Skip to content

Commit 754ca88

Browse files
update round count and improve code
1 parent 679f90f commit 754ca88

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

nebula/core/engine.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,14 @@ async def _federation_federation_start_callback(self, source, message):
279279

280280
async def _reputation_callback(self, source, message):
281281
malicious_nodes = message.arguments # List of malicious nodes
282-
if self.with_reputation:
283-
if len(malicious_nodes) > 0 and not self._is_malicious:
284-
if self.is_dynamic_topology:
285-
await self._disrupt_connection_using_reputation(malicious_nodes)
286-
if self.is_dynamic_aggregation and self.aggregator != self.target_aggregation:
287-
await self._dynamic_aggregator(
288-
self.aggregator.get_nodes_pending_models_to_aggregate(),
289-
malicious_nodes,
290-
)
282+
if self.with_reputation and len(malicious_nodes) > 0 and not self._is_malicious:
283+
if self.is_dynamic_topology:
284+
await self._disrupt_connection_using_reputation(malicious_nodes)
285+
if self.is_dynamic_aggregation and self.aggregator != self.target_aggregation:
286+
await self._dynamic_aggregator(
287+
self.aggregator.get_nodes_pending_models_to_aggregate(),
288+
malicious_nodes,
289+
)
291290

292291
async def _federation_federation_models_included_callback(self, source, message):
293292
logging.info(f"📝 handle_federation_message | Trigger | Received aggregation finished message from {source}")
@@ -366,7 +365,7 @@ async def _start_learning(self):
366365
self.total_rounds = self.config.participant["scenario_args"]["rounds"]
367366
epochs = self.config.participant["training_args"]["epochs"]
368367
await self.get_round_lock().acquire_async()
369-
self.round = 1
368+
self.round = 0
370369
await self.get_round_lock().release_async()
371370
await self.learning_cycle_lock.release_async()
372371
print_msg_box(
@@ -433,7 +432,7 @@ async def _dynamic_aggregator(self, aggregated_models_weights, malicious_nodes):
433432
self.aggregator = self.target_aggregation
434433
await self.aggregator.update_federation_nodes(self.federation_nodes)
435434

436-
for subnodes in aggregated_models_weights.keys():
435+
for subnodes in aggregated_models_weights:
437436
sublist = subnodes.split()
438437
(submodel, weights) = aggregated_models_weights[subnodes]
439438
for node in sublist:
@@ -458,9 +457,9 @@ def learning_cycle_finished(self):
458457
return not (self.round < self.total_rounds)
459458

460459
async def _learning_cycle(self):
461-
while self.round is not None and self.round <= self.total_rounds:
460+
while self.round is not None and self.round < self.total_rounds:
462461
print_msg_box(
463-
msg=f"Round {self.round} of {self.total_rounds} started.",
462+
msg=f"Round {self.round} of {self.total_rounds - 1} started (max. {self.total_rounds} rounds)",
464463
indent=2,
465464
title="Round information",
466465
)
@@ -476,13 +475,13 @@ async def _learning_cycle(self):
476475

477476
await self.get_round_lock().acquire_async()
478477
print_msg_box(
479-
msg=f"Round {self.round} of {self.total_rounds} finished.",
478+
msg=f"Round {self.round} of {self.total_rounds - 1} finished (max. {self.total_rounds} rounds)",
480479
indent=2,
481480
title="Round information",
482481
)
483482
await self.aggregator.reset()
484483
self.trainer.on_round_end()
485-
self.round = self.round + 1
484+
self.round += 1
486485
self.config.participant["federation_args"]["round"] = (
487486
self.round
488487
) # Set current round in config (send to the controller)
@@ -492,7 +491,7 @@ async def _learning_cycle(self):
492491
self.trainer.on_learning_cycle_end()
493492
await self.trainer.test()
494493
print_msg_box(
495-
msg="Federated Learning process has been completed.",
494+
msg=f"FL process has been completed successfully (max. {self.total_rounds} rounds reached)",
496495
indent=2,
497496
title="End of the experiment",
498497
)
@@ -529,7 +528,7 @@ def reputation_calculation(self, aggregated_models_weights):
529528
loss_threshold = 0.5
530529

531530
current_models = {}
532-
for subnodes in aggregated_models_weights.keys():
531+
for subnodes in aggregated_models_weights:
533532
sublist = subnodes.split()
534533
submodel = aggregated_models_weights[subnodes][0]
535534
for node in sublist:
@@ -564,9 +563,6 @@ def reputation_calculation(self, aggregated_models_weights):
564563

565564
async def send_reputation(self, malicious_nodes):
566565
logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}")
567-
# message = self.cm.mm.generate_federation_message(
568-
# nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes
569-
# )
570566
message = self.cm.create_message("federation", "reputation", arguments=[str(arg) for arg in (malicious_nodes)])
571567
await self.cm.send_message_to_neighbors(message)
572568

@@ -593,7 +589,7 @@ def __init__(
593589
async def _extended_learning_cycle(self):
594590
try:
595591
await self.attack.attack()
596-
except:
592+
except Exception:
597593
attack_name = self.config.participant["adversarial_args"]["attacks"]
598594
logging.exception(f"Attack {attack_name} failed")
599595

0 commit comments

Comments
 (0)