Skip to content

Commit b1ff153

Browse files
committed
upgrade trust
1 parent c8ff363 commit b1ff153

File tree

5 files changed

+76
-53
lines changed

5 files changed

+76
-53
lines changed

nebula/addons/trustworthiness/factsheet.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def populate_factsheet_pre_train(self, data, scenario_name):
6767
poisoned_node_percent = 0
6868
poisoned_sample_percent = 0
6969
poisoned_noise_percent = 0
70-
with_reputation = data["with_reputation"]
70+
with_reputation = data["reputation"]["enabled"]
7171
is_dynamic_topology = False # data["is_dynamic_topology"]
7272
is_dynamic_aggregation = False # data["is_dynamic_aggregation"]
7373
target_aggregation = False # data["target_aggregation"]
@@ -125,15 +125,14 @@ def populate_factsheet_pre_train(self, data, scenario_name):
125125
factsheet["configuration"]["trainable_param_num"] = model.count_parameters()
126126
factsheet["configuration"]["local_update_steps"] = 1
127127

128+
f.seek(0)
129+
f.truncate()
130+
json.dump(factsheet, f, indent=4)
131+
128132
except JSONDecodeError as e:
129133
logging.warning(f"{factsheet_file} is invalid")
130134
logging.error(e)
131135

132-
f.seek(0)
133-
f.truncate()
134-
json.dump(factsheet, f, indent=4)
135-
f.close()
136-
137136
def populate_factsheet_post_train(self, scenario_name, start_time, end_time):
138137
"""
139138
Populates the factsheet with values after the training.
@@ -198,14 +197,16 @@ def populate_factsheet_post_train(self, scenario_name, start_time, end_time):
198197

199198
factsheet["fairness"]["selection_cv"] = 1
200199

201-
count_class_samples(scenario_name, dataloaders_files)
200+
# count_class_samples(scenario_name, dataloaders_files)
201+
202+
## FER
202203

203-
with open(f"{files_dir}/count_class.json", "r") as file:
204-
class_distribution = json.load(file)
204+
# with open(f"{files_dir}/count_class.json", "r") as file:
205+
# class_distribution = json.load(file)
205206

206-
class_samples_sizes = [x for x in class_distribution.values()]
207-
class_imbalance = get_cv(list=class_samples_sizes)
208-
factsheet["fairness"]["class_imbalance"] = 1 if class_imbalance > 1 else class_imbalance
207+
# class_samples_sizes = [x for x in class_distribution.values()]
208+
# class_imbalance = get_cv(list=class_samples_sizes)
209+
# factsheet["fairness"]["class_imbalance"] = 1 if class_imbalance > 1 else class_imbalance
209210

210211
with open(train_model_file, "rb") as file:
211212
lightning_model = pickle.load(file)
@@ -273,11 +274,10 @@ def populate_factsheet_post_train(self, scenario_name, start_time, end_time):
273274
factsheet["sustainability"]["emissions_communication_uplink"] = check_field_filled(factsheet, ["sustainability", "emissions_communication_uplink"], factsheet["system"]["total_upload_bytes"] * 2.24e-10 * factsheet["sustainability"]["avg_carbon_intensity_clients"], "")
274275
factsheet["sustainability"]["emissions_communication_downlink"] = check_field_filled(factsheet, ["sustainability", "emissions_communication_downlink"], factsheet["system"]["total_download_bytes"] * 2.24e-10 * factsheet["sustainability"]["avg_carbon_intensity_server"], "")
275276

276-
except JSONDecodeError as e:
277-
logging.warning(f"{factsheet_file} is invalid")
278-
logging.error(e)
277+
f.seek(0)
278+
f.truncate()
279+
json.dump(factsheet, f, indent=4)
279280

280-
f.seek(0)
281-
f.truncate()
282-
json.dump(factsheet, f, indent=4)
283-
f.close()
281+
except JSONDecodeError as e:
282+
logging.info(f"{factsheet_file} is invalid")
283+
logging.error(e)

nebula/addons/trustworthiness/trustworthiness.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -166,40 +166,40 @@ async def _generate_factsheet(self, trust_config, experiment_name):
166166
factsheet.populate_factsheet_post_train(experiment_name, self._start_time, self._end_time)
167167
logging.info("[FER] factsheet post train done")
168168

169-
data_file_path = os.path.join(os.environ.get('NEBULA_CONFIG_DIR'), experiment_name, "scenario.json")
170-
with open(data_file_path, 'r') as data_file:
171-
data = json.load(data_file)
169+
# data_file_path = os.path.join(os.environ.get('NEBULA_CONFIG_DIR'), experiment_name, "scenario.json")
170+
# with open(data_file_path, 'r') as data_file:
171+
# data = json.load(data_file)
172172

173-
weights = {
174-
"robustness": float(data["robustness_pillar"]),
175-
"resilience_to_attacks": float(data["resilience_to_attacks"]),
176-
"algorithm_robustness": float(data["algorithm_robustness"]),
177-
"client_reliability": float(data["client_reliability"]),
178-
"privacy": float(data["privacy_pillar"]),
179-
"technique": float(data["technique"]),
180-
"uncertainty": float(data["uncertainty"]),
181-
"indistinguishability": float(data["indistinguishability"]),
182-
"fairness": float(data["fairness_pillar"]),
183-
"selection_fairness": float(data["selection_fairness"]),
184-
"performance_fairness": float(data["performance_fairness"]),
185-
"class_distribution": float(data["class_distribution"]),
186-
"explainability": float(data["explainability_pillar"]),
187-
"interpretability": float(data["interpretability"]),
188-
"post_hoc_methods": float(data["post_hoc_methods"]),
189-
"accountability": float(data["accountability_pillar"]),
190-
"factsheet_completeness": float(data["factsheet_completeness"]),
191-
"architectural_soundness": float(data["architectural_soundness_pillar"]),
192-
"client_management": float(data["client_management"]),
193-
"optimization": float(data["optimization"]),
194-
"sustainability": float(data["sustainability_pillar"]),
195-
"energy_source": float(data["energy_source"]),
196-
"hardware_efficiency": float(data["hardware_efficiency"]),
197-
"federation_complexity": float(data["federation_complexity"])
198-
}
173+
# weights = {
174+
# "robustness": float(data["robustness_pillar"]),
175+
# "resilience_to_attacks": float(data["resilience_to_attacks"]),
176+
# "algorithm_robustness": float(data["algorithm_robustness"]),
177+
# "client_reliability": float(data["client_reliability"]),
178+
# "privacy": float(data["privacy_pillar"]),
179+
# "technique": float(data["technique"]),
180+
# "uncertainty": float(data["uncertainty"]),
181+
# "indistinguishability": float(data["indistinguishability"]),
182+
# "fairness": float(data["fairness_pillar"]),
183+
# "selection_fairness": float(data["selection_fairness"]),
184+
# "performance_fairness": float(data["performance_fairness"]),
185+
# "class_distribution": float(data["class_distribution"]),
186+
# "explainability": float(data["explainability_pillar"]),
187+
# "interpretability": float(data["interpretability"]),
188+
# "post_hoc_methods": float(data["post_hoc_methods"]),
189+
# "accountability": float(data["accountability_pillar"]),
190+
# "factsheet_completeness": float(data["factsheet_completeness"]),
191+
# "architectural_soundness": float(data["architectural_soundness_pillar"]),
192+
# "client_management": float(data["client_management"]),
193+
# "optimization": float(data["optimization"]),
194+
# "sustainability": float(data["sustainability_pillar"]),
195+
# "energy_source": float(data["energy_source"]),
196+
# "hardware_efficiency": float(data["hardware_efficiency"]),
197+
# "federation_complexity": float(data["federation_complexity"])
198+
# }
199199

200-
trust_metric_manager = TrustMetricManager(self._start_time)
201-
trust_metric_manager.evaluate(experiment_name, weights, use_weights=True)
202-
logging.info("[FER] evaluation done")
200+
# trust_metric_manager = TrustMetricManager(self._start_time)
201+
# trust_metric_manager.evaluate(experiment_name, weights, use_weights=True)
202+
# logging.info("[FER] evaluation done")
203203

204204
async def _process_test_metrics_event(self, tme: TestMetricsEvent):
205205
cur_loss, cur_acc = await tme.get_event_data()

nebula/core/datasets/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
test_set_indices,
2020
local_test_set,
2121
local_test_set_indices,
22+
samples_per_label,
2223
batch_size=32,
2324
num_workers=0,
2425
val_percent=0.1,
@@ -35,6 +36,7 @@ def __init__(
3536
self.num_workers = num_workers
3637
self.val_percent = val_percent
3738
self.seed = seed
39+
self.samples_per_label = samples_per_label
3840

3941
self.model_weight = None
4042

@@ -93,7 +95,6 @@ def train_dataloader(self):
9395
"Train dataset not initialized. Please call setup('fit') before requesting train_dataloader."
9496
)
9597
logging_training.info(f"Train set size: {len(self.data_train)}")
96-
logging.info("[FER] train_dataloader")
9798
return DataLoader(
9899
self.data_train,
99100
batch_size=self.batch_size,
@@ -125,7 +126,6 @@ def test_dataloader(self):
125126
)
126127
logging_training.info(f"Local test set size: {len(self.local_te_subset)}")
127128
logging_training.info(f"Global test set size: {len(self.global_te_subset)}")
128-
logging.info("[FER] test_dataloader")
129129
return [
130130
DataLoader(
131131
self.local_te_subset,

nebula/core/node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
1616
import logging
17+
from collections import Counter
1718

1819
from nebula.config.config import Config
1920
from nebula.core.datasets.cifar10.cifar10 import CIFAR10PartitionHandler
@@ -149,6 +150,7 @@ async def main(config):
149150
dataset = NebulaPartition(handler=handler, config=config)
150151
dataset.load_partition()
151152
dataset.log_partition()
153+
samples_per_label = Counter(dataset.get_train_labels())
152154

153155
datamodule = DataModule(
154156
train_set=dataset.train_set,
@@ -159,6 +161,7 @@ async def main(config):
159161
local_test_set_indices=dataset.local_test_indices,
160162
num_workers=num_workers,
161163
batch_size=batch_size,
164+
samples_per_label = samples_per_label
162165
)
163166

164167
trainer = None

nebula/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,26 @@ def create_docker_network(cls, network_name, subnet=None, prefix=24):
105105
finally:
106106
client.close() # Ensure the Docker client is closed
107107

108+
@classmethod
109+
def check_docker_by_prefix(cls, prefix):
110+
try:
111+
# Connect to Docker client
112+
client = docker.from_env()
113+
114+
containers = client.containers.list(all=True) # `all=True` to include stopped containers
115+
116+
# Iterate through containers and remove those with the matching prefix
117+
for container in containers:
118+
if container.name.startswith(prefix):
119+
return True
120+
121+
return False
122+
123+
except docker.errors.APIError:
124+
logging.exception("Error interacting with Docker")
125+
except Exception:
126+
logging.exception("Unexpected error")
127+
108128
@classmethod
109129
def remove_docker_network(cls, network_name):
110130
try:

0 commit comments

Comments
 (0)