Skip to content

Commit 22fd9fb

Browse files
committed
Make sampler.py more generic
1 parent 879ef2f commit 22fd9fb

File tree

1 file changed

+48
-8
lines changed

1 file changed

+48
-8
lines changed

sampler.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from matplotlib import pyplot as plt
88
from tqdm import tqdm
99

10+
from models.utils.autoencoder import get_autoencoder
1011
from models.uvit import UViT
11-
from utils.train_utils import seed_everything
1212
from utils.config_utils import load_config
13+
from utils.train_utils import seed_everything
1314

1415
checkpoint_path_by_parametrization = {
1516
"predict_noise": "../logs/6218182/cifar10_uvit.pth",
@@ -75,16 +76,30 @@ def predict_previous_postprocessing(model_output, x, t):
7576
return model_output + sigma_t * z
7677

7778

78-
def get_samples(model, batch_size: int, postprocessing: callable, seed: int):
79+
def get_samples(
80+
model,
81+
batch_size: int,
82+
postprocessing: callable,
83+
seed: int,
84+
num_channels: int,
85+
sample_height: int,
86+
sample_width: int,
87+
y: int = None,
88+
autoencoder=None,
89+
):
7990
seed_everything(seed)
80-
x = torch.randn(batch_size, 3, 32, 32).to(device)
91+
x = torch.randn(batch_size, num_channels, sample_height, sample_width).to(device)
8192

8293
for t in tqdm(range(999, -1, -1)):
8394
time_tensor = t * torch.ones(batch_size, device=device)
8495
with torch.no_grad():
85-
model_output = model(x, time_tensor)
96+
model_output = model(x, time_tensor, y)
8697
x = postprocessing(model_output, x, t)
8798

99+
if autoencoder:
100+
print("Decode the images...")
101+
x = autoencoder.decode(x)
102+
88103
samples = (x + 1) / 2
89104
samples = rearrange(samples, "b c h w -> b h w c")
90105
return samples.cpu().numpy()
@@ -139,11 +154,36 @@ def main():
139154
config = load_config(args.config_path)
140155
model = UViT(**config["model_params"])
141156

142-
model.load_state_dict(
143-
torch.load(args.checkpoint_path, map_location="cpu")["model_state_dict"]
144-
)
157+
num_channels = config["model_params"]["in_chans"]
158+
sample_height = config["model_params"]["img_size"]
159+
sample_width = config["model_params"]["img_size"]
160+
161+
print(num_channels)
162+
print(sample_height)
163+
164+
print(config)
165+
166+
model.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu"))
145167
model = model.eval().to(device)
146-
samples = get_samples(model, args.batch_size, postprocessing, args.seed)
168+
169+
y = torch.ones(args.batch_size, dtype=torch.int).to(device) * 3
170+
autoencoder = (
171+
get_autoencoder(config["autoencoder"]["autoencoder_checkpoint_path"])
172+
if "autoencoder" in config
173+
else None
174+
).to(device)
175+
176+
samples = get_samples(
177+
model,
178+
args.batch_size,
179+
postprocessing,
180+
args.seed,
181+
num_channels,
182+
sample_height,
183+
sample_width,
184+
y,
185+
autoencoder,
186+
)
147187
dump_samples(samples, args.output_folder)
148188

149189

0 commit comments

Comments
 (0)