25
25
from . import ENCRYPTED , ChannelState , PacketHeader , ThpDecryptionError , ThpError
26
26
from . import alternating_bit_protocol as ABP
27
27
from . import control_byte , crypto , memory_manager
28
- from .checksum import CHECKSUM_LENGTH
28
+ from .checksum import CHECKSUM_LENGTH , is_valid
29
29
from .transmission_loop import TransmissionLoop
30
30
from .writer import MESSAGE_TYPE_LENGTH
31
31
@@ -52,20 +52,23 @@ def __init__(self, cid: int) -> None:
52
52
self .reset ()
53
53
54
54
def reset (self ) -> None :
55
- self .bytes_read = 0
56
- self .buffer_len = 0
55
+ self .bytes_read : int = 0
56
+ self .buffer_len : int = 0
57
+ self .message : memoryview | None = None
57
58
58
- def get_next_message (self , packet : memoryview ) -> memoryview | None :
59
+ def handle_packet (self , packet : memoryview ) -> bool :
59
60
"""
60
- Process current packet, returning the payload buffer on success.
61
+ Process current packet, returning `True` when a valid message is reassembled.
62
+ The parsed message can retrieved via the `message` field (if it's not `None`).
63
+ In case of a checksum error or if the reassembly is not over, return `False`.
61
64
62
65
May raise `WireBufferError` if there is a concurrent payload reassembly in progress.
63
66
"""
64
67
ctrl_byte = packet [0 ]
65
68
if control_byte .is_continuation (ctrl_byte ):
66
69
if not self .bytes_read :
67
70
# ignore unexpected continuation packets
68
- return None
71
+ return False
69
72
70
73
# may raise WireBufferError
71
74
buffer = memory_manager .get_existing_read_buffer (self .cid )
@@ -86,19 +89,36 @@ def get_next_message(self, packet: memoryview) -> memoryview | None:
86
89
87
90
assert len (buffer ) == self .buffer_len
88
91
if self .bytes_read < self .buffer_len :
89
- return None
90
- elif self .bytes_read == self .buffer_len :
91
- self .reset ()
92
- return buffer
93
- else :
92
+ return False
93
+
94
+ if self .bytes_read > self .buffer_len :
94
95
raise ThpError ("read more bytes than expected" )
95
96
97
+ if not verify_checksum (buffer ):
98
+ return False
99
+
100
+ assert self .message is None
101
+ self .message = buffer
102
+ return True
103
+
96
104
def _buffer_packet_data (
97
105
self , payload_buffer : memoryview , packet : memoryview , offset : int
98
106
) -> None :
99
107
self .bytes_read += utils .memcpy (payload_buffer , self .bytes_read , packet , offset )
100
108
101
109
110
+ def verify_checksum (buffer : memoryview ) -> memoryview | None :
111
+ """
112
+ Return the buffer if the checksum is valid, otherwise return `None`.
113
+ """
114
+ if is_valid (buffer [- CHECKSUM_LENGTH :], buffer [:- CHECKSUM_LENGTH ]):
115
+ return buffer
116
+ # ignore invalid payloads
117
+ if __debug__ :
118
+ log .warning ("Invalid payload checksum: %s" , utils .hexlify_if_bytes (buffer ))
119
+ return None
120
+
121
+
102
122
class Channel :
103
123
"""
104
124
THP protocol encrypted communication channel.
@@ -184,11 +204,18 @@ def is_channel_to_replace(self) -> bool:
184
204
185
205
# READ and DECRYPT
186
206
187
- def handle_packet (self , packet : utils .BufferType ) -> memoryview | None :
207
+ def reassemble (self , packet : utils .BufferType ) -> bool :
208
+ """
209
+ Process current packet, returning `True` when a valid message is reassembled.
210
+ The parsed message can retrieved via the `message` field (if it's not `None`).
211
+ In case of a checksum error or if the reassembly is not over, return `False`.
212
+
213
+ May raise `WireBufferError` if there is a concurrent payload reassembly in progress.
214
+ """
188
215
if self .get_channel_state () == ChannelState .UNALLOCATED :
189
- return None
216
+ return False
190
217
try :
191
- return self .reassembler .get_next_message (memoryview (packet ))
218
+ return self .reassembler .handle_packet (memoryview (packet ))
192
219
except WireBufferError :
193
220
self .reassembler .reset ()
194
221
raise
0 commit comments