1
+ import time
1
2
from argparse import ArgumentParser
2
3
from pathlib import Path
3
4
@@ -105,19 +106,21 @@ def get_samples(
105
106
return samples .cpu ().numpy ()
106
107
107
108
108
- def dump_samples (samples , output_folder ):
109
- output_folder = Path (output_folder )
110
- output_folder .mkdir (parents = True , exist_ok = True )
111
-
112
- plt .hist (samples .flatten ())
113
- plt .savefig (output_folder / "histogram.png" )
114
- plt .clf ()
109
+ def dump_samples (samples , output_folder : Path ):
110
+ # plt.hist(samples.flatten())
111
+ # plt.savefig(output_folder / "histogram.png")
112
+ # plt.clf()
115
113
116
114
for sample_id , sample in enumerate (samples ):
117
115
sample = np .clip (sample , 0 , 1 )
118
116
plt .imsave (output_folder / f"{ sample_id } .png" , sample )
119
117
120
118
119
+ def dump_statistics (elapsed_time , output_folder : Path ):
120
+ with open (output_folder / "statistics.txt" , "w" ) as f :
121
+ f .write (f"Elapsed time: { elapsed_time } s\n " )
122
+
123
+
121
124
def get_args ():
122
125
parser = ArgumentParser ()
123
126
parser .add_argument ("--seed" , type = int , default = 0 )
@@ -148,6 +151,10 @@ def get_args():
148
151
149
152
def main ():
150
153
args = get_args ()
154
+
155
+ output_folder = Path (args .output_folder )
156
+ output_folder .mkdir (parents = True , exist_ok = True )
157
+
151
158
if args .parametrization == "predict_noise" :
152
159
postprocessing = predict_noise_postprocessing
153
160
elif args .parametrization == "predict_original" :
@@ -164,25 +171,25 @@ def main():
164
171
sample_height = config ["model_params" ]["img_size" ]
165
172
sample_width = config ["model_params" ]["img_size" ]
166
173
167
- print (num_channels )
168
- print (sample_height )
169
-
170
- print (config )
171
-
172
- model .load_state_dict (torch .load (args .checkpoint_path , map_location = "cpu" ))
174
+ state_dict = torch .load (args .checkpoint_path , map_location = "cpu" )
175
+ if "model_state_dict" in state_dict :
176
+ state_dict = state_dict ["model_state_dict" ]
177
+ model .load_state_dict (state_dict )
173
178
model = model .eval ().to (device )
174
179
175
180
y = (
176
181
torch .ones (args .batch_size , dtype = torch .int ).to (device ) * args .class_id
177
182
if args .class_id is not None
178
183
else None
179
184
)
180
- autoencoder = (
181
- get_autoencoder (config ["autoencoder" ]["autoencoder_checkpoint_path" ])
182
- if "autoencoder" in config
183
- else None
184
- ).to (device )
185
+ if "autoencoder" in config :
186
+ autoencoder = get_autoencoder (
187
+ config ["autoencoder" ]["autoencoder_checkpoint_path" ]
188
+ ).to (device )
189
+ else :
190
+ autoencoder = None
185
191
192
+ tic = time .time ()
186
193
samples = get_samples (
187
194
model ,
188
195
args .batch_size ,
@@ -194,7 +201,10 @@ def main():
194
201
y ,
195
202
autoencoder ,
196
203
)
197
- dump_samples (samples , args .output_folder )
204
+ tac = time .time ()
205
+ dump_statistics (tac - tic , output_folder )
206
+
207
+ dump_samples (samples , output_folder )
198
208
199
209
200
210
if __name__ == "__main__" :
0 commit comments