Skip to content

Commit 37f1c20

Browse files
committed
Merge branch 'main' of github.com:razvanmatisan/early-stopping-diffusion
2 parents 22f02e0 + 8b560ba commit 37f1c20

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

models/early_exit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def forward(self, x, timesteps, y=None):
280280
)
281281
time_token = time_token.unsqueeze(dim=1)
282282
x = torch.cat((time_token, x), dim=1)
283-
if y is not None:
283+
if y is not None and self.uvit.label_emb is not None:
284284
label_emb = self.uvit.label_emb(y)
285285
label_emb = label_emb.unsqueeze(dim=1)
286286
x = torch.cat((label_emb, x), dim=1)

models/uvit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def __init__(
275275
self.label_emb = nn.Embedding(self.num_classes, embed_dim)
276276
self.extras = 2
277277
else:
278+
self.label_emb = None
278279
self.extras = 1
279280

280281
self.pos_embed = nn.Parameter(

sampler.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from einops import rearrange
99
from matplotlib import pyplot as plt
1010
from tqdm import tqdm
11+
from typing import List
1112

1213
from models.utils.autoencoder import get_autoencoder
1314
from models.uvit import UViT
@@ -89,13 +90,15 @@ def get_samples(
8990
use_ddim: bool,
9091
ddim_steps: int,
9192
ddim_eta: float,
93+
timesteps_save: List[int],
9294
y: int = None,
9395
autoencoder=None,
9496
late_model=None,
9597
t_switch=np.inf,
9698
):
9799
seed_everything(seed)
98100
x = torch.randn(batch_size, num_channels, sample_height, sample_width).to(device)
101+
intermediate_samples = []
99102

100103
if use_ddim:
101104
timesteps = np.linspace(0, 999, ddim_steps).astype(int)[::-1]
@@ -119,25 +122,41 @@ def get_samples(
119122
if t < 1000 - t_switch:
120123
model = late_model
121124

125+
if 1000 - t in timesteps_save:
126+
intermediate_samples.append(x)
127+
122128
else:
123129
for t in tqdm(range(999, -1, -1)):
124130
time_tensor = t * torch.ones(batch_size, device=device)
125131
with torch.no_grad():
126132
model_output = model(x, time_tensor, y)
127133
x = postprocessing(model_output, x, t)
134+
128135
if t == 1000 - t_switch:
129136
model = late_model
130137

138+
if 1000 - t in timesteps_save:
139+
intermediate_samples.append(x)
140+
141+
131142
if autoencoder:
132143
print("Decode the images...")
133144
x = autoencoder.decode(x)
134145

135146
samples = (x + 1) / 2
136147
samples = rearrange(samples, "b c h w -> b h w c")
137-
return samples.cpu().numpy()
138148

149+
for i, x in enumerate(intermediate_samples):
150+
if autoencoder:
151+
x = autoencoder.decode(x)
152+
x = (x + 1) / 2
153+
x = rearrange(x, "b c h w -> b h w c")
154+
intermediate_samples[i] = x.cpu().numpy()
155+
156+
return samples.cpu().numpy(), intermediate_samples
139157

140-
def dump_samples(samples, output_folder: Path):
158+
159+
def dump_samples(samples, output_folder: Path, timestep=1000):
141160
# plt.hist(samples.flatten())
142161
# plt.savefig(output_folder / "histogram.png")
143162
# plt.clf()
@@ -151,7 +170,8 @@ def dump_samples(samples, output_folder: Path):
151170

152171
for sample_id, sample in enumerate(samples):
153172
sample = np.clip(sample, 0, 1)
154-
plt.imsave(output_folder / f"{sample_id}.png", sample)
173+
filename = f"{sample_id}_{timestep}.png" if timestep !=1000 else f"{sample_id}.png"
174+
plt.imsave(output_folder / filename, sample)
155175

156176
row, col = divmod(sample_id, grid_size)
157177
grid_img[
@@ -226,6 +246,12 @@ def get_args():
226246
type=float,
227247
default=0.0,
228248
)
249+
parser.add_argument(
250+
"--timesteps_save",
251+
type=int,
252+
nargs="+",
253+
default=[]
254+
)
229255

230256
return parser.parse_args()
231257

@@ -303,7 +329,7 @@ def main():
303329
autoencoder = None
304330

305331
tic = time.time()
306-
samples = get_samples(
332+
samples, intermediate_samples = get_samples(
307333
model=model,
308334
batch_size=args.batch_size,
309335
postprocessing=postprocessing,
@@ -318,12 +344,17 @@ def main():
318344
autoencoder=autoencoder,
319345
late_model=model_late,
320346
t_switch=args.t_switch,
347+
timesteps_save=args.timesteps_save
321348
)
322349
tac = time.time()
323350
dump_statistics(tac - tic, output_folder)
324351

325352
dump_samples(samples, output_folder)
326353

354+
if args.timesteps_save:
355+
for timestep, samples in zip(args.timesteps_save, intermediate_samples):
356+
dump_samples(samples, output_folder, timestep)
357+
327358

328359
if __name__ == "__main__":
329360
main()

0 commit comments

Comments
 (0)