Skip to content

Commit 81dbeaf

Browse files
committed
Integrate training uvit and deediff with imagenet64 and imagenet256
1 parent e5f81ae commit 81dbeaf

File tree

4 files changed

+52
-32
lines changed

4 files changed

+52
-32
lines changed

datasets/imagenet.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
from pathlib import Path
2+
13
from torch.utils.data import DataLoader
24
from torchvision import datasets, transforms
35

4-
from pathlib import Path
56
from datasets.sampler import ResumableSeedableSampler
67

78

89
# https://www.kaggle.com/datasets/dimensi0n/imagenet-256
910
def get_imagenet_dataloader(
1011
batch_size,
1112
seed,
12-
data_dir="./archive",
13+
data_dir,
14+
resize: bool, # resizing to 64x64
15+
normalize: bool = True,
1316
):
1417
"""
1518
Builds a dataloader with all images from a 540k subset of ImageNet (with 256x256 resolution).
@@ -22,13 +25,19 @@ def get_imagenet_dataloader(
2225
DataLoader: DataLoader object containing the dataset.
2326
"""
2427

25-
mean = (0.5, 0.5, 0.5)
26-
std = (0.5, 0.5, 0.5)
27-
2828
# All images from the dataset are 256x256 resolution
29-
transform = transforms.Compose(
30-
[transforms.ToTensor(), transforms.Normalize(mean, std)]
31-
)
29+
transformations = [transforms.ToTensor()]
30+
31+
if normalize:
32+
mean = (0.5, 0.5, 0.5)
33+
std = (0.5, 0.5, 0.5)
34+
35+
transformations.append(transforms.Normalize(mean, std))
36+
37+
if resize:
38+
transformations.append(transforms.Resize((64, 64)))
39+
40+
transform = transforms.Compose(transformations)
3241

3342
path = Path(data_dir) / "imagenet"
3443

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def get_args():
184184
"--dataset",
185185
type=str,
186186
default="cifar10",
187-
choices=["cifar10", "celeba", "imagenet"],
187+
choices=["cifar10", "celeba", "imagenet64", "imagenet256"],
188188
help="Dataset name",
189189
)
190190
parser.add_argument(
@@ -204,7 +204,7 @@ def main():
204204
config = load_config(args.config_path)
205205
args.__dict__.update(config["model_params"])
206206

207-
if args.dataset == "imagenet":
207+
if args.dataset == "imagenet256":
208208
args.__dict__.update(config["autoencoder"])
209209

210210
torch.use_deterministic_algorithms(True)

tests/test_datasets.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,12 @@ def wrapper(*args, **kwargs):
2222
return wrapper
2323

2424

25-
# Might delete later
26-
def ignore_if_imagenet_data_not_downloaded(f):
27-
@wraps(f)
28-
def wrapper(*args, **kwargs):
29-
if not Path("archive/").exists():
30-
return
31-
32-
return f(*args, **kwargs)
33-
34-
return wrapper
35-
36-
3725
@ignore_if_data_not_downloaded
3826
@pytest.mark.parametrize("batch_size", [16])
3927
def test_cifar10(batch_size):
40-
dataloader = get_cifar10_dataloader(batch_size=batch_size, seed=0)
28+
dataloader = get_cifar10_dataloader(
29+
batch_size=batch_size, seed=0, data_dir="./data"
30+
)
4131

4232
x, _ = next(iter(dataloader))
4333
assert x.shape == torch.Size([batch_size, 3, 32, 32])
@@ -46,16 +36,29 @@ def test_cifar10(batch_size):
4636
@ignore_if_data_not_downloaded
4737
@pytest.mark.parametrize("batch_size", [4])
4838
def test_celeba(batch_size):
49-
dataloader = get_celeba_dataloader(batch_size=batch_size, seed=0)
39+
dataloader = get_celeba_dataloader(batch_size=batch_size, seed=0, data_dir="./data")
5040

5141
x, _ = next(iter(dataloader))
5242
assert x.shape == torch.Size([batch_size, 3, 64, 64])
5343

5444

55-
@ignore_if_imagenet_data_not_downloaded
45+
@ignore_if_data_not_downloaded
46+
@pytest.mark.parametrize("batch_size", [4])
47+
def test_imagenet64(batch_size):
48+
dataloader = get_imagenet_dataloader(
49+
batch_size=batch_size, seed=0, data_dir="./data", resize=True
50+
)
51+
52+
x, _ = next(iter(dataloader))
53+
assert x.shape == torch.Size([batch_size, 3, 64, 64])
54+
55+
56+
@ignore_if_data_not_downloaded
5657
@pytest.mark.parametrize("batch_size", [4])
57-
def test_imagenet(batch_size):
58-
dataloader = get_imagenet_dataloader(batch_size=batch_size, seed=0)
58+
def test_imagenet256(batch_size):
59+
dataloader = get_imagenet_dataloader(
60+
batch_size=batch_size, seed=0, data_dir="./data", resize=False
61+
)
5962

6063
x, _ = next(iter(dataloader))
6164
assert x.shape == torch.Size([batch_size, 3, 256, 256])

trainer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,19 @@ def _init_dataloader(self):
140140
seed=self.args.seed,
141141
data_dir=self.args.data_path,
142142
)
143-
elif self.args.dataset == "imagenet":
143+
elif self.args.dataset == "imagenet64":
144144
self.dataloader = get_imagenet_dataloader(
145145
batch_size=self.args.batch_size,
146146
seed=self.args.seed,
147147
data_dir=self.args.data_path,
148+
resize=True,
149+
)
150+
elif self.args.dataset == "imagenet256":
151+
self.dataloader = get_imagenet_dataloader(
152+
batch_size=self.args.batch_size,
153+
seed=self.args.seed,
154+
data_dir=self.args.data_path,
155+
resize=False,
148156
)
149157
else:
150158
raise ValueError(f"Dataset {self.args.dataset} not implemented.")
@@ -298,7 +306,7 @@ def _loss_fn(self, batch):
298306
data = batch[0].to(self.device)
299307
batch_size = data.size(0)
300308
clean_images = data
301-
labels = batch[1].to(self.device) if self.args.dataset == "imagenet" else None
309+
labels = batch[1].to(self.device) if "imagenet" in self.args.dataset else None
302310

303311
timesteps = torch.randint(
304312
0, self.args.num_timesteps, (batch_size,), device=self.device
@@ -308,13 +316,13 @@ def _loss_fn(self, batch):
308316

309317
if self.args.model == "uvit":
310318
if self.args.parametrization == "predict_noise":
311-
predicted_noise = self.model(noisy_images, timesteps)
319+
predicted_noise = self.model(noisy_images, timesteps, labels)
312320
loss = F.mse_loss(predicted_noise, noise)
313321
elif self.args.parametrization == "predict_original":
314-
predicted_original = self.model(noisy_images, timesteps)
322+
predicted_original = self.model(noisy_images, timesteps, labels)
315323
loss = F.mse_loss(predicted_original, clean_images)
316324
elif self.args.parametrization == "predict_previous":
317-
predicted_previous = self.model(noisy_images, timesteps)
325+
predicted_previous = self.model(noisy_images, timesteps, labels)
318326

319327
betas = torch.linspace(1e-4, 0.02, 1000).to(self.device)
320328
alphas = 1 - betas

0 commit comments

Comments
 (0)