|
7 | 7 | from matplotlib import pyplot as plt
|
8 | 8 | from tqdm import tqdm
|
9 | 9 |
|
| 10 | +from models.utils.autoencoder import get_autoencoder |
10 | 11 | from models.uvit import UViT
|
11 |
| -from utils.train_utils import seed_everything |
12 | 12 | from utils.config_utils import load_config
|
| 13 | +from utils.train_utils import seed_everything |
13 | 14 |
|
14 | 15 | checkpoint_path_by_parametrization = {
|
15 | 16 | "predict_noise": "../logs/6218182/cifar10_uvit.pth",
|
@@ -75,16 +76,30 @@ def predict_previous_postprocessing(model_output, x, t):
|
75 | 76 | return model_output + sigma_t * z
|
76 | 77 |
|
77 | 78 |
|
78 |
| -def get_samples(model, batch_size: int, postprocessing: callable, seed: int): |
| 79 | +def get_samples( |
| 80 | + model, |
| 81 | + batch_size: int, |
| 82 | + postprocessing: callable, |
| 83 | + seed: int, |
| 84 | + num_channels: int, |
| 85 | + sample_height: int, |
| 86 | + sample_width: int, |
| 87 | + y: int = None, |
| 88 | + autoencoder=None, |
| 89 | +): |
79 | 90 | seed_everything(seed)
|
80 |
| - x = torch.randn(batch_size, 3, 32, 32).to(device) |
| 91 | + x = torch.randn(batch_size, num_channels, sample_height, sample_width).to(device) |
81 | 92 |
|
82 | 93 | for t in tqdm(range(999, -1, -1)):
|
83 | 94 | time_tensor = t * torch.ones(batch_size, device=device)
|
84 | 95 | with torch.no_grad():
|
85 |
| - model_output = model(x, time_tensor) |
| 96 | + model_output = model(x, time_tensor, y) |
86 | 97 | x = postprocessing(model_output, x, t)
|
87 | 98 |
|
| 99 | + if autoencoder: |
| 100 | + print("Decode the images...") |
| 101 | + x = autoencoder.decode(x) |
| 102 | + |
88 | 103 | samples = (x + 1) / 2
|
89 | 104 | samples = rearrange(samples, "b c h w -> b h w c")
|
90 | 105 | return samples.cpu().numpy()
|
@@ -139,11 +154,36 @@ def main():
|
139 | 154 | config = load_config(args.config_path)
|
140 | 155 | model = UViT(**config["model_params"])
|
141 | 156 |
|
142 |
| - model.load_state_dict( |
143 |
| - torch.load(args.checkpoint_path, map_location="cpu")["model_state_dict"] |
144 |
| - ) |
| 157 | + num_channels = config["model_params"]["in_chans"] |
| 158 | + sample_height = config["model_params"]["img_size"] |
| 159 | + sample_width = config["model_params"]["img_size"] |
| 160 | + |
| 161 | + print(num_channels) |
| 162 | + print(sample_height) |
| 163 | + |
| 164 | + print(config) |
| 165 | + |
| 166 | + model.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu")) |
145 | 167 | model = model.eval().to(device)
|
146 |
| - samples = get_samples(model, args.batch_size, postprocessing, args.seed) |
| 168 | + |
| 169 | + y = torch.ones(args.batch_size, dtype=torch.int).to(device) * 3 |
| 170 | + autoencoder = ( |
| 171 | + get_autoencoder(config["autoencoder"]["autoencoder_checkpoint_path"]) |
| 172 | + if "autoencoder" in config |
| 173 | + else None |
| 174 | + ).to(device) |
| 175 | + |
| 176 | + samples = get_samples( |
| 177 | + model, |
| 178 | + args.batch_size, |
| 179 | + postprocessing, |
| 180 | + args.seed, |
| 181 | + num_channels, |
| 182 | + sample_height, |
| 183 | + sample_width, |
| 184 | + y, |
| 185 | + autoencoder, |
| 186 | + ) |
147 | 187 | dump_samples(samples, args.output_folder)
|
148 | 188 |
|
149 | 189 |
|
|
0 commit comments