Skip to content

Commit 44debab

Browse files
authored
fix: fix checkpoint filename for the PyTorch backend
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 30bc1e5 commit 44debab

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

dpgen/generator/run.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,13 @@ def run_train_dp(iter_index, jdata, mdata):
808808
elif training_finetune_model is not None:
809809
init_flag = f" --finetune old/init{suffix}"
810810
command = f"{train_command} train {train_input_file}{extra_flags}"
811-
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
811+
if suffix == ".pb":
812+
ckpt_suffix = ".index"
813+
elif suffix == ".pth":
814+
ckpt_suffix = ".pt"
815+
else:
816+
raise RuntimeError(f"Unknown suffix {suffix}")
817+
command = f"{{ if [ ! -f model.ckpt{ckpt_suffix} ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
812818
command = f"/bin/sh -c {shlex.quote(command)}"
813819
commands.append(command)
814820
command = f"{train_command} freeze"

0 commit comments

Comments
 (0)