Skip to content

Commit 2cc41d1

Browse files
committed
Add depth param + sampling with random class in imagenet
1 parent 74001a2 commit 2cc41d1

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

eesampler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ def get_samples(
4545
sample_height: int,
4646
sample_width: int,
4747
threshold: float,
48+
depth: int,
4849
y: int = None,
4950
autoencoder=None,
5051
):
5152
seed_everything(seed)
5253
x = torch.randn(batch_size, num_channels, sample_height, sample_width).to(device)
53-
error_prediction_by_timestep = torch.zeros(1000, 13)
54+
error_prediction_by_timestep = torch.zeros(1000, depth)
5455
indices_by_timestep = torch.zeros(1000, batch_size)
5556

5657
for t in tqdm(range(999, -1, -1)):
@@ -67,7 +68,7 @@ def get_samples(
6768
model_output = outputs[indices, torch.arange(batch_size)]
6869

6970
# Log for visualization
70-
error_prediction_by_timestep[t] = classifier_outputs.mean(axis=1)[:13]
71+
error_prediction_by_timestep[t] = classifier_outputs.mean(axis=1)[:depth]
7172
indices_by_timestep[t, :] = indices
7273

7374
alpha_t = alphas[t]
@@ -158,18 +159,28 @@ def main():
158159
num_channels = config["model_params"]["in_chans"]
159160
sample_height = config["model_params"]["img_size"]
160161
sample_width = config["model_params"]["img_size"]
162+
depth = config["model_params"]["depth"]
161163

162164
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
163165
if "model_state_dict" in state_dict:
164166
state_dict = state_dict["model_state_dict"]
165167
model.load_state_dict(state_dict)
166168
model = model.eval().to(device)
167169

170+
# y = (
171+
# torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
172+
# if args.class_id is not None
173+
# else None
174+
# )
175+
176+
seed_everything(args.seed)
177+
168178
y = (
169-
torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
179+
torch.randint(1, 1001, (args.batch_size,)).to(device)
170180
if args.class_id is not None
171181
else None
172182
)
183+
173184
if "autoencoder" in config:
174185
autoencoder = get_autoencoder(
175186
config["autoencoder"]["autoencoder_checkpoint_path"]
@@ -186,6 +197,7 @@ def main():
186197
sample_height=sample_height,
187198
sample_width=sample_width,
188199
threshold=args.threshold,
200+
depth=depth,
189201
y=y,
190202
autoencoder=autoencoder,
191203
)

sampler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,20 @@ def main():
236236
else:
237237
model_late = None
238238

239+
# y = (
240+
# torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
241+
# if args.class_id is not None
242+
# else None
243+
# )
244+
245+
seed_everything(args.seed)
246+
239247
y = (
240-
torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
248+
torch.randint(1, 1001, (args.batch_size,)).to(device)
241249
if args.class_id is not None
242250
else None
243251
)
252+
244253
if "autoencoder" in config:
245254
autoencoder = get_autoencoder(
246255
config["autoencoder"]["autoencoder_checkpoint_path"]

0 commit comments

Comments
 (0)