Skip to content

Commit 56977db

Browse files
[Feature/datasets] Improve dataset management (#33)
* new implementation of datasets * update dataset implementation * update dataset implementation * update dataset implementation and remove reload * update dataset implementation (all datasets) * include dataset mode * include dataset mode * fix dataset parameters * optimize dataset management * clean dataset resources * improve nebuladataset
1 parent d4283bd commit 56977db

File tree

38 files changed

+872
-1685
lines changed

38 files changed

+872
-1685
lines changed

docs/_prebuilt/developerguide.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ First, you must add the Dataset option in the frontend. Adding the Dataset optio
259259
"EMNIST": ["MLP", "CNN"],
260260
"CIFAR10": ["CNN", "CNNv2", "CNNv3", "ResNet9", "fastermobilenet", "simplemobilenet"],
261261
"CIFAR100": ["CNN"],
262-
"KITSUN": ["MLP"],
263262
}
264263
var datasetSelect = document.getElementById("datasetSelect");
265264
var modelSelect = document.getElementById("modelSelect");

nebula/addons/trustworthiness/factsheet.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
get_feature_importance_cv,
2020
)
2121
from nebula.addons.trustworthiness.utils import check_field_filled, count_class_samples, get_entropy, read_csv
22-
from nebula.core.models.mnist.cnn import CIFAR10ModelCNN, CIFAR10TorchModelCNN, MNISTModelCNN, MNISTTorchModelCNN
23-
from nebula.core.models.mnist.mlp import MNISTModelMLP, MNISTTorchModelMLP, SyscallModelMLP, SyscallTorchModelMLP
22+
from nebula.core.models.mnist.cnn import MNISTModelCNN
23+
from nebula.core.models.mnist.mlp import MNISTModelMLP
2424

2525
dirname = os.path.dirname(__file__)
2626

@@ -119,10 +119,8 @@ def populate_factsheet_pre_train(self, data, scenario_name):
119119
model = MNISTModelMLP()
120120
elif dataset == "MNIST" and algorithm == "CNN":
121121
model = MNISTModelCNN()
122-
elif dataset == "Syscall" and algorithm == "MLP":
123-
model = SyscallModelMLP()
124122
else:
125-
model = CIFAR10ModelCNN()
123+
model = MNISTModelCNN()
126124

127125
factsheet["configuration"]["learning_rate"] = model.get_learning_rate()
128126
factsheet["configuration"]["trainable_param_num"] = model.count_parameters()
@@ -225,8 +223,6 @@ def populate_factsheet_post_train(self, scenario):
225223
pytorch_model = MNISTTorchModelMLP()
226224
elif dataset == "MNIST" and model == "CNN":
227225
pytorch_model = MNISTTorchModelCNN()
228-
elif dataset == "Syscall" and model == "MLP":
229-
pytorch_model = SyscallTorchModelMLP()
230226
else:
231227
pytorch_model = CIFAR10TorchModelCNN()
232228

nebula/controller.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,20 @@ def verify_nodes_ports(self, src_path):
193193
try:
194194
port_mapping = {}
195195
new_port_start = 50000
196-
196+
197197
participant_files = sorted(
198198
f for f in os.listdir(scenario_path) if f.endswith(".json") and f.startswith("participant")
199199
)
200-
200+
201201
for filename in participant_files:
202202
file_path = os.path.join(scenario_path, filename)
203203
with open(file_path) as json_file:
204204
node = json.load(json_file)
205205
current_port = node["network_args"]["port"]
206206
port_mapping[current_port] = SocketUtils.find_free_port(start_port=new_port_start)
207-
logging.info(f"Participant file: {filename} | Current port: {current_port} | New port: {port_mapping[current_port]}")
207+
logging.info(
208+
f"Participant file: {filename} | Current port: {current_port} | New port: {port_mapping[current_port]}"
209+
)
208210
new_port_start = port_mapping[current_port] + 1
209211

210212
for filename in participant_files:
Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,63 @@
11
import os
22

3+
from PIL import Image
34
from torchvision import transforms
45
from torchvision.datasets import CIFAR10
56

6-
from nebula.core.datasets.nebuladataset import NebulaDataset
7+
from nebula.core.datasets.nebuladataset import NebulaDataset, NebulaPartitionHandler
8+
9+
10+
class CIFAR10PartitionHandler(NebulaPartitionHandler):
11+
def __init__(self, file_path, prefix, mode):
12+
super().__init__(file_path, prefix, mode)
13+
14+
# Custom transform for CIFAR10
15+
mean = (0.4914, 0.4822, 0.4465)
16+
std = (0.2471, 0.2435, 0.2616)
17+
self.transform = transforms.Compose([
18+
transforms.RandomCrop(32, padding=4),
19+
transforms.RandomHorizontalFlip(),
20+
transforms.ToTensor(),
21+
transforms.Normalize(mean, std, inplace=True),
22+
])
23+
24+
def __getitem__(self, idx):
25+
img, target = super().__getitem__(idx)
26+
27+
img = Image.fromarray(img)
28+
29+
if self.transform is not None:
30+
img = self.transform(img)
31+
32+
if self.target_transform is not None:
33+
target = self.target_transform(target)
34+
35+
return img, target
736

837

938
class CIFAR10Dataset(NebulaDataset):
1039
def __init__(
1140
self,
1241
num_classes=10,
13-
partition_id=0,
1442
partitions_number=1,
1543
batch_size=32,
1644
num_workers=4,
1745
iid=True,
1846
partition="dirichlet",
1947
partition_parameter=0.5,
2048
seed=42,
21-
config=None,
49+
config_dir=None,
2250
):
2351
super().__init__(
2452
num_classes=num_classes,
25-
partition_id=partition_id,
2653
partitions_number=partitions_number,
2754
batch_size=batch_size,
2855
num_workers=num_workers,
2956
iid=iid,
3057
partition=partition,
3158
partition_parameter=partition_parameter,
3259
seed=seed,
33-
config=config,
60+
config_dir=config_dir,
3461
)
3562

3663
def initialize_dataset(self):
@@ -40,39 +67,15 @@ def initialize_dataset(self):
4067
if self.test_set is None:
4168
self.test_set = self.load_cifar10_dataset(train=False)
4269

43-
# All nodes have the same test set (indices are the same for all nodes)
44-
self.test_indices_map = list(range(len(self.test_set)))
45-
46-
# Depending on the iid flag, generate a non-iid or iid map of the train set
47-
if self.iid:
48-
self.train_indices_map = self.generate_iid_map(self.train_set, self.partition, self.partition_parameter)
49-
self.local_test_indices_map = self.generate_iid_map(self.test_set, self.partition, self.partition_parameter)
50-
else:
51-
self.train_indices_map = self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter)
52-
self.local_test_indices_map = self.generate_non_iid_map(
53-
self.test_set, self.partition, self.partition_parameter
54-
)
55-
56-
print(f"Length of train indices map: {len(self.train_indices_map)}")
57-
print(f"Lenght of test indices map (global): {len(self.test_indices_map)}")
58-
print(f"Length of test indices map (local): {len(self.local_test_indices_map)}")
70+
self.data_partitioning(plot=True)
5971

6072
def load_cifar10_dataset(self, train=True):
61-
mean = (0.4914, 0.4822, 0.4465)
62-
std = (0.2471, 0.2435, 0.2616)
63-
apply_transforms = transforms.Compose([
64-
transforms.RandomCrop(32, padding=4),
65-
transforms.RandomHorizontalFlip(),
66-
transforms.ToTensor(),
67-
transforms.Normalize(mean, std, inplace=True),
68-
])
6973
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
7074
os.makedirs(data_dir, exist_ok=True)
7175
return CIFAR10(
7276
data_dir,
7377
train=train,
7478
download=True,
75-
transform=apply_transforms,
7679
)
7780

7881
def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5):
@@ -83,11 +86,7 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet
8386
else:
8487
raise ValueError(f"Partition {partition} is not supported for Non-IID map")
8588

86-
if self.partition_id == 0:
87-
self.plot_data_distribution(dataset, partitions_map)
88-
self.plot_all_data_distribution(dataset, partitions_map)
89-
90-
return partitions_map[self.partition_id]
89+
return partitions_map
9190

9291
def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2):
9392
if partition == "balancediid":
@@ -97,8 +96,4 @@ def generate_iid_map(self, dataset, partition="balancediid", partition_parameter
9796
else:
9897
raise ValueError(f"Partition {partition} is not supported for IID map")
9998

100-
if self.partition_id == 0:
101-
self.plot_data_distribution(dataset, partitions_map)
102-
self.plot_all_data_distribution(dataset, partitions_map)
103-
104-
return partitions_map[self.partition_id]
99+
return partitions_map
Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,63 @@
11
import os
22

3+
from PIL import Image
34
from torchvision import transforms
45
from torchvision.datasets import CIFAR100
56

6-
from nebula.core.datasets.nebuladataset import NebulaDataset
7+
from nebula.core.datasets.nebuladataset import NebulaDataset, NebulaPartitionHandler
8+
9+
10+
class CIFAR100PartitionHandler(NebulaPartitionHandler):
11+
def __init__(self, file_path, prefix, mode):
12+
super().__init__(file_path, prefix, mode)
13+
14+
# Custom transform for CIFAR100
15+
mean = (0.4914, 0.4822, 0.4465)
16+
std = (0.2471, 0.2435, 0.2616)
17+
self.transform = transforms.Compose([
18+
transforms.RandomCrop(32, padding=4),
19+
transforms.RandomHorizontalFlip(),
20+
transforms.ToTensor(),
21+
transforms.Normalize(mean, std, inplace=True),
22+
])
23+
24+
def __getitem__(self, idx):
25+
img, target = super().__getitem__(idx)
26+
27+
img = Image.fromarray(img)
28+
29+
if self.transform is not None:
30+
img = self.transform(img)
31+
32+
if self.target_transform is not None:
33+
target = self.target_transform(target)
34+
35+
return img, target
736

837

938
class CIFAR100Dataset(NebulaDataset):
1039
def __init__(
1140
self,
1241
num_classes=100,
13-
partition_id=0,
1442
partitions_number=1,
1543
batch_size=32,
1644
num_workers=4,
1745
iid=True,
1846
partition="dirichlet",
1947
partition_parameter=0.5,
2048
seed=42,
21-
config=None,
49+
config_dir=None,
2250
):
2351
super().__init__(
2452
num_classes=num_classes,
25-
partition_id=partition_id,
2653
partitions_number=partitions_number,
2754
batch_size=batch_size,
2855
num_workers=num_workers,
2956
iid=iid,
3057
partition=partition,
3158
partition_parameter=partition_parameter,
3259
seed=seed,
33-
config=config,
60+
config_dir=config_dir,
3461
)
3562

3663
def initialize_dataset(self):
@@ -40,37 +67,13 @@ def initialize_dataset(self):
4067
if self.test_set is None:
4168
self.test_set = self.load_cifar100_dataset(train=False)
4269

43-
# All nodes have the same test set (indices are the same for all nodes)
44-
self.test_indices_map = list(range(len(self.test_set)))
45-
46-
# Depending on the iid flag, generate a non-iid or iid map of the train set
47-
if self.iid:
48-
self.train_indices_map = self.generate_iid_map(self.train_set, self.partition, self.partition_parameter)
49-
self.local_test_indices_map = self.generate_iid_map(self.test_set, self.partition, self.partition_parameter)
50-
else:
51-
self.train_indices_map = self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter)
52-
self.local_test_indices_map = self.generate_non_iid_map(
53-
self.test_set, self.partition, self.partition_parameter
54-
)
55-
56-
print(f"Length of train indices map: {len(self.train_indices_map)}")
57-
print(f"Lenght of test indices map (global): {len(self.test_indices_map)}")
58-
print(f"Length of test indices map (local): {len(self.local_test_indices_map)}")
70+
self.data_partitioning(plot=True)
5971

6072
def load_cifar100_dataset(self, train=True):
61-
mean = (0.4914, 0.4822, 0.4465)
62-
std = (0.2471, 0.2435, 0.2616)
63-
apply_transforms = transforms.Compose([
64-
transforms.RandomCrop(32, padding=4),
65-
transforms.RandomHorizontalFlip(),
66-
transforms.ToTensor(),
67-
transforms.Normalize(mean, std, inplace=True),
68-
])
6973
return CIFAR100(
7074
os.path.join(os.path.dirname(os.path.abspath(__file__)), "data"),
7175
train=train,
7276
download=True,
73-
transform=apply_transforms,
7477
)
7578

7679
def generate_non_iid_map(self, dataset, partition="dirichlet", partition_parameter=0.5):
@@ -81,11 +84,7 @@ def generate_non_iid_map(self, dataset, partition="dirichlet", partition_paramet
8184
else:
8285
raise ValueError(f"Partition {partition} is not supported for Non-IID map")
8386

84-
if self.partition_id == 0:
85-
self.plot_data_distribution(dataset, partitions_map)
86-
self.plot_all_data_distribution(dataset, partitions_map)
87-
88-
return partitions_map[self.partition_id]
87+
return partitions_map
8988

9089
def generate_iid_map(self, dataset, partition="balancediid", partition_parameter=2):
9190
if partition == "balancediid":
@@ -95,8 +94,4 @@ def generate_iid_map(self, dataset, partition="balancediid", partition_parameter
9594
else:
9695
raise ValueError(f"Partition {partition} is not supported for IID map")
9796

98-
if self.partition_id == 0:
99-
self.plot_data_distribution(dataset, partitions_map)
100-
self.plot_all_data_distribution(dataset, partitions_map)
101-
102-
return partitions_map[self.partition_id]
97+
return partitions_map

nebula/core/datasets/datamodule.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ def __init__(
1818
test_set,
1919
test_set_indices,
2020
local_test_set_indices,
21-
partition_id=0,
22-
partitions_number=1,
2321
batch_size=32,
2422
num_workers=0,
2523
val_percent=0.1,
@@ -31,8 +29,6 @@ def __init__(
3129
self.test_set = test_set
3230
self.test_set_indices = test_set_indices
3331
self.local_test_set_indices = local_test_set_indices
34-
self.partition_id = partition_id
35-
self.partitions_number = partitions_number
3632
self.batch_size = batch_size
3733
self.num_workers = num_workers
3834
self.val_percent = val_percent
@@ -79,9 +75,6 @@ def setup(self, stage=None):
7975
self.global_te_subset = ChangeableSubset(self.test_set, self.test_set_indices)
8076
self.local_te_subset = ChangeableSubset(self.test_set, self.local_test_set_indices)
8177

82-
if len(self.test_set) < self.partitions_number:
83-
raise ValueError("Too many partitions for the size of the test set.")
84-
8578
def teardown(self, stage=None):
8679
# Teardown the datasets
8780
if stage in (None, "fit"):

0 commit comments

Comments
 (0)