Skip to content

Commit d5bd4dd

Browse files
fix trust
1 parent c8cf61c commit d5bd4dd

File tree

3 files changed

+89
-13
lines changed

3 files changed

+89
-13
lines changed

nebula/addons/trustworthiness/factsheet.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from nebula.core.models.mnist.mlp import MNISTModelMLP
1313
from nebula.core.models.mnist.cnn import MNISTModelCNN
1414
from nebula.addons.trustworthiness.calculation import get_elapsed_time, get_bytes_models, get_bytes_sent_recv, get_avg_loss_accuracy, get_cv, get_clever_score, get_feature_importance_cv
15-
from nebula.addons.trustworthiness.utils import count_class_samples, read_csv, check_field_filled, get_entropy
15+
from nebula.addons.trustworthiness.utils import count_all_class_samples, read_csv, check_field_filled, get_entropy, get_all_data_entropy
1616
# from nebula.core.models.syscall.mlp import SyscallModelMLP
1717

1818
dirname = os.path.dirname(__file__)
@@ -160,16 +160,20 @@ def populate_factsheet_post_train(self, scenario_name, start_time, end_time, cla
160160
train_model_file = f"{files_dir}/participant_1_train_model.pk"
161161
emissions_file = os.path.join(files_dir, "emissions.csv")
162162

163-
# Entropy
164-
i = 0
165-
for file in dataloaders_files:
166-
with open(file, "rb") as file:
167-
dataloader = pickle.load(file)
168-
get_entropy(i, scenario_name, dataloader)
169-
i += 1
163+
# # Entropy
164+
# i = 0
165+
# for file in dataloaders_files:
166+
# with open(file, "rb") as file:
167+
# dataloader = pickle.load(file)
168+
# get_entropy(i, scenario_name, dataloader)
169+
# i += 1
170+
171+
get_all_data_entropy(scenario_name)
170172

171173
with open(f"{files_dir}/entropy.json", "r") as file:
172174
entropy_distribution = json.load(file)
175+
176+
logging.info(f"[ALEX] entropy_distribution: {entropy_distribution}")
173177

174178
values = np.array(list(entropy_distribution.values()))
175179

@@ -197,12 +201,14 @@ def populate_factsheet_post_train(self, scenario_name, start_time, end_time, cla
197201

198202
factsheet["fairness"]["selection_cv"] = 1
199203

200-
count_class_samples(scenario_name, dataloaders_files, class_counter)
204+
count_all_class_samples(scenario_name)
201205

202-
# FER
206+
# # FER
203207

204208
with open(f"{files_dir}/count_class.json", "r") as file:
205209
class_distribution = json.load(file)
210+
211+
logging.info(f"[ALEX] class_distribution: {class_distribution}")
206212

207213
class_samples_sizes = [x for x in class_distribution.values()]
208214
class_imbalance = get_cv(list=class_samples_sizes)
@@ -220,6 +226,8 @@ def populate_factsheet_post_train(self, scenario_name, start_time, end_time, cla
220226
# else:
221227
# model = CIFAR10ModelCNN()
222228

229+
logging.info(f"[ALEX] parte de training hecha")
230+
223231
model.load_state_dict(lightning_model.state_dict())
224232

225233
with open(test_dataloader_file, "rb") as file:

nebula/addons/trustworthiness/trustworthiness.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from nebula.addons.trustworthiness.utils import save_results_csv
1212
from codecarbon import EmissionsTracker
1313
import asyncio
14+
from collections import Counter
1415

1516
""" ##############################
1617
# TRUST WORKLOADS #
@@ -115,6 +116,8 @@ async def _process_experiment_finished_event(self, efe:ExperimentFinishEvent):
115116
# Save model in trustworthy dir
116117
with open(model_file, 'wb') as f:
117118
pickle.dump(self._engine.trainer.model, f)
119+
120+
118121

119122
class TrustWorkloadServer(TrustWorkload):
120123

@@ -199,8 +202,9 @@ async def _generate_factsheet(self, trust_config, experiment_name):
199202
}
200203

201204
trust_metric_manager = TrustMetricManager(self._start_time)
202-
# trust_metric_manager.evaluate(experiment_name, weights, use_weights=True)
203-
logging.info("[FER] evaluation done")
205+
trust_metric_manager.evaluate(experiment_name, weights, use_weights=True)
206+
#logging.info("[FER] evaluation done")
207+
logging.info("Trust work DONE")
204208

205209
async def _process_test_metrics_event(self, tme: TestMetricsEvent):
206210
cur_loss, cur_acc = await tme.get_event_data()
@@ -262,6 +266,10 @@ async def _create_trustworthiness_directory(self):
262266
logging.info("log2")
263267

264268
async def _process_experiment_finish_event(self, efe: ExperimentFinishEvent):
269+
from nebula.addons.trustworthiness.utils import save_class_count_per_participant
270+
class_counter = self._engine.trainer.datamodule.get_samples_per_label()
271+
save_class_count_per_participant(self._experiment_name, class_counter, self._idx)
272+
265273
await self.tw.finish_experiment_role_pre_actions()
266274

267275
last_loss, last_accuracy = self.tw.get_metrics()
@@ -283,7 +291,7 @@ async def _process_experiment_finish_event(self, efe: ExperimentFinishEvent):
283291
save_results_csv(self._experiment_name, self._idx, bytes_sent, bytes_recv, last_loss, last_accuracy)
284292
stop_emissions_tracking_and_save(self._tracker, self._trust_dir_files, self._emissions_file, self._role.value, workload, sample_size)
285293
await self.tw.finish_experiment_role_post_actions(self._trust_config, self._experiment_name)
286-
294+
287295
def _factory_trust_workload(self, role: Role, engine: Engine, idx, trust_files_route) -> TrustWorkload:
288296
trust_workloads = {
289297
Role.TRAINER: TrustWorkloadTrainer,

nebula/addons/trustworthiness/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@
1717
dirname = os.path.dirname(__file__)
1818

1919

20+
def save_class_count_per_participant(experiment_name, class_counter: Counter, idx):
21+
class_count = os.path.join(os.environ.get('NEBULA_LOGS_DIR'), experiment_name, "trustworthiness", f"{str(idx)}_class_count.json")
22+
result = {hashids.encode(int(class_id)): count for class_id, count in class_counter.items()}
23+
with open(class_count, "w") as f:
24+
json.dump(result, f)
25+
26+
def count_all_class_samples(experiment_name):
27+
participant_id = 0
28+
global_class_count = {}
29+
30+
while True:
31+
data_class_count_file = os.path.join(os.environ.get('NEBULA_LOGS_DIR'), experiment_name, "trustworthiness", f"{str(participant_id)}_class_count.json")
32+
33+
if not os.path.exists(data_class_count_file):
34+
break
35+
36+
with open(data_class_count_file, "r") as f:
37+
class_count = json.load(f)
38+
39+
for class_hash, count in class_count.items():
40+
global_class_count[class_hash] = global_class_count.get(class_hash, 0) + count
41+
42+
participant_id += 1
43+
44+
# Guardar conteo total en class_count.json
45+
output_file = os.path.join(os.environ.get('NEBULA_LOGS_DIR'),experiment_name, "trustworthiness", "count_class.json")
46+
47+
with open(output_file, "w") as f:
48+
json.dump(global_class_count, f, indent=2)
49+
2050
def count_class_samples(scenario_name, dataloaders_files, class_counter: Counter = None):
2151
"""
2252
Counts the number of samples by class.
@@ -56,6 +86,35 @@ def count_class_samples(scenario_name, dataloaders_files, class_counter: Counter
5686
json.dump(result, f)
5787

5888

89+
def get_all_data_entropy(experiment_name):
90+
participant_id = 0
91+
data_class_count_file = os.path.join(os.environ.get('NEBULA_LOGS_DIR'), experiment_name, "trustworthiness", f"{str(participant_id)}_class_count.json")
92+
entropy_per_participant = {}
93+
94+
while True:
95+
data_class_count_file = os.path.join(os.environ.get('NEBULA_LOGS_DIR'), experiment_name, "trustworthiness", f"{str(participant_id)}_class_count.json")
96+
97+
if not os.path.exists(data_class_count_file):
98+
break
99+
100+
with open(data_class_count_file, "r") as f:
101+
class_count = json.load(f)
102+
103+
total = sum(class_count.values())
104+
if total == 0:
105+
entropy_value = 0.0
106+
else:
107+
probabilities = [count / total for count in class_count.values()]
108+
entropy_value = entropy(probabilities, base=2)
109+
110+
entropy_per_participant[str(participant_id)] = round(entropy_value, 6)
111+
participant_id += 1
112+
113+
name_file = os.path.join(os.environ.get('NEBULA_LOGS_DIR'),experiment_name, "trustworthiness", "entropy.json")
114+
115+
with open(name_file, "w") as f:
116+
json.dump(entropy_per_participant, f, indent=2)
117+
59118
def get_entropy(client_id, scenario_name, dataloader):
60119
"""
61120
Get the entropy of each client in the scenario.
@@ -72,6 +131,7 @@ def get_entropy(client_id, scenario_name, dataloader):
72131
name_file = os.path.join(os.environ.get('NEBULA_LOGS_DIR'), scenario_name, "trustworthiness", "entropy.json")
73132

74133
if os.path.exists(name_file):
134+
logging.info(f"entropy fiel already exists.. loading.")
75135
with open(name_file, "r") as f:
76136
client_entropy = json.load(f)
77137

0 commit comments

Comments
 (0)