Skip to content

Commit 008cc87

Browse files
Refactor dataset handlers and update Makefile
1 parent bf1a9e0 commit 008cc87

File tree

13 files changed

+182
-313
lines changed

13 files changed

+182
-313
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
UV := uv
1+
UV := $(HOME)/.local/bin/uv
22
PYTHON_VERSION := 3.11
33
UV_INSTALL_SCRIPT := https://astral.sh/uv/install.sh
44
PATH := $(HOME)/.local/bin:$(PATH)
@@ -20,8 +20,9 @@ check-uv: ## Check and install uv if necessary
2020
install-python: check-uv ## Install Python with uv
2121
@echo "🐍 Installing Python $(PYTHON_VERSION) with uv"
2222
@$(UV) python install $(PYTHON_VERSION)
23-
@echo "🔧 Configuring Python $(PYTHON_VERSION) as the default Python version"
23+
@echo "🐍 Configuring Python $(PYTHON_VERSION) as the default Python version"
2424
@$(UV) python pin $(PYTHON_VERSION)
25+
@echo "🐍 Python installation complete."
2526

2627
.PHONY: install
2728
install: install-python ## Install core dependencies

nebula/core/datasets/cifar10/cifar10.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class CIFAR10PartitionHandler(NebulaPartitionHandler):
11-
def __init__(self, file_path, prefix, mode):
12-
super().__init__(file_path, prefix, mode)
11+
def __init__(self, file_path, prefix, config, empty=False):
12+
super().__init__(file_path, prefix, config, empty)
1313

1414
# Custom transform for CIFAR10
1515
mean = (0.4914, 0.4822, 0.4465)
@@ -22,9 +22,17 @@ def __init__(self, file_path, prefix, mode):
2222
])
2323

2424
def __getitem__(self, idx):
25-
img, target = super().__getitem__(idx)
25+
data, target = super().__getitem__(idx)
2626

27-
img = Image.fromarray(img)
27+
# CIFAR10 from torchvision returns a tuple (image, target)
28+
if isinstance(data, tuple):
29+
img, target = data
30+
else:
31+
img = data
32+
33+
# Only convert if not already a PIL image
34+
if not isinstance(img, Image.Image):
35+
img = Image.fromarray(img)
2836

2937
if self.transform is not None:
3038
img = self.transform(img)

nebula/core/datasets/cifar100/cifar100.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class CIFAR100PartitionHandler(NebulaPartitionHandler):
11-
def __init__(self, file_path, prefix, mode):
12-
super().__init__(file_path, prefix, mode)
11+
def __init__(self, file_path, prefix, config, empty=False):
12+
super().__init__(file_path, prefix, config, empty)
1313

1414
# Custom transform for CIFAR100
1515
mean = (0.4914, 0.4822, 0.4465)
@@ -22,9 +22,17 @@ def __init__(self, file_path, prefix, mode):
2222
])
2323

2424
def __getitem__(self, idx):
25-
img, target = super().__getitem__(idx)
25+
data, target = super().__getitem__(idx)
2626

27-
img = Image.fromarray(img)
27+
# CIFAR100 from torchvision returns a tuple (image, target)
28+
if isinstance(data, tuple):
29+
img, target = data
30+
else:
31+
img = data
32+
33+
# Only convert if not already a PIL image
34+
if not isinstance(img, Image.Image):
35+
img = Image.fromarray(img)
2836

2937
if self.transform is not None:
3038
img = self.transform(img)

nebula/core/datasets/datamodule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(
1717
train_set_indices,
1818
test_set,
1919
test_set_indices,
20+
local_test_set,
2021
local_test_set_indices,
2122
batch_size=32,
2223
num_workers=0,
@@ -28,6 +29,7 @@ def __init__(
2829
self.train_set_indices = train_set_indices
2930
self.test_set = test_set
3031
self.test_set_indices = test_set_indices
32+
self.local_test_set = local_test_set
3133
self.local_test_set_indices = local_test_set_indices
3234
self.batch_size = batch_size
3335
self.num_workers = num_workers
@@ -73,7 +75,7 @@ def setup(self, stage=None):
7375
if stage in (None, "test"):
7476
# Test sets
7577
self.global_te_subset = ChangeableSubset(self.test_set, self.test_set_indices)
76-
self.local_te_subset = ChangeableSubset(self.test_set, self.local_test_set_indices)
78+
self.local_te_subset = ChangeableSubset(self.local_test_set, self.local_test_set_indices)
7779

7880
def teardown(self, stage=None):
7981
# Teardown the datasets

nebula/core/datasets/emnist/emnist.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class EMNISTPartitionHandler(NebulaPartitionHandler):
11-
def __init__(self, file_path, prefix, mode):
12-
super().__init__(file_path, prefix, mode)
11+
def __init__(self, file_path, prefix, config, empty=False):
12+
super().__init__(file_path, prefix, config, empty)
1313

1414
# Custom transform for EMNIST
1515
mean = (0.5,)
@@ -22,9 +22,17 @@ def __init__(self, file_path, prefix, mode):
2222
])
2323

2424
def __getitem__(self, idx):
25-
img, target = super().__getitem__(idx)
25+
data, target = super().__getitem__(idx)
2626

27-
img = Image.fromarray(img, mode="L")
27+
# EMNIST from torchvision returns a tuple (image, target)
28+
if isinstance(data, tuple):
29+
img, target = data
30+
else:
31+
img = data
32+
33+
# Only convert if not already a PIL image
34+
if not isinstance(img, Image.Image):
35+
img = Image.fromarray(img, mode="L")
2836

2937
if self.transform is not None:
3038
img = self.transform(img)

nebula/core/datasets/fashionmnist/fashionmnist.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class FashionMNISTPartitionHandler(NebulaPartitionHandler):
11-
def __init__(self, file_path, prefix, mode):
12-
super().__init__(file_path, prefix, mode)
11+
def __init__(self, file_path, prefix, config, empty=False):
12+
super().__init__(file_path, prefix, config, empty)
1313

1414
# Custom transform for MNIST
1515
self.transform = transforms.Compose([
@@ -18,9 +18,17 @@ def __init__(self, file_path, prefix, mode):
1818
])
1919

2020
def __getitem__(self, idx):
21-
img, target = super().__getitem__(idx)
21+
data, target = super().__getitem__(idx)
2222

23-
img = Image.fromarray(img, mode="L")
23+
# FashionMNIST from torchvision returns a tuple (image, target)
24+
if isinstance(data, tuple):
25+
img, target = data
26+
else:
27+
img = data
28+
29+
# Only convert if not already a PIL image
30+
if not isinstance(img, Image.Image):
31+
img = Image.fromarray(img, mode="L")
2432

2533
if self.transform is not None:
2634
img = self.transform(img)

nebula/core/datasets/mnist/mnist.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class MNISTPartitionHandler(NebulaPartitionHandler):
11-
def __init__(self, file_path, prefix, mode):
12-
super().__init__(file_path, prefix, mode)
11+
def __init__(self, file_path, prefix, config, empty=False):
12+
super().__init__(file_path, prefix, config, empty)
1313

1414
# Custom transform for MNIST
1515
self.transform = transforms.Compose([
@@ -18,9 +18,17 @@ def __init__(self, file_path, prefix, mode):
1818
])
1919

2020
def __getitem__(self, idx):
21-
img, target = super().__getitem__(idx)
21+
data, target = super().__getitem__(idx)
2222

23-
img = Image.fromarray(img, mode="L")
23+
# MNIST from torchvision returns a tuple (image, target)
24+
if isinstance(data, tuple):
25+
img, target = data
26+
else:
27+
img = data
28+
29+
# Only convert if not already a PIL image
30+
if not isinstance(img, Image.Image):
31+
img = Image.fromarray(img, mode="L")
2432

2533
if self.transform is not None:
2634
img = self.transform(img)

nebula/core/datasets/mnistML/__init__.py

Whitespace-only changes.

nebula/core/datasets/mnistML/mnist.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)