@@ -45,12 +45,13 @@ def get_samples(
45
45
sample_height : int ,
46
46
sample_width : int ,
47
47
threshold : float ,
48
+ depth : int ,
48
49
y : int = None ,
49
50
autoencoder = None ,
50
51
):
51
52
seed_everything (seed )
52
53
x = torch .randn (batch_size , num_channels , sample_height , sample_width ).to (device )
53
- error_prediction_by_timestep = torch .zeros (1000 , 13 )
54
+ error_prediction_by_timestep = torch .zeros (1000 , depth )
54
55
indices_by_timestep = torch .zeros (1000 , batch_size )
55
56
56
57
for t in tqdm (range (999 , - 1 , - 1 )):
@@ -67,7 +68,7 @@ def get_samples(
67
68
model_output = outputs [indices , torch .arange (batch_size )]
68
69
69
70
# Log for visualization
70
- error_prediction_by_timestep [t ] = classifier_outputs .mean (axis = 1 )[:13 ]
71
+ error_prediction_by_timestep [t ] = classifier_outputs .mean (axis = 1 )[:depth ]
71
72
indices_by_timestep [t , :] = indices
72
73
73
74
alpha_t = alphas [t ]
@@ -158,18 +159,28 @@ def main():
158
159
num_channels = config ["model_params" ]["in_chans" ]
159
160
sample_height = config ["model_params" ]["img_size" ]
160
161
sample_width = config ["model_params" ]["img_size" ]
162
+ depth = config ["model_params" ]["depth" ]
161
163
162
164
state_dict = torch .load (args .checkpoint_path , map_location = "cpu" )
163
165
if "model_state_dict" in state_dict :
164
166
state_dict = state_dict ["model_state_dict" ]
165
167
model .load_state_dict (state_dict )
166
168
model = model .eval ().to (device )
167
169
170
+ # y = (
171
+ # torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
172
+ # if args.class_id is not None
173
+ # else None
174
+ # )
175
+
176
+ seed_everything (args .seed )
177
+
168
178
y = (
169
- torch .ones ( args .batch_size , dtype = torch . int ) .to (device ) * args . class_id
179
+ torch .randint ( 1 , 1001 , ( args .batch_size ,)) .to (device )
170
180
if args .class_id is not None
171
181
else None
172
182
)
183
+
173
184
if "autoencoder" in config :
174
185
autoencoder = get_autoencoder (
175
186
config ["autoencoder" ]["autoencoder_checkpoint_path" ]
@@ -186,6 +197,7 @@ def main():
186
197
sample_height = sample_height ,
187
198
sample_width = sample_width ,
188
199
threshold = args .threshold ,
200
+ depth = depth ,
189
201
y = y ,
190
202
autoencoder = autoencoder ,
191
203
)
0 commit comments