Skip to content

Commit 614044d

Browse files
committed
refactor(core): move checksum validation to Channel
Also, move channel usage update to `interface_context`. Keep the last valid reassembled message at `Reassembler.message`. It will be used to implement `ThpChannel` read functionality. [no changelog]
1 parent 45e870b commit 614044d

File tree

4 files changed

+51
-38
lines changed

4 files changed

+51
-38
lines changed

core/src/trezor/wire/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ async def handle_session(iface: WireInterface) -> None:
105105

106106
while True:
107107
try:
108-
(channel, message) = await ctx.get_next_message()
108+
channel = await ctx.get_next_message()
109+
message = channel.reassembler.message
110+
assert message is not None
109111
await received_message_handler.handle_received_message(channel, message)
110112
except Exception:
111113
loop.clear() # restart event loop in case of error

core/src/trezor/wire/thp/channel.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
2626
from . import alternating_bit_protocol as ABP
2727
from . import control_byte, crypto, memory_manager
28-
from .checksum import CHECKSUM_LENGTH
28+
from .checksum import CHECKSUM_LENGTH, is_valid
2929
from .transmission_loop import TransmissionLoop
3030
from .writer import MESSAGE_TYPE_LENGTH
3131

@@ -52,20 +52,23 @@ def __init__(self, cid: int) -> None:
5252
self.reset()
5353

5454
def reset(self) -> None:
55-
self.bytes_read = 0
56-
self.buffer_len = 0
55+
self.bytes_read: int = 0
56+
self.buffer_len: int = 0
57+
self.message: memoryview | None = None
5758

58-
def get_next_message(self, packet: memoryview) -> memoryview | None:
59+
def handle_packet(self, packet: memoryview) -> bool:
5960
"""
60-
Process current packet, returning the payload buffer on success.
61+
Process current packet, returning `True` when a valid message is reassembled.
62+
The parsed message can retrieved via the `message` field (if it's not `None`).
63+
In case of a checksum error or if the reassembly is not over, return `False`.
6164
6265
May raise `WireBufferError` if there is a concurrent payload reassembly in progress.
6366
"""
6467
ctrl_byte = packet[0]
6568
if control_byte.is_continuation(ctrl_byte):
6669
if not self.bytes_read:
6770
# ignore unexpected continuation packets
68-
return None
71+
return False
6972

7073
# may raise WireBufferError
7174
buffer = memory_manager.get_existing_read_buffer(self.cid)
@@ -86,19 +89,36 @@ def get_next_message(self, packet: memoryview) -> memoryview | None:
8689

8790
assert len(buffer) == self.buffer_len
8891
if self.bytes_read < self.buffer_len:
89-
return None
90-
elif self.bytes_read == self.buffer_len:
91-
self.reset()
92-
return buffer
93-
else:
92+
return False
93+
94+
if self.bytes_read > self.buffer_len:
9495
raise ThpError("read more bytes than expected")
9596

97+
if not verify_checksum(buffer):
98+
return False
99+
100+
assert self.message is None
101+
self.message = buffer
102+
return True
103+
96104
def _buffer_packet_data(
97105
self, payload_buffer: memoryview, packet: memoryview, offset: int
98106
) -> None:
99107
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
100108

101109

110+
def verify_checksum(buffer: memoryview) -> memoryview | None:
111+
"""
112+
Return the buffer if the checksum is valid, otherwise return `None`.
113+
"""
114+
if is_valid(buffer[-CHECKSUM_LENGTH:], buffer[:-CHECKSUM_LENGTH]):
115+
return buffer
116+
# ignore invalid payloads
117+
if __debug__:
118+
log.warning("Invalid payload checksum: %s", utils.hexlify_if_bytes(buffer))
119+
return None
120+
121+
102122
class Channel:
103123
"""
104124
THP protocol encrypted communication channel.
@@ -184,11 +204,18 @@ def is_channel_to_replace(self) -> bool:
184204

185205
# READ and DECRYPT
186206

187-
def handle_packet(self, packet: utils.BufferType) -> memoryview | None:
207+
def reassemble(self, packet: utils.BufferType) -> bool:
208+
"""
209+
Process current packet, returning `True` when a valid message is reassembled.
210+
The parsed message can retrieved via the `message` field (if it's not `None`).
211+
In case of a checksum error or if the reassembly is not over, return `False`.
212+
213+
May raise `WireBufferError` if there is a concurrent payload reassembly in progress.
214+
"""
188215
if self.get_channel_state() == ChannelState.UNALLOCATED:
189-
return None
216+
return False
190217
try:
191-
return self.reassembler.get_next_message(memoryview(packet))
218+
return self.reassembler.handle_packet(memoryview(packet))
192219
except WireBufferError:
193220
self.reassembler.reset()
194221
raise

core/src/trezor/wire/thp/interface_context.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
BROADCAST_CHANNEL_ID,
88
ChannelCache,
99
iter_allocated_channels,
10+
update_channel_last_used,
1011
)
1112
from trezor import io, loop, utils
1213

@@ -60,7 +61,7 @@ def __init__(self, iface: WireInterface) -> None:
6061
self._write = loop.wait(iface.iface_num() | io.POLL_WRITE)
6162
self._channels: dict[int, Channel] = {}
6263

63-
async def get_next_message(self) -> tuple[Channel, memoryview]:
64+
async def get_next_message(self) -> Channel:
6465
packet = bytearray(self._iface.RX_PACKET_LEN)
6566
while True:
6667
packet_len = await self._read
@@ -84,11 +85,11 @@ async def get_next_message(self) -> tuple[Channel, memoryview]:
8485
continue
8586

8687
try:
87-
message = channel.handle_packet(packet)
88-
if message is not None:
89-
# `message` must be handled ASAP without blocking,
88+
if channel.reassemble(packet):
89+
update_channel_last_used(channel.channel_id)
90+
# The reassembled message must be handled ASAP without blocking,
9091
# since it may point to the global read buffer.
91-
return channel, message
92+
return channel
9293
except WireBufferError:
9394
await channel.write_error(ThpErrorType.TRANSPORT_BUSY)
9495
continue

core/src/trezor/wire/thp/received_message_handler.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
KEY_LENGTH,
1313
SESSION_ID_LENGTH,
1414
TAG_LENGTH,
15-
update_channel_last_used,
1615
update_session_last_used,
1716
)
1817
from trezor import config, loop, protobuf, utils
@@ -39,7 +38,7 @@
3938
ThpUnallocatedSessionError,
4039
)
4140
from . import alternating_bit_protocol as ABP
42-
from . import checksum, control_byte, get_encoded_device_properties, session_manager
41+
from . import control_byte, get_encoded_device_properties, session_manager
4342
from .checksum import CHECKSUM_LENGTH
4443
from .crypto import PUBKEY_LENGTH, Handshake
4544
from .session_context import SeedlessSessionContext
@@ -82,8 +81,6 @@ async def handle_received_message(
8281
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
8382
message_length = payload_length + PacketHeader.INIT_LENGTH
8483

85-
_check_checksum(message_length, message_buffer)
86-
8784
# Synchronization process
8885
seq_bit = control_byte.get_seq_bit(ctrl_byte)
8986
ack_bit = control_byte.get_ack_bit(ctrl_byte)
@@ -95,8 +92,6 @@ async def handle_received_message(
9592
ack_bit,
9693
iface=ctx.iface,
9794
)
98-
# 0: Update "last-time used"
99-
update_channel_last_used(ctx.channel_id)
10095

10196
# 1: Handle ACKs
10297
if control_byte.is_ack(ctrl_byte):
@@ -161,18 +156,6 @@ def _send_ack(channel: Channel, ack_bit: int) -> Awaitable[None]:
161156
return channel.ctx.write_payload(header, b"")
162157

163158

164-
def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None:
165-
if __debug__:
166-
log.debug(__name__, "check_checksum")
167-
if not checksum.is_valid(
168-
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
169-
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
170-
):
171-
if __debug__:
172-
log.debug(__name__, "Invalid checksum, ignoring message.")
173-
raise ThpError("Invalid checksum, ignoring message.")
174-
175-
176159
async def handle_ack(ctx: Channel, ack_bit: int) -> None:
177160
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
178161
return

0 commit comments

Comments
 (0)