Skip to content

Commit cd002d7

Browse files
authored
Optimization/messages (#32)
* messages optimized * fix: solved logging error when the experiment is finished * fix: websockets on monitor and layout * make checl
1 parent ac6adae commit cd002d7

File tree

10 files changed

+374
-223
lines changed

10 files changed

+374
-223
lines changed

docs/_prebuilt/developerguide.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,10 +707,7 @@ The new aggregator must inherit from the **Aggregator** class. You can use **Fed
707707
logging.info(
708708
f"🔄 include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}"
709709
)
710-
message = self.cm.mm.generate_federation_message(
711-
nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED,
712-
[self.engine.get_round()],
713-
)
710+
message = self.cm.create_message("federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]])
714711
await self.cm.send_message_to_neighbors(message)
715712

716713
return

nebula/core/aggregation/aggregator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from abc import ABC, abstractmethod
44
from functools import partial
55

6-
from nebula.core.pb import nebula_pb2
76
from nebula.core.utils.locker import Locker
87

98

@@ -195,9 +194,12 @@ async def include_model_in_buffer(self, model, weight, source=None, round=None,
195194
logging.info(
196195
f"🔄 include_model_in_buffer | Broadcasting MODELS_INCLUDED for round {self.engine.get_round()}"
197196
)
198-
message = self.cm.mm.generate_federation_message(
199-
nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED,
200-
[self.engine.get_round()],
197+
# message = self.cm.mm.generate_federation_message(
198+
# nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED,
199+
# [self.engine.get_round()],
200+
# )
201+
message = self.cm.create_message(
202+
"federation", "federation_models_included", [str(arg) for arg in [self.engine.get_round()]]
201203
)
202204
await self.cm.send_message_to_neighbors(message)
203205

nebula/core/engine.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from nebula.addons.functions import print_msg_box
99
from nebula.addons.reporter import Reporter
1010
from nebula.core.aggregation.aggregator import create_aggregator, create_target_aggregator
11-
from nebula.core.eventmanager import EventManager, event_handler
11+
from nebula.core.eventmanager import EventManager
1212
from nebula.core.network.communications import CommunicationsManager
13-
from nebula.core.pb import nebula_pb2
1413
from nebula.core.utils.locker import Locker
1514

1615
logging.getLogger("requests").setLevel(logging.WARNING)
@@ -140,22 +139,18 @@ def __init__(
140139
self._reporter = Reporter(config=self.config, trainer=self.trainer, cm=self.cm)
141140

142141
self._event_manager = EventManager(
143-
default_callbacks=[
144-
self._discovery_discover_callback,
145-
self._control_alive_callback,
146-
self._connection_connect_callback,
147-
self._connection_disconnect_callback,
148-
self._federation_ready_callback,
149-
self._start_federation_callback,
150-
self._federation_models_included_callback,
151-
]
142+
# default_callbacks=[
143+
# self._discovery_discover_callback,
144+
# self._control_alive_callback,
145+
# self._connection_connect_callback,
146+
# self._connection_disconnect_callback,
147+
# # self._federation_ready_callback,
148+
# # self._start_federation_callback,
149+
# # self._federation_models_included_callback,
150+
# ]
152151
)
153152

154-
# Register additional callbacks
155-
self._event_manager.register_callback(
156-
self._reputation_callback,
157-
# ... add more callbacks here
158-
)
153+
self.register_message_events_callbacks()
159154

160155
@property
161156
def cm(self):
@@ -207,7 +202,26 @@ def get_federation_setup_lock(self):
207202
def get_round_lock(self):
208203
return self.round_lock
209204

210-
@event_handler(nebula_pb2.DiscoveryMessage, nebula_pb2.DiscoveryMessage.Action.DISCOVER)
205+
def register_message_events_callbacks(self):
206+
me_dict = self.cm.get_messages_events()
207+
message_events = [
208+
(message_name, message_action)
209+
for (message_name, message_actions) in me_dict.items()
210+
for message_action in message_actions
211+
]
212+
logging.info(f"{message_events}")
213+
for event_type, action in message_events:
214+
callback_name = f"_{event_type}_{action}_callback"
215+
logging.info(f"Searching callback named: {callback_name}")
216+
method = getattr(self, callback_name, None)
217+
218+
if callable(method):
219+
self.event_manager.subscribe((event_type, action), method)
220+
221+
async def trigger_event(self, message_event):
222+
logging.info(f"Publishing MessageEvent: {message_event.message_type}")
223+
await self.event_manager.publish(message_event)
224+
211225
async def _discovery_discover_callback(self, source, message):
212226
logging.info(
213227
f"🔍 handle_discovery_message | Trigger | Received discovery message from {source} (network propagation)"
@@ -231,7 +245,6 @@ async def _discovery_discover_callback(self, source, message):
231245
f"🔍 Invalid geolocation received from {source}: latitude={message.latitude}, longitude={message.longitude}"
232246
)
233247

234-
@event_handler(nebula_pb2.ControlMessage, nebula_pb2.ControlMessage.Action.ALIVE)
235248
async def _control_alive_callback(self, source, message):
236249
logging.info(f"🔧 handle_control_message | Trigger | Received alive message from {source}")
237250
current_connections = await self.cm.get_addrs_current_connections(myself=True)
@@ -243,38 +256,27 @@ async def _control_alive_callback(self, source, message):
243256
else:
244257
logging.error(f"❗️ Connection {source} not found in connections...")
245258

246-
@event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.CONNECT)
247259
async def _connection_connect_callback(self, source, message):
248260
logging.info(f"🔗 handle_connection_message | Trigger | Received connection message from {source}")
249261
current_connections = await self.cm.get_addrs_current_connections(myself=True)
250262
if source not in current_connections:
251263
logging.info(f"🔗 handle_connection_message | Trigger | Connecting to {source}")
252264
await self.cm.connect(source, direct=True)
253265

254-
@event_handler(nebula_pb2.ConnectionMessage, nebula_pb2.ConnectionMessage.Action.DISCONNECT)
255266
async def _connection_disconnect_callback(self, source, message):
256267
logging.info(f"🔗 handle_connection_message | Trigger | Received disconnection message from {source}")
257268
await self.cm.disconnect(source, mutual_disconnection=False)
258269

259-
@event_handler(
260-
nebula_pb2.FederationMessage,
261-
nebula_pb2.FederationMessage.Action.FEDERATION_READY,
262-
)
263-
async def _federation_ready_callback(self, source, message):
270+
async def _federation_federation_ready_callback(self, source, message):
264271
logging.info(f"📝 handle_federation_message | Trigger | Received ready federation message from {source}")
265272
if self.config.participant["device_args"]["start"]:
266273
logging.info(f"📝 handle_federation_message | Trigger | Adding ready connection {source}")
267274
await self.cm.add_ready_connection(source)
268275

269-
@event_handler(
270-
nebula_pb2.FederationMessage,
271-
nebula_pb2.FederationMessage.Action.FEDERATION_START,
272-
)
273-
async def _start_federation_callback(self, source, message):
276+
async def _federation_federation_start_callback(self, source, message):
274277
logging.info(f"📝 handle_federation_message | Trigger | Received start federation message from {source}")
275278
await self.create_trainer_module()
276279

277-
@event_handler(nebula_pb2.FederationMessage, nebula_pb2.FederationMessage.Action.REPUTATION)
278280
async def _reputation_callback(self, source, message):
279281
malicious_nodes = message.arguments # List of malicious nodes
280282
if self.with_reputation:
@@ -287,11 +289,7 @@ async def _reputation_callback(self, source, message):
287289
malicious_nodes,
288290
)
289291

290-
@event_handler(
291-
nebula_pb2.FederationMessage,
292-
nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED,
293-
)
294-
async def _federation_models_included_callback(self, source, message):
292+
async def _federation_federation_models_included_callback(self, source, message):
295293
logging.info(f"📝 handle_federation_message | Trigger | Received aggregation finished message from {source}")
296294
try:
297295
await self.cm.get_connections_lock().acquire_async()
@@ -346,7 +344,8 @@ async def deploy_federation(self):
346344
while not await self.cm.check_federation_ready():
347345
await asyncio.sleep(1)
348346
logging.info("Sending FEDERATION_START to neighbors...")
349-
message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START)
347+
# message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_START)
348+
message = self.cm.create_message("federation", "federation_start")
350349
await self.cm.send_message_to_neighbors(message)
351350
await self.get_federation_ready_lock().release_async()
352351
await self.create_trainer_module()
@@ -355,7 +354,8 @@ async def deploy_federation(self):
355354

356355
else:
357356
logging.info("Sending FEDERATION_READY to neighbors...")
358-
message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY)
357+
# message = self.cm.mm.generate_federation_message(nebula_pb2.FederationMessage.Action.FEDERATION_READY)
358+
message = self.cm.create_message("federation", "federation_ready")
359359
await self.cm.send_message_to_neighbors(message)
360360
logging.info("💤 Waiting until receiving the start signal from the start node")
361361

@@ -454,6 +454,9 @@ async def _waiting_model_updates(self):
454454
else:
455455
logging.error("Aggregation finished with no parameters")
456456

457+
def learning_cycle_finished(self):
458+
return not (self.round < self.total_rounds)
459+
457460
async def _learning_cycle(self):
458461
while self.round is not None and self.round < self.total_rounds:
459462
print_msg_box(
@@ -488,8 +491,6 @@ async def _learning_cycle(self):
488491
# End of the learning cycle
489492
self.trainer.on_learning_cycle_end()
490493
await self.trainer.test()
491-
self.round = None
492-
self.total_rounds = None
493494
print_msg_box(
494495
msg="Federated Learning process has been completed.",
495496
indent=2,
@@ -507,8 +508,7 @@ async def _learning_cycle(self):
507508
while not self.cm.check_finished_experiment():
508509
await asyncio.sleep(1)
509510

510-
# Enable loggin info
511-
logging.getLogger().disabled = True
511+
await asyncio.sleep(5)
512512

513513
# Kill itself
514514
if self.config.participant["scenario_args"]["deployment"] == "docker":
@@ -564,9 +564,10 @@ def reputation_calculation(self, aggregated_models_weights):
564564

565565
async def send_reputation(self, malicious_nodes):
566566
logging.info(f"Sending REPUTATION to the rest of the topology: {malicious_nodes}")
567-
message = self.cm.mm.generate_federation_message(
568-
nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes
569-
)
567+
# message = self.cm.mm.generate_federation_message(
568+
# nebula_pb2.FederationMessage.Action.REPUTATION, malicious_nodes
569+
# )
570+
message = self.cm.create_message("federation", "reputation", arguments=[str(arg) for arg in (malicious_nodes)])
570571
await self.cm.send_message_to_neighbors(message)
571572

572573

nebula/core/eventmanager.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections import defaultdict
55
from functools import wraps
66

7+
from nebula.core.network.messages import MessageEvent
8+
79

810
def event_handler(message_type, action):
911
"""Decorator for registering an event handler."""
@@ -33,6 +35,30 @@ class EventManager:
3335
def __init__(self, default_callbacks=None):
3436
self._event_callbacks = defaultdict(list)
3537
self._register_default_callbacks(default_callbacks or [])
38+
self._subscribers: dict[tuple[str, str], list] = {}
39+
40+
def subscribe(self, event_type: tuple[str, str], callback: callable):
41+
"""Register a callback for a specific event type."""
42+
if event_type not in self._subscribers:
43+
self._subscribers[event_type] = []
44+
self._subscribers[event_type].append(callback)
45+
logging.info(f"EventManager | Subscribed callback for event: {event_type}")
46+
47+
async def publish(self, message_event: MessageEvent):
48+
"""Trigger all callbacks registered for a specific event type."""
49+
event_type = message_event.message_type
50+
if event_type not in self._subscribers:
51+
logging.error(f"EventManager | No subscribers for event: {event_type}")
52+
return
53+
54+
for callback in self._subscribers[event_type]:
55+
try:
56+
logging.info(
57+
f"EventManager | Triggering callback for event: {event_type}, from source: {message_event.source}"
58+
)
59+
await callback(message_event.source, message_event.message)
60+
except Exception as e:
61+
logging.exception(f"EventManager | Error in callback for event {event_type}: {e}")
3662

3763
def _register_default_callbacks(self, default_callbacks):
3864
"""Registers default callbacks for events."""

nebula/core/network/actions.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from enum import Enum
2+
3+
from nebula.core.pb import nebula_pb2
4+
5+
6+
class ConnectionAction(Enum):
7+
CONNECT = nebula_pb2.ConnectionMessage.Action.CONNECT
8+
DISCONNECT = nebula_pb2.ConnectionMessage.Action.DISCONNECT
9+
10+
11+
class FederationAction(Enum):
12+
FEDERATION_START = nebula_pb2.FederationMessage.Action.FEDERATION_START
13+
REPUTATION = nebula_pb2.FederationMessage.Action.REPUTATION
14+
FEDERATION_MODELS_INCLUDED = nebula_pb2.FederationMessage.Action.FEDERATION_MODELS_INCLUDED
15+
FEDERATION_READY = nebula_pb2.FederationMessage.Action.FEDERATION_READY
16+
17+
18+
class DiscoveryAction(Enum):
19+
DISCOVER = nebula_pb2.DiscoveryMessage.Action.DISCOVER
20+
REGISTER = nebula_pb2.DiscoveryMessage.Action.REGISTER
21+
DEREGISTER = nebula_pb2.DiscoveryMessage.Action.DEREGISTER
22+
23+
24+
class ControlAction(Enum):
25+
ALIVE = nebula_pb2.ControlMessage.Action.ALIVE
26+
OVERHEAD = nebula_pb2.ControlMessage.Action.OVERHEAD
27+
MOBILITY = nebula_pb2.ControlMessage.Action.MOBILITY
28+
RECOVERY = nebula_pb2.ControlMessage.Action.RECOVERY
29+
WEAK_LINK = nebula_pb2.ControlMessage.Action.WEAK_LINK
30+
31+
32+
ACTION_CLASSES = {
33+
"connection": ConnectionAction,
34+
"federation": FederationAction,
35+
"discovery": DiscoveryAction,
36+
"control": ControlAction,
37+
}
38+
39+
40+
def get_action_name_from_value(message_type: str, action_value: int) -> str:
41+
# Obtener el Enum correspondiente al tipo de mensaje
42+
enum_class = ACTION_CLASSES.get(message_type)
43+
if not enum_class:
44+
raise ValueError(f"Unknown message type: {message_type}")
45+
46+
# Buscar el nombre de la acción a partir del valor
47+
for action in enum_class:
48+
if action.value == action_value:
49+
return action.name.lower() # Convertimos a lowercase para mantener el formato "late_connect"
50+
51+
raise ValueError(f"Unknown action value {action_value} for message type {message_type}")
52+
53+
54+
def get_actions_names(message_type: str):
55+
message_actions = ACTION_CLASSES.get(message_type)
56+
if not message_actions:
57+
raise ValueError(f"Invalid message type: {message_type}")
58+
59+
return [action.name.lower() for action in message_actions]
60+
61+
62+
def factory_message_action(message_type: str, action: str):
63+
message_actions = ACTION_CLASSES.get(message_type)
64+
65+
if message_actions:
66+
normalized_action = action.upper()
67+
enum_action = message_actions[normalized_action]
68+
# logging.info(f"Message action: {enum_action}, value: {enum_action.value}")
69+
return enum_action.value
70+
else:
71+
return None

0 commit comments

Comments
 (0)