File tree Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Expand file tree Collapse file tree 1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -165,7 +165,19 @@ def main():
165
165
raise ValueError (f"Invalid parametrization { args .parametrization } " )
166
166
167
167
config = load_config (args .config_path )
168
- model = UViT (** config ["model_params" ])
168
+ model = UViT (
169
+ img_size = config ["model_params" ]["img_size" ],
170
+ patch_size = config ["model_params" ]["patch_size" ],
171
+ in_chans = config ["model_params" ]["in_chans" ],
172
+ embed_dim = config ["model_params" ]["embed_dim" ],
173
+ depth = config ["model_params" ]["depth" ],
174
+ num_heads = config ["model_params" ]["num_heads" ],
175
+ mlp_ratio = config ["model_params" ]["mlp_ratio" ],
176
+ qkv_bias = config ["model_params" ]["qkv_bias" ],
177
+ mlp_time_embed = config ["model_params" ]["mlp_time_embed" ],
178
+ num_classes = config ["model_params" ]["num_classes" ],
179
+ normalize_timesteps = config ["model_params" ]["normalize_timesteps" ],
180
+ )
169
181
170
182
num_channels = config ["model_params" ]["in_chans" ]
171
183
sample_height = config ["model_params" ]["img_size" ]
You can’t perform that action at this time.
0 commit comments