18
18
19
19
@dataclass
20
20
class MessageChunk :
21
+ __slots__ = ['index' , 'data' , 'is_last' ]
21
22
index : int
22
23
data : bytes
23
24
is_last : bool
@@ -58,7 +59,7 @@ def __init__(
58
59
self .loop = asyncio .get_event_loop ()
59
60
self .read_task = None
60
61
self .process_task = None
61
- self .pending_messages_queue = asyncio .Queue ()
62
+ self .pending_messages_queue = asyncio .Queue (maxsize = 100 )
62
63
self .message_buffers : Dict [bytes , Dict [int , MessageChunk ]] = {}
63
64
64
65
self .EOT_CHAR = b"\x00 \x00 \x00 \x04 "
@@ -239,7 +240,7 @@ async def _send_chunks(self, message_id: bytes, data: bytes) -> None:
239
240
self .writer .write (chunk_with_header )
240
241
await self .writer .drain ()
241
242
242
- logging .debug (f"Sent message { message_id .hex ()} | chunk { chunk_index + 1 } /{ num_chunks } | size: { len (chunk )} bytes" )
243
+ # logging.debug(f"Sent message {message_id.hex()} | chunk {chunk_index+1}/{num_chunks} | size: {len(chunk)} bytes")
243
244
244
245
def _calculate_chunk_size (self , data_size : int ) -> int :
245
246
if data_size <= 1024 : # 1 KB
@@ -250,14 +251,18 @@ def _calculate_chunk_size(self, data_size: int) -> int:
250
251
return 1024 * 1024 # 1 MB
251
252
252
253
async def handle_incoming_message (self ) -> None :
254
+ reusable_buffer = bytearray (self .MAX_CHUNK_SIZE )
253
255
try :
254
256
while True :
257
+ if self .pending_messages_queue .full ():
258
+ await asyncio .sleep (0.1 ) # Wait a bit if the queue is full to create backpressure
259
+ continue
255
260
header = await self ._read_exactly (self .HEADER_SIZE )
256
261
message_id , chunk_index , is_last_chunk = self ._parse_header (header )
257
262
258
- chunk_data = await self ._read_chunk ()
263
+ chunk_data = await self ._read_chunk (reusable_buffer )
259
264
self ._store_chunk (message_id , chunk_index , chunk_data , is_last_chunk )
260
- logging .debug (f"Received chunk { chunk_index } of message { message_id .hex ()} | size: { len (chunk_data )} bytes" )
265
+ # logging.debug(f"Received chunk {chunk_index} of message {message_id.hex()} | size: {len(chunk_data)} bytes")
261
266
262
267
if is_last_chunk :
263
268
await self ._process_complete_message (message_id )
@@ -293,32 +298,40 @@ def _parse_header(self, header: bytes) -> tuple[bytes, int, bool]:
293
298
is_last_chunk = header [20 ] == 1
294
299
return message_id , chunk_index , is_last_chunk
295
300
296
- async def _read_chunk (self ) -> bytes :
301
+ async def _read_chunk (self , buffer : bytearray = None ) -> bytes :
302
+ if buffer is None :
303
+ buffer = bytearray (self .MAX_CHUNK_SIZE )
304
+
297
305
chunk_size_bytes = await self ._read_exactly (4 )
298
306
chunk_size = int .from_bytes (chunk_size_bytes , "big" )
299
307
300
308
if chunk_size > self .MAX_CHUNK_SIZE :
301
309
raise ValueError (f"Chunk size { chunk_size } exceeds MAX_CHUNK_SIZE { self .MAX_CHUNK_SIZE } " )
302
310
303
311
chunk = await self ._read_exactly (chunk_size )
312
+ buffer [:chunk_size ] = chunk
304
313
eot = await self ._read_exactly (len (self .EOT_CHAR ))
305
314
306
315
if eot != self .EOT_CHAR :
307
316
raise ValueError ("Invalid EOT character" )
308
317
309
- return chunk
318
+ return memoryview ( buffer )[: chunk_size ]
310
319
311
- def _store_chunk (self , message_id : bytes , chunk_index : int , data : bytes , is_last : bool ) -> None :
320
+ def _store_chunk (self , message_id : bytes , chunk_index : int , buffer : memoryview , is_last : bool ) -> None :
312
321
if message_id not in self .message_buffers :
313
322
self .message_buffers [message_id ] = {}
314
- self .message_buffers [message_id ][chunk_index ] = MessageChunk (chunk_index , data , is_last )
315
- logging .debug (f"Stored chunk { chunk_index } of message { message_id .hex ()} | size: { len (data )} bytes" )
323
+ try :
324
+ self .message_buffers [message_id ][chunk_index ] = MessageChunk (chunk_index , buffer .tobytes (), is_last )
325
+ # logging.debug(f"Stored chunk {chunk_index} of message {message_id.hex()} | size: {len(data)} bytes")
326
+ except Exception as e :
327
+ if message_id in self .message_buffers :
328
+ del self .message_buffers [message_id ]
329
+ logging .error (f"Error storing chunk { chunk_index } for message { message_id .hex ()} : { e } " )
316
330
317
331
async def _process_complete_message (self , message_id : bytes ) -> None :
318
332
chunks = sorted (self .message_buffers [message_id ].values (), key = lambda x : x .index )
319
333
complete_message = b"" .join (chunk .data for chunk in chunks )
320
334
del self .message_buffers [message_id ]
321
- gc .collect ()
322
335
323
336
data_type_prefix = complete_message [:4 ]
324
337
message_content = complete_message [4 :]
@@ -328,8 +341,8 @@ async def _process_complete_message(self, message_id: bytes) -> None:
328
341
if message_content is None :
329
342
return
330
343
331
- await self .pending_messages_queue .put ((data_type_prefix , message_content ))
332
- logging .debug (f"Processed complete message { message_id .hex ()} | total size: { len (complete_message )} bytes" )
344
+ await self .pending_messages_queue .put ((data_type_prefix , memoryview ( message_content ) ))
345
+ # logging.debug(f"Processed complete message {message_id.hex()} | total size: {len(complete_message)} bytes")
333
346
334
347
def _decompress (self , data : bytes , compression : str ) -> Optional [bytes ]:
335
348
if compression == "zlib" :
@@ -360,7 +373,7 @@ async def process_message_queue(self) -> None:
360
373
361
374
async def _handle_message (self , data_type_prefix : bytes , message : bytes ) -> None :
362
375
if data_type_prefix == self .DATA_TYPE_PREFIXES ["pb" ]:
363
- logging .debug ("Received a protobuf message" )
376
+ # logging.debug("Received a protobuf message")
364
377
asyncio .create_task (self .cm .handle_incoming_message (message , self .addr ), name = f"Connection { self .addr } message handler" )
365
378
elif data_type_prefix == self .DATA_TYPE_PREFIXES ["string" ]:
366
379
logging .debug (f"Received string message: { message .decode ('utf-8' )} " )
0 commit comments