88from nebula .addons .functions import print_msg_box
99from nebula .addons .reporter import Reporter
1010from nebula .core .aggregation .aggregator import create_aggregator , create_target_aggregator
11- from nebula .core .eventmanager import EventManager , event_handler
11+ from nebula .core .eventmanager import EventManager
1212from nebula .core .network .communications import CommunicationsManager
13- from nebula .core .pb import nebula_pb2
1413from nebula .core .utils .locker import Locker
1514
1615logging .getLogger ("requests" ).setLevel (logging .WARNING )
@@ -140,22 +139,18 @@ def __init__(
140139 self ._reporter = Reporter (config = self .config , trainer = self .trainer , cm = self .cm )
141140
142141 self ._event_manager = EventManager (
143- default_callbacks = [
144- self ._discovery_discover_callback ,
145- self ._control_alive_callback ,
146- self ._connection_connect_callback ,
147- self ._connection_disconnect_callback ,
148- self ._federation_ready_callback ,
149- self ._start_federation_callback ,
150- self ._federation_models_included_callback ,
151- ]
142+ # default_callbacks=[
143+ # self._discovery_discover_callback,
144+ # self._control_alive_callback,
145+ # self._connection_connect_callback,
146+ # self._connection_disconnect_callback,
147+ # # self._federation_ready_callback,
148+ # # self._start_federation_callback,
149+ # # self._federation_models_included_callback,
150+ # ]
152151 )
153152
154- # Register additional callbacks
155- self ._event_manager .register_callback (
156- self ._reputation_callback ,
157- # ... add more callbacks here
158- )
153+ self .register_message_events_callbacks ()
159154
160155 @property
161156 def cm (self ):
@@ -207,7 +202,26 @@ def get_federation_setup_lock(self):
207202 def get_round_lock (self ):
208203 return self .round_lock
209204
210- @event_handler (nebula_pb2 .DiscoveryMessage , nebula_pb2 .DiscoveryMessage .Action .DISCOVER )
205+ def register_message_events_callbacks (self ):
206+ me_dict = self .cm .get_messages_events ()
207+ message_events = [
208+ (message_name , message_action )
209+ for (message_name , message_actions ) in me_dict .items ()
210+ for message_action in message_actions
211+ ]
212+ logging .info (f"{ message_events } " )
213+ for event_type , action in message_events :
214+ callback_name = f"_{ event_type } _{ action } _callback"
215+ logging .info (f"Searching callback named: { callback_name } " )
216+ method = getattr (self , callback_name , None )
217+
218+ if callable (method ):
219+ self .event_manager .subscribe ((event_type , action ), method )
220+
221+ async def trigger_event (self , message_event ):
222+ logging .info (f"Publishing MessageEvent: { message_event .message_type } " )
223+ await self .event_manager .publish (message_event )
224+
211225 async def _discovery_discover_callback (self , source , message ):
212226 logging .info (
213227 f"🔍 handle_discovery_message | Trigger | Received discovery message from { source } (network propagation)"
@@ -231,7 +245,6 @@ async def _discovery_discover_callback(self, source, message):
231245 f"🔍 Invalid geolocation received from { source } : latitude={ message .latitude } , longitude={ message .longitude } "
232246 )
233247
234- @event_handler (nebula_pb2 .ControlMessage , nebula_pb2 .ControlMessage .Action .ALIVE )
235248 async def _control_alive_callback (self , source , message ):
236249 logging .info (f"🔧 handle_control_message | Trigger | Received alive message from { source } " )
237250 current_connections = await self .cm .get_addrs_current_connections (myself = True )
@@ -243,38 +256,27 @@ async def _control_alive_callback(self, source, message):
243256 else :
244257 logging .error (f"❗️ Connection { source } not found in connections..." )
245258
246- @event_handler (nebula_pb2 .ConnectionMessage , nebula_pb2 .ConnectionMessage .Action .CONNECT )
247259 async def _connection_connect_callback (self , source , message ):
248260 logging .info (f"🔗 handle_connection_message | Trigger | Received connection message from { source } " )
249261 current_connections = await self .cm .get_addrs_current_connections (myself = True )
250262 if source not in current_connections :
251263 logging .info (f"🔗 handle_connection_message | Trigger | Connecting to { source } " )
252264 await self .cm .connect (source , direct = True )
253265
254- @event_handler (nebula_pb2 .ConnectionMessage , nebula_pb2 .ConnectionMessage .Action .DISCONNECT )
255266 async def _connection_disconnect_callback (self , source , message ):
256267 logging .info (f"🔗 handle_connection_message | Trigger | Received disconnection message from { source } " )
257268 await self .cm .disconnect (source , mutual_disconnection = False )
258269
259- @event_handler (
260- nebula_pb2 .FederationMessage ,
261- nebula_pb2 .FederationMessage .Action .FEDERATION_READY ,
262- )
263- async def _federation_ready_callback (self , source , message ):
270+ async def _federation_federation_ready_callback (self , source , message ):
264271 logging .info (f"📝 handle_federation_message | Trigger | Received ready federation message from { source } " )
265272 if self .config .participant ["device_args" ]["start" ]:
266273 logging .info (f"📝 handle_federation_message | Trigger | Adding ready connection { source } " )
267274 await self .cm .add_ready_connection (source )
268275
269- @event_handler (
270- nebula_pb2 .FederationMessage ,
271- nebula_pb2 .FederationMessage .Action .FEDERATION_START ,
272- )
273- async def _start_federation_callback (self , source , message ):
276+ async def _federation_federation_start_callback (self , source , message ):
274277 logging .info (f"📝 handle_federation_message | Trigger | Received start federation message from { source } " )
275278 await self .create_trainer_module ()
276279
277- @event_handler (nebula_pb2 .FederationMessage , nebula_pb2 .FederationMessage .Action .REPUTATION )
278280 async def _reputation_callback (self , source , message ):
279281 malicious_nodes = message .arguments # List of malicious nodes
280282 if self .with_reputation :
@@ -287,11 +289,7 @@ async def _reputation_callback(self, source, message):
287289 malicious_nodes ,
288290 )
289291
290- @event_handler (
291- nebula_pb2 .FederationMessage ,
292- nebula_pb2 .FederationMessage .Action .FEDERATION_MODELS_INCLUDED ,
293- )
294- async def _federation_models_included_callback (self , source , message ):
292+ async def _federation_federation_models_included_callback (self , source , message ):
295293 logging .info (f"📝 handle_federation_message | Trigger | Received aggregation finished message from { source } " )
296294 try :
297295 await self .cm .get_connections_lock ().acquire_async ()
@@ -346,7 +344,8 @@ async def deploy_federation(self):
346344 while not await self .cm .check_federation_ready ():
347345 await asyncio .sleep (1 )
348346 logging .info ("Sending FEDERATION_START to neighbors..." )
349- message = self .cm .mm .generate_federation_message (nebula_pb2 .FederationMessage .Action .FEDERATION_START )
347+ # message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START)
348+ message = self .cm .create_message ("federation" , "federation_start" )
350349 await self .cm .send_message_to_neighbors (message )
351350 await self .get_federation_ready_lock ().release_async ()
352351 await self .create_trainer_module ()
@@ -355,7 +354,8 @@ async def deploy_federation(self):
355354
356355 else :
357356 logging .info ("Sending FEDERATION_READY to neighbors..." )
358- message = self .cm .mm .generate_federation_message (nebula_pb2 .FederationMessage .Action .FEDERATION_READY )
357+ # message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY)
358+ message = self .cm .create_message ("federation" , "federation_ready" )
359359 await self .cm .send_message_to_neighbors (message )
360360 logging .info ("💤 Waiting until receiving the start signal from the start node" )
361361
@@ -454,6 +454,9 @@ async def _waiting_model_updates(self):
454454 else :
455455 logging .error ("Aggregation finished with no parameters" )
456456
457+ def learning_cycle_finished (self ):
458+ return not (self .round < self .total_rounds )
459+
457460 async def _learning_cycle (self ):
458461 while self .round is not None and self .round < self .total_rounds :
459462 print_msg_box (
@@ -488,8 +491,6 @@ async def _learning_cycle(self):
488491 # End of the learning cycle
489492 self .trainer .on_learning_cycle_end ()
490493 await self .trainer .test ()
491- self .round = None
492- self .total_rounds = None
493494 print_msg_box (
494495 msg = "Federated Learning process has been completed." ,
495496 indent = 2 ,
@@ -507,8 +508,7 @@ async def _learning_cycle(self):
507508 while not self .cm .check_finished_experiment ():
508509 await asyncio .sleep (1 )
509510
510- # Enable loggin info
511- logging .getLogger ().disabled = True
511+ await asyncio .sleep (5 )
512512
513513 # Kill itself
514514 if self .config .participant ["scenario_args" ]["deployment" ] == "docker" :
@@ -564,9 +564,10 @@ def reputation_calculation(self, aggregated_models_weights):
564564
565565 async def send_reputation (self , malicious_nodes ):
566566 logging .info (f"Sending REPUTATION to the rest of the topology: { malicious_nodes } " )
567- message = self .cm .mm .generate_federation_message (
568- nebula_pb2 .FederationMessage .Action .REPUTATION , malicious_nodes
569- )
567+ # message = self.cm.mm.generate_federation_message(
568+ # nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes
569+ # )
570+ message = self .cm .create_message ("federation" , "reputation" , arguments = [str (arg ) for arg in (malicious_nodes )])
570571 await self .cm .send_message_to_neighbors (message )
571572
572573
0 commit comments