Skip to content

Commit d4cf479

Browse files
committed
Fix FID metric bug
Both the real and the fake images should be normalized in the [0, 1] range. Before this, our real images came normalized (with mean = std = 0.5).
1 parent e387a06 commit d4cf479

File tree

3 files changed

+37
-29
lines changed

3 files changed

+37
-29
lines changed

datasets/celeba.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datasets.sampler import ResumableSeedableSampler
88

99

10-
def get_celeba_dataloader(batch_size, seed, data_dir="data/"):
10+
def get_celeba_dataloader(batch_size, seed, data_dir="data/", normalize: bool = True):
1111
"""
1212
Builds a dataloader with all images from the CelebA dataset.
1313
Args:
@@ -18,17 +18,27 @@ def get_celeba_dataloader(batch_size, seed, data_dir="data/"):
1818
DataLoader: DataLoader object containing the dataset.
1919
2020
"""
21-
mean = (0.5, 0.5, 0.5)
22-
std = (0.5, 0.5, 0.5)
23-
24-
data_transforms = transforms.Compose(
25-
[
26-
transforms.ToTensor(),
27-
transforms.Normalize(mean, std),
28-
transforms.CenterCrop((178, 178)),
29-
transforms.Resize((64, 64)),
30-
]
31-
)
21+
22+
if normalize:
23+
mean = (0.5, 0.5, 0.5)
24+
std = (0.5, 0.5, 0.5)
25+
26+
data_transforms = transforms.Compose(
27+
[
28+
transforms.ToTensor(),
29+
transforms.Normalize(mean, std),
30+
transforms.CenterCrop((178, 178)),
31+
transforms.Resize((64, 64)),
32+
]
33+
)
34+
else:
35+
data_transforms = transforms.Compose(
36+
[
37+
transforms.ToTensor(),
38+
transforms.CenterCrop((178, 178)),
39+
transforms.Resize((64, 64)),
40+
]
41+
)
3242

3343
path = Path(data_dir)
3444

datasets/cifar10.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
from datasets.sampler import ResumableSeedableSampler
88

99

10-
def get_cifar10_dataloader(
11-
batch_size,
12-
seed,
13-
data_dir,
14-
):
10+
def get_cifar10_dataloader(batch_size, seed, data_dir, normalize: bool = True):
1511
"""
1612
Builds a dataloader with all training images from the CIFAR-10 dataset.
1713
Args:
@@ -22,13 +18,15 @@ def get_cifar10_dataloader(
2218
DataLoader: DataLoader object containing the dataset.
2319
2420
"""
25-
26-
mean = (0.5, 0.5, 0.5)
27-
std = (0.5, 0.5, 0.5)
28-
29-
transform = transforms.Compose(
30-
[transforms.ToTensor(), transforms.Normalize(mean, std)]
31-
)
21+
if normalize:
22+
mean = (0.5, 0.5, 0.5)
23+
std = (0.5, 0.5, 0.5)
24+
25+
transform = transforms.Compose(
26+
[transforms.ToTensor(), transforms.Normalize(mean, std)]
27+
)
28+
else:
29+
transform = transforms.ToTensor()
3230

3331
path = Path(data_dir) / "cifar10"
3432

utils/evaluation_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from pathlib import Path
22

3-
from PIL import Image
43
import torch
4+
from PIL import Image
55
from torchvision import transforms
66
from torchvision.utils import save_image
77

8-
from datasets.cifar10 import get_cifar10_dataloader
98
from datasets.celeba import get_celeba_dataloader
9+
from datasets.cifar10 import get_cifar10_dataloader
1010

1111

1212
def read_samples(path):
13-
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
13+
transform = transforms.Compose([transforms.ToTensor()])
1414

1515
tensor_list = []
1616
for p in Path(path).rglob("*.png"):
@@ -24,9 +24,9 @@ def read_samples(path):
2424

2525
def get_dataset_samples(dataset_name, data_path, seed, n_samples):
2626
if dataset_name == "cifar10":
27-
dataset = get_cifar10_dataloader(n_samples, seed, data_path)
27+
dataset = get_cifar10_dataloader(n_samples, seed, data_path, normalize=False)
2828
elif dataset_name == "celeba":
29-
dataset = get_celeba_dataloader(n_samples, seed, data_path)
29+
dataset = get_celeba_dataloader(n_samples, seed, data_path, normalize=False)
3030
else:
3131
raise ValueError("Incorrect dataset name")
3232

0 commit comments

Comments
 (0)