From 8e697391e771837f324880b45949e19308ed941c Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Wed, 20 Aug 2025 14:39:47 +0300 Subject: [PATCH 1/4] refactor(python): refactor THP-related exception types [no changelog] --- python/src/trezorlib/cli/__init__.py | 2 +- python/src/trezorlib/debuglink.py | 16 ++------- python/src/trezorlib/exceptions.py | 26 ++++++++++++-- .../trezorlib/transport/thp/protocol_v2.py | 35 ++++++------------- tests/device_tests/thp/test_multiple_hosts.py | 5 ++- tests/device_tests/thp/test_pairing.py | 6 ++-- 6 files changed, 41 insertions(+), 49 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 2d332e50955..d5c95e0731d 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -293,7 +293,7 @@ def session_context( empty_passphrase=empty_passphrase, must_resume=must_resume, ) - except exceptions.DeviceLockedException: + except exceptions.DeviceLocked: click.echo( "Device is locked, enter a pin on the device.", err=True, diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 1a47f7ddb26..af92bb5c627 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -34,12 +34,7 @@ from . import btc, mapping, messages, models, protobuf from .client import ProtocolVersion, TrezorClient -from .exceptions import ( - Cancelled, - DeviceLockedException, - TrezorFailure, - UnexpectedMessageError, -) +from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES from .messages import DebugTouchEventType, DebugWaitType from .tools import parse_path @@ -1317,14 +1312,7 @@ def get_pin(_msg: messages.PinMatrixRequest) -> str: self.pin_callback = get_pin self.button_callback = self.ui.button_request - try: - super().__init__(transport) - except DeviceLockedException: - LOG.debug("Locked device handling") - self.debug.input("") - self.debug.input(self.debug.encode_pin("1234")) - super().__init__(transport) - + super().__init__(transport) self.sync_responses() # So that we can choose right screenshotting logic (T1 vs TT) diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index 95d0cd30f2c..0f8c0d1f275 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -111,13 +111,33 @@ class DerivationOnUninitaizedDeviceError(TrezorException): To communicate with uninitialized device, use seedless session instead.""" -class DeviceLockedException(TrezorException): +class UnexpectedCodeEntryTagException(TrezorException): pass -class UnexpectedCodeEntryTagException(TrezorException): +class ThpError(TrezorException): pass -class ThpError(TrezorException): +class TransportBusy(ThpError): + pass + + +class UnallocatedChannel(ThpError): + pass + + +class DecryptionFailed(ThpError): + pass + + +class InvalidData(ThpError): + pass + + +class DeviceLocked(ThpError): + pass + + +class ThpUnknownError(ThpError): pass diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index c805e249f2a..7a56a57812c 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -231,11 +231,6 @@ def _send_handshake_init_request(self) -> None: def _read_handshake_init_response(self) -> bytes: header, payload = self._read_until_valid_crc_check() - if control_byte.is_error(header.ctrl_byte): - if payload == b"\x05": - raise exceptions.DeviceLockedException() - else: - raise exceptions.ThpError(_get_error_from_int(payload[0])) if not header.is_handshake_init_response(): LOG.error("Received message is not a valid handshake init response message") @@ -278,8 +273,6 @@ def _read_handshake_completion_response(self) -> int: header, data = self._read_until_valid_crc_check() if not header.is_handshake_comp_response(): LOG.error("Received message is not a valid handshake completion response") - if control_byte.is_error(header.ctrl_byte): - raise exceptions.ThpError(_get_error_from_int(data[0])) trezor_state = self._noise.decrypt(bytes(data)) assert trezor_state == b"\x00" or trezor_state == b"\x01" self._send_ack_bit(bit=1) @@ -289,8 +282,6 @@ def _read_ack(self): header, payload = self._read_until_valid_crc_check() if not header.is_ack() or len(payload) > 0: LOG.error("Received message is not a valid ACK") - if control_byte.is_error(header.ctrl_byte): - raise exceptions.ThpError(_get_error_from_int(payload[0])) def _send_ack_bit(self, bit: int): if bit not in (0, 1): @@ -336,8 +327,6 @@ def read_and_decrypt( continue if control_byte.is_ack(header.ctrl_byte): continue - if control_byte.is_error(header.ctrl_byte): - raise exceptions.ThpError(_get_error_from_int(raw_payload[0])) if not header.is_encrypted_transport(): LOG.error( "Trying to decrypt not encrypted message! (" @@ -392,6 +381,10 @@ def _read_until_valid_crc_check( self.sync_bit_receive = 1 - self.sync_bit_receive + if control_byte.is_error(header.ctrl_byte): + code = payload[0] + raise _ERRORS_MAP.get(code) or exceptions.ThpUnknownError(code) + return header, payload def _is_valid_channel_allocation_response( @@ -420,16 +413,10 @@ def _is_valid_pong( return True -def _get_error_from_int(error_code: int) -> str: - # TODO FIXME improve this (ThpErrorType) - if error_code == 1: - return "TRANSPORT BUSY" - if error_code == 2: - return "UNALLOCATED CHANNEL" - if error_code == 3: - return "DECRYPTION FAILED" - if error_code == 4: - return "INVALID DATA" - if error_code == 5: - return "DEVICE LOCKED" - raise Exception("Not Implemented error case") +_ERRORS_MAP = { + 1: exceptions.TransportBusy, + 2: exceptions.UnallocatedChannel, + 3: exceptions.DecryptionFailed, + 4: exceptions.InvalidData, + 5: exceptions.DeviceLocked, +} diff --git a/tests/device_tests/thp/test_multiple_hosts.py b/tests/device_tests/thp/test_multiple_hosts.py index 4e145ecc7bf..d424576ece6 100644 --- a/tests/device_tests/thp/test_multiple_hosts.py +++ b/tests/device_tests/thp/test_multiple_hosts.py @@ -31,10 +31,9 @@ def test_concurrent_handshakes(client: Client) -> None: # The second host starts handshake protocol_2._send_handshake_init_request() - # The second host should not be able to interrupt the first host's handshake - with pytest.raises(exceptions.ThpError) as e: + # The second host should not be able to interrupt the first host's handshake immediately + with pytest.raises(exceptions.TransportBusy): protocol_2._read_ack() - assert e.value.args[0] == "TRANSPORT BUSY" # The first host can complete handshake protocol_1._send_handshake_completion_request() diff --git a/tests/device_tests/thp/test_pairing.py b/tests/device_tests/thp/test_pairing.py index 5885762572f..63186abc598 100644 --- a/tests/device_tests/thp/test_pairing.py +++ b/tests/device_tests/thp/test_pairing.py @@ -406,9 +406,8 @@ def test_credential_phase(client: Client) -> None: protocol._noise.noise_protocol.cipher_state_encrypt.n = 250 protocol._send_message(ButtonAck()) - with pytest.raises(exceptions.ThpError) as e: + with pytest.raises(exceptions.DecryptionFailed): protocol.read(1) - assert e.value.args[0] == "DECRYPTION FAILED" # Connect using credential with confirmation and ask for autoconnect credential. protocol = prepare_protocol_for_handshake(client) @@ -457,9 +456,8 @@ def test_credential_phase(client: Client) -> None: protocol._noise.noise_protocol.cipher_state_encrypt.n = 100 protocol._send_message(ButtonAck()) - with pytest.raises(exceptions.ThpError) as e: + with pytest.raises(exceptions.DecryptionFailed): protocol.read(1) - assert e.value.args[0] == "DECRYPTION FAILED" # Connect using autoconnect credential - should work the same as above protocol = prepare_protocol_for_handshake(client) From 52faa5b1e7fef35c9ddea03959c5f8c4e1cb54aa Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 21 Aug 2025 12:40:46 +0300 Subject: [PATCH 2/4] fix(core): correct `ticks_diff` arguments' order [no changelog] --- core/mocks/utime.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/mocks/utime.pyi b/core/mocks/utime.pyi index c2ff9bb82e2..3bbcae6884a 100644 --- a/core/mocks/utime.pyi +++ b/core/mocks/utime.pyi @@ -5,5 +5,5 @@ def ticks_ms() -> int: ... def ticks_us() -> int: ... def ticks_cpu() -> int: ... def ticks_add(ticks_in: int, delta_in: int) -> int: ... -def ticks_diff(old: int, new: int) -> int: ... +def ticks_diff(new: int, old: int) -> int: ... def gmtime2000(timestamp: int) -> tuple: ... From daff0c89bc0510243baa0d7d4e889772635e5631 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 21 Aug 2025 12:41:51 +0300 Subject: [PATCH 3/4] chore(core): exclude logging on non-debug builds [no changelog] --- core/src/trezor/wire/thp/session_context.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 0c23bb70a7c..54667419942 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -50,7 +50,8 @@ async def handle(self, message: Message | None = None) -> None: if message is None: message = await self._read_next_message() await handle_single_message(self, message) - self.channel._log("session loop is over") + if __debug__: + self.channel._log("session loop is over") return except protocol_common.WireError as e: if __debug__: @@ -70,9 +71,10 @@ async def _read_next_message(self) -> Message: session_id, message = await self.channel.decrypt_message() if session_id == self.session_id: return message - self.channel._log( - "Ignored message for unexpected session", logger=log.warning - ) + if __debug__: + self.channel._log( + "Ignored message for unexpected session", logger=log.warning + ) async def read( self, From bfcb11cb9ef240826e50057a4a70000ac1ffdc48 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Thu, 21 Aug 2025 11:38:21 +0300 Subject: [PATCH 4/4] feat(core): allow aborting stale THP workflows [no changelog] --- core/src/trezor/wire/__init__.py | 5 ++++- core/src/trezor/wire/context.py | 6 +++++- core/src/trezor/wire/message_handler.py | 7 ++++++- core/src/trezor/wire/thp/channel.py | 20 +++++++++++++++++++ core/src/trezor/wire/thp/interface_context.py | 9 ++++++--- core/src/trezor/wire/thp/pairing_context.py | 4 +++- 6 files changed, 44 insertions(+), 7 deletions(-) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 99cb8a4d362..5fc4aa107f8 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -104,7 +104,10 @@ async def handle_session(iface: WireInterface) -> None: if __debug__: _THP_CHANNELS.append(ctx._channels) try: - channel = await ctx.get_next_message() + while (channel := await ctx.get_next_message()) is None: + # If there is an active THP channel on another interface, + # wait for event loop restart, while responding to the host. + pass while await received_message_handler.handle_received_message(channel): pass finally: diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 96598f0c079..0a81c0f3666 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -35,7 +35,11 @@ T = TypeVar("T") -class UnexpectedMessageException(Exception): +class AbortWorkflow(Exception): + pass + + +class UnexpectedMessageException(AbortWorkflow): """A message was received that is not part of the current workflow. Utility exception to inform the session handler that the current workflow diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index 470eec4f721..6bb32b08637 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -5,7 +5,7 @@ from trezor.enums import FailureType from trezor.messages import Failure -from .context import UnexpectedMessageException, with_context +from .context import AbortWorkflow, UnexpectedMessageException, with_context from .errors import ActionCancelled, DataError, Error, UnexpectedMessage from .protocol_common import Context, Message @@ -153,6 +153,11 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: # to process. It is not a standard exception that should be logged and a result # sent to the wire. raise + + except AbortWorkflow: + # Aborting current workflow, to restart the event loop. + return False + except BaseException as exc: # Either: # - the message had a type that has a registered handler, but does not have diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index dfe875b6397..f378e6e1e9e 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -1,4 +1,5 @@ import ustruct +import utime from micropython import const from typing import TYPE_CHECKING @@ -23,6 +24,7 @@ from trezor import protobuf, utils, workflow from trezor.loop import Timeout +from ..context import AbortWorkflow from ..protocol_common import Message from . import ( ACK_MESSAGE, @@ -58,6 +60,8 @@ _MAX_RETRANSMISSION_COUNT = const(50) _MIN_RETRANSMISSION_COUNT = const(2) +_STALE_CHANNEL_TIMEOUT_MS = const(1000) + class Reassembler: def __init__(self, cid: int, read_buf: ThpBuffer) -> None: @@ -270,12 +274,28 @@ async def _get_reassembled_message( self, timeout_ms: int | None = None ) -> memoryview: """Doesn't block if a message has been already reassembled.""" + ping_sent_ms: int | None = None while self.reassembler.message is None: # receive and reassemble a new message from this channel channel = await self.ctx.get_next_message(timeout_ms=timeout_ms) if channel is self: break + # interrupted by a different channel + if ping_sent_ms is None: + ping_sent_ms = utime.ticks_ms() + continue + + elapsed = utime.ticks_diff(utime.ticks_ms(), ping_sent_ms) + if elapsed < _STALE_CHANNEL_TIMEOUT_MS: + continue + + if __debug__: + self._log( + f"Interrupting stale channel after {elapsed} ms", logger=log.warning + ) + raise AbortWorkflow + # currently only single-channel sessions are supported during a single event loop run self._log( "Ignoring unexpected channel: ", diff --git a/core/src/trezor/wire/thp/interface_context.py b/core/src/trezor/wire/thp/interface_context.py index 819cd09346a..6eeec382903 100644 --- a/core/src/trezor/wire/thp/interface_context.py +++ b/core/src/trezor/wire/thp/interface_context.py @@ -47,11 +47,12 @@ def __init__(self, iface: WireInterface) -> None: self._write = loop.wait(iface.iface_num() | io.POLL_WRITE) self._channels: dict[int, Channel] = {} - async def get_next_message(self, timeout_ms: int | None = None) -> Channel: + async def get_next_message(self, timeout_ms: int | None = None) -> Channel | None: """ Reassemble a valid THP payload and return its channel. Also handle THP channel allocation. + Return `None` if there is already another active THP channel. """ from .. import THP_BUFFERS_PROVIDER @@ -80,9 +81,11 @@ async def get_next_message(self, timeout_ms: int | None = None) -> Channel: if (channel := self._channels.get(cid)) is None: if (buffers := THP_BUFFERS_PROVIDER.take()) is None: - # concurrent payload reassembly is not supported + # Concurrent payload reassembly is not supported: + # - Notify the new channel, so it will retry later. + # - Interrupt this method, to detect whether the existing channel is still alive. await self.write_error(cid, ThpErrorType.TRANSPORT_BUSY) - continue + return None channel = self._channels[cid] = Channel(cache, self, buffers) if channel.reassemble(packet): diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 8520f08c2c5..5ce11d6d102 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -3,7 +3,7 @@ from trezor import loop, protobuf, workflow from trezor.wire import context, message_handler, protocol_common -from trezor.wire.context import UnexpectedMessageException +from trezor.wire.context import AbortWorkflow, UnexpectedMessageException from trezor.wire.errors import ActionCancelled, DataError, SilentError from trezor.wire.protocol_common import Context, Message from trezor.wire.thp import ChannelState, get_enabled_pairing_methods, ui @@ -234,6 +234,8 @@ async def handle_message( # We might handle only the few common cases here, like # Initialize and Cancel. return exc.msg + except AbortWorkflow: + return None except SilentError as exc: if __debug__: log.error(__name__, "SilentError: %s", exc.message, iface=pairing_ctx.iface)