@@ -129,13 +129,13 @@ def _get_model_suffix(jdata) -> str:
129
129
"""Return the model suffix based on the backend."""
130
130
mlp_engine = jdata .get ("mlp_engine" , "dp" )
131
131
if mlp_engine == "dp" :
132
- suffix_map = {"tensorflow" : ".pb" , "pytorch" : ".pth" }
132
+ suffix_map = {"tensorflow" : ".pb" , "pytorch" : ".pth" , "jax" : ".savedmodel" }
133
133
backend = jdata .get ("train_backend" , "tensorflow" )
134
134
if backend in suffix_map :
135
135
suffix = suffix_map [backend ]
136
136
else :
137
137
raise ValueError (
138
- f"The backend { backend } is not available. Supported backends are: 'tensorflow', 'pytorch'."
138
+ f"The backend { backend } is not available. Supported backends are: 'tensorflow', 'pytorch', 'jax' ."
139
139
)
140
140
return suffix
141
141
else :
@@ -766,6 +766,8 @@ def run_train_dp(iter_index, jdata, mdata):
766
766
# assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
767
767
if suffix == ".pth" :
768
768
train_command += " --pt"
769
+ elif suffix == ".savedmodel" :
770
+ train_command += "--jax"
769
771
770
772
# paths
771
773
iter_name = make_iter_name (iter_index )
@@ -803,6 +805,8 @@ def run_train_dp(iter_index, jdata, mdata):
803
805
ckpt_suffix = ".index"
804
806
elif suffix == ".pth" :
805
807
ckpt_suffix = ".pt"
808
+ elif suffix == ".savedmodel" :
809
+ ckpt_suffix = ".jax"
806
810
else :
807
811
raise RuntimeError (f"Unknown suffix { suffix } " )
808
812
command = f"{{ if [ ! -f model.ckpt{ ckpt_suffix } ]; then { command } { init_flag } ; else { command } --restart model.ckpt; fi }}"
@@ -840,6 +844,10 @@ def run_train_dp(iter_index, jdata, mdata):
840
844
]
841
845
elif suffix == ".pth" :
842
846
forward_files += [os .path .join ("old" , "model.ckpt.pt" )]
847
+ elif suffix == ".savedmodel" :
848
+ forward_files += [os .path .join ("old" , "model.ckpt.jax" )]
849
+ else :
850
+ raise RuntimeError (f"Unknown suffix { suffix } " )
843
851
elif training_init_frozen_model is not None or training_finetune_model is not None :
844
852
forward_files .append (os .path .join ("old" , f"init{ suffix } " ))
845
853
@@ -860,6 +868,10 @@ def run_train_dp(iter_index, jdata, mdata):
860
868
]
861
869
elif suffix == ".pth" :
862
870
backward_files += ["model.ckpt.pt" ]
871
+ elif suffix == ".savedmodel" :
872
+ backward_files += ["model.ckpt.jax" ]
873
+ else :
874
+ raise RuntimeError (f"Unknown suffix { suffix } " )
863
875
864
876
if not jdata .get ("one_h5" , False ):
865
877
init_data_sys_ = jdata ["init_data_sys" ]
0 commit comments