@@ -86,21 +86,47 @@ def get_samples(
86
86
num_channels : int ,
87
87
sample_height : int ,
88
88
sample_width : int ,
89
+ use_ddim : bool ,
90
+ ddim_steps : int ,
91
+ ddim_eta : float ,
89
92
y : int = None ,
90
93
autoencoder = None ,
91
94
late_model = None ,
92
- t_switch = np .inf
95
+ t_switch = np .inf ,
93
96
):
94
97
seed_everything (seed )
95
98
x = torch .randn (batch_size , num_channels , sample_height , sample_width ).to (device )
96
99
97
- for t in tqdm (range (999 , - 1 , - 1 )):
98
- time_tensor = t * torch .ones (batch_size , device = device )
99
- with torch .no_grad ():
100
- model_output = model (x , time_tensor , y )
101
- x = postprocessing (model_output , x , t )
102
- if t == 1000 - t_switch :
103
- model = late_model
100
+ if use_ddim :
101
+ timesteps = np .linspace (0 , 999 , ddim_steps ).astype (int )[::- 1 ]
102
+ for t , s in zip (tqdm (timesteps [:- 1 ]), timesteps [1 :]):
103
+ assert s < t
104
+
105
+ time_tensor = t * torch .ones (batch_size , device = device )
106
+ with torch .no_grad ():
107
+ model_output = model (x , time_tensor , y )
108
+
109
+ sigma_t_squared = betas_tilde [t ] * ddim_eta
110
+
111
+ mean = torch .sqrt (alphas_bar [s ] / alphas_bar [t ]) * (
112
+ x - torch .sqrt (1 - alphas_bar [t ]) * model_output
113
+ )
114
+ mean += torch .sqrt (1 - alphas_bar [s ] - sigma_t_squared ) * model_output
115
+
116
+ z = torch .randn_like (x ) if s > 0 else 0
117
+ x = mean + sigma_t_squared * z
118
+
119
+ if t < 1000 - t_switch :
120
+ model = late_model
121
+
122
+ else :
123
+ for t in tqdm (range (999 , - 1 , - 1 )):
124
+ time_tensor = t * torch .ones (batch_size , device = device )
125
+ with torch .no_grad ():
126
+ model_output = model (x , time_tensor , y )
127
+ x = postprocessing (model_output , x , t )
128
+ if t == 1000 - t_switch :
129
+ model = late_model
104
130
105
131
if autoencoder :
106
132
print ("Decode the images..." )
@@ -128,7 +154,11 @@ def dump_samples(samples, output_folder: Path):
128
154
plt .imsave (output_folder / f"{ sample_id } .png" , sample )
129
155
130
156
row , col = divmod (sample_id , grid_size )
131
- grid_img [row * sample_height :(row + 1 ) * sample_height , col * sample_width :(col + 1 ) * sample_width , :] = sample
157
+ grid_img [
158
+ row * sample_height : (row + 1 ) * sample_height ,
159
+ col * sample_width : (col + 1 ) * sample_width ,
160
+ :,
161
+ ] = sample
132
162
133
163
plt .imsave (output_folder / "grid_image.png" , grid_img )
134
164
@@ -141,14 +171,18 @@ def dump_statistics(elapsed_time, output_folder: Path):
141
171
def get_args ():
142
172
parser = ArgumentParser ()
143
173
parser .add_argument ("--seed" , type = int , default = 0 )
144
- parser .add_argument ("--checkpoint_path" ,
145
- type = str ,
146
- required = True ,
147
- help = "Path to checkpoint of the model" )
148
- parser .add_argument ("--checkpoint_path_late" ,
149
- type = str ,
150
- default = None ,
151
- help = "Path to checkpoint of the model to be used in the latest steps" )
174
+ parser .add_argument (
175
+ "--checkpoint_path" ,
176
+ type = str ,
177
+ required = True ,
178
+ help = "Path to checkpoint of the model" ,
179
+ )
180
+ parser .add_argument (
181
+ "--checkpoint_path_late" ,
182
+ type = str ,
183
+ default = None ,
184
+ help = "Path to checkpoint of the model to be used in the latest steps" ,
185
+ )
152
186
parser .add_argument ("--batch_size" , type = int , required = True )
153
187
parser .add_argument (
154
188
"--parametrization" ,
@@ -181,6 +215,17 @@ def get_args():
181
215
default = None ,
182
216
help = "Number up to 1000 that corresponds to a class" ,
183
217
)
218
+ parser .add_argument ("--use_ddim" , action = "store_true" )
219
+ parser .add_argument (
220
+ "--ddim_steps" ,
221
+ type = int ,
222
+ default = 50 ,
223
+ )
224
+ parser .add_argument (
225
+ "--ddim_eta" ,
226
+ type = float ,
227
+ default = 0.0 ,
228
+ )
184
229
185
230
return parser .parse_args ()
186
231
@@ -259,19 +304,20 @@ def main():
259
304
260
305
tic = time .time ()
261
306
samples = get_samples (
262
- model ,
263
- args .batch_size ,
264
- postprocessing ,
265
- args .seed ,
266
- num_channels ,
267
- sample_height ,
268
- sample_width ,
269
- y ,
270
- autoencoder ,
271
- model_late ,
272
- args .t_switch
273
-
274
-
307
+ model = model ,
308
+ batch_size = args .batch_size ,
309
+ postprocessing = postprocessing ,
310
+ seed = args .seed ,
311
+ num_channels = num_channels ,
312
+ sample_height = sample_height ,
313
+ sample_width = sample_width ,
314
+ use_ddim = args .use_ddim ,
315
+ ddim_steps = args .ddim_steps ,
316
+ ddim_eta = args .ddim_eta ,
317
+ y = y ,
318
+ autoencoder = autoencoder ,
319
+ late_model = model_late ,
320
+ t_switch = args .t_switch ,
275
321
)
276
322
tac = time .time ()
277
323
dump_statistics (tac - tic , output_folder )
0 commit comments