Skip to content

Commit 7e57fa9

Browse files
committed
Implement DDIM sampler
1 parent ba72a2e commit 7e57fa9

File tree

1 file changed

+76
-30
lines changed

1 file changed

+76
-30
lines changed

sampler.py

Lines changed: 76 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,47 @@ def get_samples(
8686
num_channels: int,
8787
sample_height: int,
8888
sample_width: int,
89+
use_ddim: bool,
90+
ddim_steps: int,
91+
ddim_eta: float,
8992
y: int = None,
9093
autoencoder=None,
9194
late_model=None,
92-
t_switch=np.inf
95+
t_switch=np.inf,
9396
):
9497
seed_everything(seed)
9598
x = torch.randn(batch_size, num_channels, sample_height, sample_width).to(device)
9699

97-
for t in tqdm(range(999, -1, -1)):
98-
time_tensor = t * torch.ones(batch_size, device=device)
99-
with torch.no_grad():
100-
model_output = model(x, time_tensor, y)
101-
x = postprocessing(model_output, x, t)
102-
if t == 1000 - t_switch:
103-
model = late_model
100+
if use_ddim:
101+
timesteps = np.linspace(0, 999, ddim_steps).astype(int)[::-1]
102+
for t, s in zip(tqdm(timesteps[:-1]), timesteps[1:]):
103+
assert s < t
104+
105+
time_tensor = t * torch.ones(batch_size, device=device)
106+
with torch.no_grad():
107+
model_output = model(x, time_tensor, y)
108+
109+
sigma_t_squared = betas_tilde[t] * ddim_eta
110+
111+
mean = torch.sqrt(alphas_bar[s] / alphas_bar[t]) * (
112+
x - torch.sqrt(1 - alphas_bar[t]) * model_output
113+
)
114+
mean += torch.sqrt(1 - alphas_bar[s] - sigma_t_squared) * model_output
115+
116+
z = torch.randn_like(x) if s > 0 else 0
117+
x = mean + sigma_t_squared * z
118+
119+
if t < 1000 - t_switch:
120+
model = late_model
121+
122+
else:
123+
for t in tqdm(range(999, -1, -1)):
124+
time_tensor = t * torch.ones(batch_size, device=device)
125+
with torch.no_grad():
126+
model_output = model(x, time_tensor, y)
127+
x = postprocessing(model_output, x, t)
128+
if t == 1000 - t_switch:
129+
model = late_model
104130

105131
if autoencoder:
106132
print("Decode the images...")
@@ -128,7 +154,11 @@ def dump_samples(samples, output_folder: Path):
128154
plt.imsave(output_folder / f"{sample_id}.png", sample)
129155

130156
row, col = divmod(sample_id, grid_size)
131-
grid_img[row * sample_height:(row + 1) * sample_height, col * sample_width:(col + 1) * sample_width, :] = sample
157+
grid_img[
158+
row * sample_height : (row + 1) * sample_height,
159+
col * sample_width : (col + 1) * sample_width,
160+
:,
161+
] = sample
132162

133163
plt.imsave(output_folder / "grid_image.png", grid_img)
134164

@@ -141,14 +171,18 @@ def dump_statistics(elapsed_time, output_folder: Path):
141171
def get_args():
142172
parser = ArgumentParser()
143173
parser.add_argument("--seed", type=int, default=0)
144-
parser.add_argument("--checkpoint_path",
145-
type=str,
146-
required=True,
147-
help="Path to checkpoint of the model")
148-
parser.add_argument("--checkpoint_path_late",
149-
type=str,
150-
default=None,
151-
help="Path to checkpoint of the model to be used in the latest steps")
174+
parser.add_argument(
175+
"--checkpoint_path",
176+
type=str,
177+
required=True,
178+
help="Path to checkpoint of the model",
179+
)
180+
parser.add_argument(
181+
"--checkpoint_path_late",
182+
type=str,
183+
default=None,
184+
help="Path to checkpoint of the model to be used in the latest steps",
185+
)
152186
parser.add_argument("--batch_size", type=int, required=True)
153187
parser.add_argument(
154188
"--parametrization",
@@ -181,6 +215,17 @@ def get_args():
181215
default=None,
182216
help="Number up to 1000 that corresponds to a class",
183217
)
218+
parser.add_argument("--use_ddim", action="store_true")
219+
parser.add_argument(
220+
"--ddim_steps",
221+
type=int,
222+
default=50,
223+
)
224+
parser.add_argument(
225+
"--ddim_eta",
226+
type=float,
227+
default=0.0,
228+
)
184229

185230
return parser.parse_args()
186231

@@ -250,19 +295,20 @@ def main():
250295

251296
tic = time.time()
252297
samples = get_samples(
253-
model,
254-
args.batch_size,
255-
postprocessing,
256-
args.seed,
257-
num_channels,
258-
sample_height,
259-
sample_width,
260-
y,
261-
autoencoder,
262-
model_late,
263-
args.t_switch
264-
265-
298+
model=model,
299+
batch_size=args.batch_size,
300+
postprocessing=postprocessing,
301+
seed=args.seed,
302+
num_channels=num_channels,
303+
sample_height=sample_height,
304+
sample_width=sample_width,
305+
use_ddim=args.use_ddim,
306+
ddim_steps=args.ddim_steps,
307+
ddim_eta=args.ddim_eta,
308+
y=y,
309+
autoencoder=autoencoder,
310+
late_model=model_late,
311+
t_switch=args.t_switch,
266312
)
267313
tac = time.time()
268314
dump_statistics(tac - tic, output_folder)

0 commit comments

Comments
 (0)