Skip to content

Commit 7788431

Browse files
committed
fix: fix PyTorch model extension in simplify
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent a33e270 commit 7788431

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

dpgen/simplify/simplify.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def run_model_devi(iter_index, jdata, mdata):
221221
commands = []
222222
run_tasks = ["."]
223223
# get models
224-
models = glob.glob(os.path.join(work_path, "graph*pb"))
224+
suffix = _get_model_suffix(jdata)
225+
models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))
226+
assert len(models) > 0, "No model file found."
225227
model_names = [os.path.basename(ii) for ii in models]
226228
task_model_list = []
227229
for ii in model_names:

0 commit comments

Comments
 (0)