Skip to content

Commit 5107295

Browse files
committed
Fid for iamgenet
1 parent 484ba95 commit 5107295

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

fid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def get_args():
1212
"--dataset",
1313
type=str,
1414
required=True,
15-
choices=["cifar10", "celeba"],
15+
choices=["cifar10", "celeba", "imagenet64", "imagenet256"],
1616
help="Dataset name.",
1717
)
1818
parser.add_argument(
@@ -43,6 +43,7 @@ def main():
4343
args = get_args()
4444
generated_images = read_samples(args.samples_path)
4545
n_samples = len(generated_images)
46+
print(f"Using {n_samples}")
4647
real_images = get_dataset_samples(
4748
args.dataset, args.data_path, args.seed, n_samples
4849
)

utils/evaluation_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from datasets.celeba import get_celeba_dataloader
99
from datasets.cifar10 import get_cifar10_dataloader
10+
from datasets.imagenet import get_imagenet_dataloader
1011

1112

1213
def read_samples(path):
@@ -28,6 +29,10 @@ def get_dataset_samples(dataset_name, data_path, seed, n_samples):
2829
dataset = get_cifar10_dataloader(n_samples, seed, data_path, normalize=False)
2930
elif dataset_name == "celeba":
3031
dataset = get_celeba_dataloader(n_samples, seed, data_path, normalize=False)
32+
elif dataset_name == "imagenet64":
33+
dataset = get_imagenet_dataloader(n_samples, seed, data_path, normalize=False, resize=True)
34+
elif dataset_name == "imagenet256":
35+
dataset = get_imagenet_dataloader(n_samples, seed, data_path, normalize=False, resize=False)
3136
else:
3237
raise ValueError("Incorrect dataset name")
3338

0 commit comments

Comments
 (0)