Skip to content

Commit 484ba95

Browse files
committed
Merge branch 'main' of github.com:razvanmatisan/early-stopping-diffusion
2 parents 2cc41d1 + 7e57fa9 commit 484ba95

File tree

7 files changed

+151
-30
lines changed

7 files changed

+151
-30
lines changed

configs/uvit_celeba_3.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model_params:
2+
img_size: 64
3+
patch_size: 4
4+
in_chans: 3
5+
embed_dim: 512
6+
depth: 3
7+
num_heads: 8
8+
mlp_ratio: 4
9+
qkv_bias: False
10+
mlp_time_embed: False
11+
num_classes: -1
12+
normalize_timesteps: True

configs/uvit_celeba_5.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model_params:
2+
img_size: 64
3+
patch_size: 4
4+
in_chans: 3
5+
embed_dim: 512
6+
depth: 5
7+
num_heads: 8
8+
mlp_ratio: 4
9+
qkv_bias: False
10+
mlp_time_embed: False
11+
num_classes: -1
12+
normalize_timesteps: True

configs/uvit_celeba_7.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model_params:
2+
img_size: 64
3+
patch_size: 4
4+
in_chans: 3
5+
embed_dim: 512
6+
depth: 7
7+
num_heads: 8
8+
mlp_ratio: 4
9+
qkv_bias: False
10+
mlp_time_embed: False
11+
num_classes: -1
12+
normalize_timesteps: True

configs/uvit_cifar10_3.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model_params:
2+
img_size: 32
3+
patch_size: 2
4+
in_chans: 3
5+
embed_dim: 512
6+
depth: 3
7+
num_heads: 8
8+
mlp_ratio: 4
9+
qkv_bias: False
10+
mlp_time_embed: False
11+
num_classes: -1
12+
normalize_timesteps: True

configs/uvit_imagenet256_3.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
model_params:
2+
img_size: 32
3+
patch_size: 2
4+
in_chans: 4
5+
embed_dim: 1024
6+
depth: 3
7+
num_heads: 16
8+
mlp_ratio: 4
9+
qkv_bias: False
10+
mlp_time_embed: False
11+
num_classes: 1001
12+
normalize_timesteps: False
13+
14+
autoencoder:
15+
autoencoder_checkpoint_path: ./checkpoints/autoencoder/autoencoder_kl.pth

configs/uvit_imagenet64_3.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
model_params:
2+
img_size: 64
3+
patch_size: 4
4+
in_chans: 3
5+
embed_dim: 768
6+
depth: 3
7+
num_heads: 12
8+
mlp_ratio: 4
9+
qkv_bias: False
10+
mlp_time_embed: False
11+
num_classes: 1000
12+
normalize_timesteps: False

sampler.py

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,47 @@ def get_samples(
8686
num_channels: int,
8787
sample_height: int,
8888
sample_width: int,
89+
use_ddim: bool,
90+
ddim_steps: int,
91+
ddim_eta: float,
8992
y: int = None,
9093
autoencoder=None,
9194
late_model=None,
92-
t_switch=np.inf
95+
t_switch=np.inf,
9396
):
9497
seed_everything(seed)
9598
x = torch.randn(batch_size, num_channels, sample_height, sample_width).to(device)
9699

97-
for t in tqdm(range(999, -1, -1)):
98-
time_tensor = t * torch.ones(batch_size, device=device)
99-
with torch.no_grad():
100-
model_output = model(x, time_tensor, y)
101-
x = postprocessing(model_output, x, t)
102-
if t == 1000 - t_switch:
103-
model = late_model
100+
if use_ddim:
101+
timesteps = np.linspace(0, 999, ddim_steps).astype(int)[::-1]
102+
for t, s in zip(tqdm(timesteps[:-1]), timesteps[1:]):
103+
assert s < t
104+
105+
time_tensor = t * torch.ones(batch_size, device=device)
106+
with torch.no_grad():
107+
model_output = model(x, time_tensor, y)
108+
109+
sigma_t_squared = betas_tilde[t] * ddim_eta
110+
111+
mean = torch.sqrt(alphas_bar[s] / alphas_bar[t]) * (
112+
x - torch.sqrt(1 - alphas_bar[t]) * model_output
113+
)
114+
mean += torch.sqrt(1 - alphas_bar[s] - sigma_t_squared) * model_output
115+
116+
z = torch.randn_like(x) if s > 0 else 0
117+
x = mean + sigma_t_squared * z
118+
119+
if t < 1000 - t_switch:
120+
model = late_model
121+
122+
else:
123+
for t in tqdm(range(999, -1, -1)):
124+
time_tensor = t * torch.ones(batch_size, device=device)
125+
with torch.no_grad():
126+
model_output = model(x, time_tensor, y)
127+
x = postprocessing(model_output, x, t)
128+
if t == 1000 - t_switch:
129+
model = late_model
104130

105131
if autoencoder:
106132
print("Decode the images...")
@@ -128,7 +154,11 @@ def dump_samples(samples, output_folder: Path):
128154
plt.imsave(output_folder / f"{sample_id}.png", sample)
129155

130156
row, col = divmod(sample_id, grid_size)
131-
grid_img[row * sample_height:(row + 1) * sample_height, col * sample_width:(col + 1) * sample_width, :] = sample
157+
grid_img[
158+
row * sample_height : (row + 1) * sample_height,
159+
col * sample_width : (col + 1) * sample_width,
160+
:,
161+
] = sample
132162

133163
plt.imsave(output_folder / "grid_image.png", grid_img)
134164

@@ -141,14 +171,18 @@ def dump_statistics(elapsed_time, output_folder: Path):
141171
def get_args():
142172
parser = ArgumentParser()
143173
parser.add_argument("--seed", type=int, default=0)
144-
parser.add_argument("--checkpoint_path",
145-
type=str,
146-
required=True,
147-
help="Path to checkpoint of the model")
148-
parser.add_argument("--checkpoint_path_late",
149-
type=str,
150-
default=None,
151-
help="Path to checkpoint of the model to be used in the latest steps")
174+
parser.add_argument(
175+
"--checkpoint_path",
176+
type=str,
177+
required=True,
178+
help="Path to checkpoint of the model",
179+
)
180+
parser.add_argument(
181+
"--checkpoint_path_late",
182+
type=str,
183+
default=None,
184+
help="Path to checkpoint of the model to be used in the latest steps",
185+
)
152186
parser.add_argument("--batch_size", type=int, required=True)
153187
parser.add_argument(
154188
"--parametrization",
@@ -181,6 +215,17 @@ def get_args():
181215
default=None,
182216
help="Number up to 1000 that corresponds to a class",
183217
)
218+
parser.add_argument("--use_ddim", action="store_true")
219+
parser.add_argument(
220+
"--ddim_steps",
221+
type=int,
222+
default=50,
223+
)
224+
parser.add_argument(
225+
"--ddim_eta",
226+
type=float,
227+
default=0.0,
228+
)
184229

185230
return parser.parse_args()
186231

@@ -259,19 +304,20 @@ def main():
259304

260305
tic = time.time()
261306
samples = get_samples(
262-
model,
263-
args.batch_size,
264-
postprocessing,
265-
args.seed,
266-
num_channels,
267-
sample_height,
268-
sample_width,
269-
y,
270-
autoencoder,
271-
model_late,
272-
args.t_switch
273-
274-
307+
model=model,
308+
batch_size=args.batch_size,
309+
postprocessing=postprocessing,
310+
seed=args.seed,
311+
num_channels=num_channels,
312+
sample_height=sample_height,
313+
sample_width=sample_width,
314+
use_ddim=args.use_ddim,
315+
ddim_steps=args.ddim_steps,
316+
ddim_eta=args.ddim_eta,
317+
y=y,
318+
autoencoder=autoencoder,
319+
late_model=model_late,
320+
t_switch=args.t_switch,
275321
)
276322
tac = time.time()
277323
dump_statistics(tac - tic, output_folder)

0 commit comments

Comments
 (0)