Skip to content

Commit 4d968ff

Browse files
Improve comms efficiency
1 parent 8aa0c21 commit 4d968ff

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

nebula/core/aggregation/aggregator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ async def get_aggregation(self):
197197

198198
if self._waiting_global_update and len(self._pending_models_to_aggregate) == 1:
199199
logging.info(f"🔄 get_aggregation | Received an global model. Overwriting my model with the aggregated model.")
200-
return next(iter(self._pending_models_to_aggregate.values()))[0]
200+
aggregated_model = next(iter(self._pending_models_to_aggregate.values()))[0]
201+
self._pending_models_to_aggregate.clear()
202+
return aggregated_model
201203

202204
unique_nodes_involved = set(node for key in self._pending_models_to_aggregate for node in key.split())
203205

@@ -206,8 +208,10 @@ async def get_aggregation(self):
206208
logging.info(f"🔄 get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}")
207209
else:
208210
logging.info(f"🔄 get_aggregation | All models accounted for, proceeding with aggregation.")
209-
210-
return self.run_aggregation(self._pending_models_to_aggregate)
211+
212+
aggregated_result = self.run_aggregation(self._pending_models_to_aggregate)
213+
self._pending_models_to_aggregate.clear()
214+
return aggregated_result
211215

212216
async def include_next_model_in_buffer(self, model, weight, source=None, round=None):
213217
logging.info(f"🔄 include_next_model_in_buffer | source={source} | round={round} | weight={weight}")

nebula/core/network/connection.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
@dataclass
2020
class MessageChunk:
21+
__slots__ = ['index', 'data', 'is_last']
2122
index: int
2223
data: bytes
2324
is_last: bool
@@ -58,7 +59,7 @@ def __init__(
5859
self.loop = asyncio.get_event_loop()
5960
self.read_task = None
6061
self.process_task = None
61-
self.pending_messages_queue = asyncio.Queue()
62+
self.pending_messages_queue = asyncio.Queue(maxsize=100)
6263
self.message_buffers: Dict[bytes, Dict[int, MessageChunk]] = {}
6364

6465
self.EOT_CHAR = b"\x00\x00\x00\x04"
@@ -239,7 +240,7 @@ async def _send_chunks(self, message_id: bytes, data: bytes) -> None:
239240
self.writer.write(chunk_with_header)
240241
await self.writer.drain()
241242

242-
logging.debug(f"Sent message {message_id.hex()} | chunk {chunk_index+1}/{num_chunks} | size: {len(chunk)} bytes")
243+
logging.debug(f"Sent message {message_id.hex()} | chunk {chunk_index+1}/{num_chunks} | size: {len(chunk)} bytes")
243244

244245
def _calculate_chunk_size(self, data_size: int) -> int:
245246
if data_size <= 1024: # 1 KB
@@ -250,14 +251,18 @@ def _calculate_chunk_size(self, data_size: int) -> int:
250251
return 1024 * 1024 # 1 MB
251252

252253
async def handle_incoming_message(self) -> None:
254+
reusable_buffer = bytearray(self.MAX_CHUNK_SIZE)
253255
try:
254256
while True:
257+
if self.pending_messages_queue.full():
258+
await asyncio.sleep(0.1) # Wait a bit if the queue is full to create backpressure
259+
continue
255260
header = await self._read_exactly(self.HEADER_SIZE)
256261
message_id, chunk_index, is_last_chunk = self._parse_header(header)
257262

258-
chunk_data = await self._read_chunk()
263+
chunk_data = await self._read_chunk(reusable_buffer)
259264
self._store_chunk(message_id, chunk_index, chunk_data, is_last_chunk)
260-
logging.debug(f"Received chunk {chunk_index} of message {message_id.hex()} | size: {len(chunk_data)} bytes")
265+
logging.debug(f"Received chunk {chunk_index} of message {message_id.hex()} | size: {len(chunk_data)} bytes")
261266

262267
if is_last_chunk:
263268
await self._process_complete_message(message_id)
@@ -293,32 +298,40 @@ def _parse_header(self, header: bytes) -> tuple[bytes, int, bool]:
293298
is_last_chunk = header[20] == 1
294299
return message_id, chunk_index, is_last_chunk
295300

296-
async def _read_chunk(self) -> bytes:
301+
async def _read_chunk(self, buffer: bytearray = None) -> bytes:
302+
if buffer is None:
303+
buffer = bytearray(self.MAX_CHUNK_SIZE)
304+
297305
chunk_size_bytes = await self._read_exactly(4)
298306
chunk_size = int.from_bytes(chunk_size_bytes, "big")
299307

300308
if chunk_size > self.MAX_CHUNK_SIZE:
301309
raise ValueError(f"Chunk size {chunk_size} exceeds MAX_CHUNK_SIZE {self.MAX_CHUNK_SIZE}")
302310

303311
chunk = await self._read_exactly(chunk_size)
312+
buffer[:chunk_size] = chunk
304313
eot = await self._read_exactly(len(self.EOT_CHAR))
305314

306315
if eot != self.EOT_CHAR:
307316
raise ValueError("Invalid EOT character")
308317

309-
return chunk
318+
return memoryview(buffer)[:chunk_size]
310319

311-
def _store_chunk(self, message_id: bytes, chunk_index: int, data: bytes, is_last: bool) -> None:
320+
def _store_chunk(self, message_id: bytes, chunk_index: int, buffer: memoryview, is_last: bool) -> None:
312321
if message_id not in self.message_buffers:
313322
self.message_buffers[message_id] = {}
314-
self.message_buffers[message_id][chunk_index] = MessageChunk(chunk_index, data, is_last)
315-
logging.debug(f"Stored chunk {chunk_index} of message {message_id.hex()} | size: {len(data)} bytes")
323+
try:
324+
self.message_buffers[message_id][chunk_index] = MessageChunk(chunk_index, buffer.tobytes(), is_last)
325+
# logging.debug(f"Stored chunk {chunk_index} of message {message_id.hex()} | size: {len(data)} bytes")
326+
except Exception as e:
327+
if message_id in self.message_buffers:
328+
del self.message_buffers[message_id]
329+
logging.error(f"Error storing chunk {chunk_index} for message {message_id.hex()}: {e}")
316330

317331
async def _process_complete_message(self, message_id: bytes) -> None:
318332
chunks = sorted(self.message_buffers[message_id].values(), key=lambda x: x.index)
319333
complete_message = b"".join(chunk.data for chunk in chunks)
320334
del self.message_buffers[message_id]
321-
gc.collect()
322335

323336
data_type_prefix = complete_message[:4]
324337
message_content = complete_message[4:]
@@ -328,8 +341,8 @@ async def _process_complete_message(self, message_id: bytes) -> None:
328341
if message_content is None:
329342
return
330343

331-
await self.pending_messages_queue.put((data_type_prefix, message_content))
332-
logging.debug(f"Processed complete message {message_id.hex()} | total size: {len(complete_message)} bytes")
344+
await self.pending_messages_queue.put((data_type_prefix, memoryview(message_content)))
345+
logging.debug(f"Processed complete message {message_id.hex()} | total size: {len(complete_message)} bytes")
333346

334347
def _decompress(self, data: bytes, compression: str) -> Optional[bytes]:
335348
if compression == "zlib":
@@ -360,7 +373,7 @@ async def process_message_queue(self) -> None:
360373

361374
async def _handle_message(self, data_type_prefix: bytes, message: bytes) -> None:
362375
if data_type_prefix == self.DATA_TYPE_PREFIXES["pb"]:
363-
logging.debug("Received a protobuf message")
376+
logging.debug("Received a protobuf message")
364377
asyncio.create_task(self.cm.handle_incoming_message(message, self.addr), name=f"Connection {self.addr} message handler")
365378
elif data_type_prefix == self.DATA_TYPE_PREFIXES["string"]:
366379
logging.debug(f"Received string message: {message.decode('utf-8')}")

0 commit comments

Comments
 (0)