Skip to content

Commit 10b9399

Browse files
authored
feat: support jax backend in advance
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent c2db0fa commit 10b9399

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

dpgen/generator/run.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ def _get_model_suffix(jdata) -> str:
129129
"""Return the model suffix based on the backend."""
130130
mlp_engine = jdata.get("mlp_engine", "dp")
131131
if mlp_engine == "dp":
132-
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
132+
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth", "jax": ".savedmodel"}
133133
backend = jdata.get("train_backend", "tensorflow")
134134
if backend in suffix_map:
135135
suffix = suffix_map[backend]
136136
else:
137137
raise ValueError(
138-
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
138+
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch', 'jax'."
139139
)
140140
return suffix
141141
else:
@@ -766,6 +766,8 @@ def run_train_dp(iter_index, jdata, mdata):
766766
# assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
767767
if suffix == ".pth":
768768
train_command += " --pt"
769+
elif suffix == ".savedmodel":
770+
train_command += "--jax"
769771

770772
# paths
771773
iter_name = make_iter_name(iter_index)
@@ -803,6 +805,8 @@ def run_train_dp(iter_index, jdata, mdata):
803805
ckpt_suffix = ".index"
804806
elif suffix == ".pth":
805807
ckpt_suffix = ".pt"
808+
elif suffix == ".savedmodel":
809+
ckpt_suffix = ".jax"
806810
else:
807811
raise RuntimeError(f"Unknown suffix {suffix}")
808812
command = f"{{ if [ ! -f model.ckpt{ckpt_suffix} ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
@@ -840,6 +844,10 @@ def run_train_dp(iter_index, jdata, mdata):
840844
]
841845
elif suffix == ".pth":
842846
forward_files += [os.path.join("old", "model.ckpt.pt")]
847+
elif suffix == ".savedmodel":
848+
forward_files += [os.path.join("old", "model.ckpt.jax")]
849+
else:
850+
raise RuntimeError(f"Unknown suffix {suffix}")
843851
elif training_init_frozen_model is not None or training_finetune_model is not None:
844852
forward_files.append(os.path.join("old", f"init{suffix}"))
845853

@@ -860,6 +868,10 @@ def run_train_dp(iter_index, jdata, mdata):
860868
]
861869
elif suffix == ".pth":
862870
backward_files += ["model.ckpt.pt"]
871+
elif suffix == ".savedmodel":
872+
backward_files += ["model.ckpt.jax"]
873+
else:
874+
raise RuntimeError(f"Unknown suffix {suffix}")
863875

864876
if not jdata.get("one_h5", False):
865877
init_data_sys_ = jdata["init_data_sys"]

0 commit comments

Comments
 (0)