Skip to content

Commit 7a4e440

Browse files
committed
Fix bug in sampler
1 parent bb65e69 commit 7a4e440

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

sampler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,19 @@ def main():
165165
raise ValueError(f"Invalid parametrization {args.parametrization}")
166166

167167
config = load_config(args.config_path)
168-
model = UViT(**config["model_params"])
168+
model = UViT(
169+
img_size=config["model_params"]["img_size"],
170+
patch_size=config["model_params"]["patch_size"],
171+
in_chans=config["model_params"]["in_chans"],
172+
embed_dim=config["model_params"]["embed_dim"],
173+
depth=config["model_params"]["depth"],
174+
num_heads=config["model_params"]["num_heads"],
175+
mlp_ratio=config["model_params"]["mlp_ratio"],
176+
qkv_bias=config["model_params"]["qkv_bias"],
177+
mlp_time_embed=config["model_params"]["mlp_time_embed"],
178+
num_classes=config["model_params"]["num_classes"],
179+
normalize_timesteps=config["model_params"]["normalize_timesteps"],
180+
)
169181

170182
num_channels = config["model_params"]["in_chans"]
171183
sample_height = config["model_params"]["img_size"]

0 commit comments

Comments
 (0)