8
8
from einops import rearrange
9
9
from matplotlib import pyplot as plt
10
10
from tqdm import tqdm
11
+ from typing import List
11
12
12
13
from models .utils .autoencoder import get_autoencoder
13
14
from models .uvit import UViT
@@ -89,13 +90,15 @@ def get_samples(
89
90
use_ddim : bool ,
90
91
ddim_steps : int ,
91
92
ddim_eta : float ,
93
+ timesteps_save : List [int ],
92
94
y : int = None ,
93
95
autoencoder = None ,
94
96
late_model = None ,
95
97
t_switch = np .inf ,
96
98
):
97
99
seed_everything (seed )
98
100
x = torch .randn (batch_size , num_channels , sample_height , sample_width ).to (device )
101
+ intermediate_samples = []
99
102
100
103
if use_ddim :
101
104
timesteps = np .linspace (0 , 999 , ddim_steps ).astype (int )[::- 1 ]
@@ -119,25 +122,41 @@ def get_samples(
119
122
if t < 1000 - t_switch :
120
123
model = late_model
121
124
125
+ if 1000 - t in timesteps_save :
126
+ intermediate_samples .append (x )
127
+
122
128
else :
123
129
for t in tqdm (range (999 , - 1 , - 1 )):
124
130
time_tensor = t * torch .ones (batch_size , device = device )
125
131
with torch .no_grad ():
126
132
model_output = model (x , time_tensor , y )
127
133
x = postprocessing (model_output , x , t )
134
+
128
135
if t == 1000 - t_switch :
129
136
model = late_model
130
137
138
+ if 1000 - t in timesteps_save :
139
+ intermediate_samples .append (x )
140
+
141
+
131
142
if autoencoder :
132
143
print ("Decode the images..." )
133
144
x = autoencoder .decode (x )
134
145
135
146
samples = (x + 1 ) / 2
136
147
samples = rearrange (samples , "b c h w -> b h w c" )
137
- return samples .cpu ().numpy ()
138
148
149
+ for i , x in enumerate (intermediate_samples ):
150
+ if autoencoder :
151
+ x = autoencoder .decode (x )
152
+ x = (x + 1 ) / 2
153
+ x = rearrange (x , "b c h w -> b h w c" )
154
+ intermediate_samples [i ] = x .cpu ().numpy ()
155
+
156
+ return samples .cpu ().numpy (), intermediate_samples
139
157
140
- def dump_samples (samples , output_folder : Path ):
158
+
159
+ def dump_samples (samples , output_folder : Path , timestep = 1000 ):
141
160
# plt.hist(samples.flatten())
142
161
# plt.savefig(output_folder / "histogram.png")
143
162
# plt.clf()
@@ -151,7 +170,8 @@ def dump_samples(samples, output_folder: Path):
151
170
152
171
for sample_id , sample in enumerate (samples ):
153
172
sample = np .clip (sample , 0 , 1 )
154
- plt .imsave (output_folder / f"{ sample_id } .png" , sample )
173
+ filename = f"{ sample_id } _{ timestep } .png" if timestep != 1000 else f"{ sample_id } .png"
174
+ plt .imsave (output_folder / filename , sample )
155
175
156
176
row , col = divmod (sample_id , grid_size )
157
177
grid_img [
@@ -226,6 +246,12 @@ def get_args():
226
246
type = float ,
227
247
default = 0.0 ,
228
248
)
249
+ parser .add_argument (
250
+ "--timesteps_save" ,
251
+ type = int ,
252
+ nargs = "+" ,
253
+ default = []
254
+ )
229
255
230
256
return parser .parse_args ()
231
257
@@ -303,7 +329,7 @@ def main():
303
329
autoencoder = None
304
330
305
331
tic = time .time ()
306
- samples = get_samples (
332
+ samples , intermediate_samples = get_samples (
307
333
model = model ,
308
334
batch_size = args .batch_size ,
309
335
postprocessing = postprocessing ,
@@ -318,12 +344,17 @@ def main():
318
344
autoencoder = autoencoder ,
319
345
late_model = model_late ,
320
346
t_switch = args .t_switch ,
347
+ timesteps_save = args .timesteps_save
321
348
)
322
349
tac = time .time ()
323
350
dump_statistics (tac - tic , output_folder )
324
351
325
352
dump_samples (samples , output_folder )
326
353
354
+ if args .timesteps_save :
355
+ for timestep , samples in zip (args .timesteps_save , intermediate_samples ):
356
+ dump_samples (samples , output_folder , timestep )
357
+
327
358
328
359
if __name__ == "__main__" :
329
360
main ()
0 commit comments