Skip to content

Commit e387a06

Browse files
committed
Compute time in sampler.py
1 parent fae196e commit e387a06

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

sampler.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from argparse import ArgumentParser
23
from pathlib import Path
34

@@ -105,19 +106,21 @@ def get_samples(
105106
return samples.cpu().numpy()
106107

107108

108-
def dump_samples(samples, output_folder):
109-
output_folder = Path(output_folder)
110-
output_folder.mkdir(parents=True, exist_ok=True)
111-
112-
plt.hist(samples.flatten())
113-
plt.savefig(output_folder / "histogram.png")
114-
plt.clf()
109+
def dump_samples(samples, output_folder: Path):
110+
# plt.hist(samples.flatten())
111+
# plt.savefig(output_folder / "histogram.png")
112+
# plt.clf()
115113

116114
for sample_id, sample in enumerate(samples):
117115
sample = np.clip(sample, 0, 1)
118116
plt.imsave(output_folder / f"{sample_id}.png", sample)
119117

120118

119+
def dump_statistics(elapsed_time, output_folder: Path):
120+
with open(output_folder / "statistics.txt", "w") as f:
121+
f.write(f"Elapsed time: {elapsed_time} s\n")
122+
123+
121124
def get_args():
122125
parser = ArgumentParser()
123126
parser.add_argument("--seed", type=int, default=0)
@@ -148,6 +151,10 @@ def get_args():
148151

149152
def main():
150153
args = get_args()
154+
155+
output_folder = Path(args.output_folder)
156+
output_folder.mkdir(parents=True, exist_ok=True)
157+
151158
if args.parametrization == "predict_noise":
152159
postprocessing = predict_noise_postprocessing
153160
elif args.parametrization == "predict_original":
@@ -164,25 +171,25 @@ def main():
164171
sample_height = config["model_params"]["img_size"]
165172
sample_width = config["model_params"]["img_size"]
166173

167-
print(num_channels)
168-
print(sample_height)
169-
170-
print(config)
171-
172-
model.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu"))
174+
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
175+
if "model_state_dict" in state_dict:
176+
state_dict = state_dict["model_state_dict"]
177+
model.load_state_dict(state_dict)
173178
model = model.eval().to(device)
174179

175180
y = (
176181
torch.ones(args.batch_size, dtype=torch.int).to(device) * args.class_id
177182
if args.class_id is not None
178183
else None
179184
)
180-
autoencoder = (
181-
get_autoencoder(config["autoencoder"]["autoencoder_checkpoint_path"])
182-
if "autoencoder" in config
183-
else None
184-
).to(device)
185+
if "autoencoder" in config:
186+
autoencoder = get_autoencoder(
187+
config["autoencoder"]["autoencoder_checkpoint_path"]
188+
).to(device)
189+
else:
190+
autoencoder = None
185191

192+
tic = time.time()
186193
samples = get_samples(
187194
model,
188195
args.batch_size,
@@ -194,7 +201,10 @@ def main():
194201
y,
195202
autoencoder,
196203
)
197-
dump_samples(samples, args.output_folder)
204+
tac = time.time()
205+
dump_statistics(tac - tic, output_folder)
206+
207+
dump_samples(samples, output_folder)
198208

199209

200210
if __name__ == "__main__":

0 commit comments

Comments
 (0)