From 348417967166e6893e7560d332caa1c9c03f9541 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 15 Aug 2025 09:00:35 +0300 Subject: [PATCH 1/3] test(core): wait for device availability to avoid UI tests flakiness [no changelog] --- tests/ui_tests/__init__.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 48350646786..f38acacf6a8 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -2,7 +2,6 @@ import logging import shutil -import time import typing as t from contextlib import contextmanager @@ -10,7 +9,6 @@ from _pytest.nodes import Node from _pytest.outcomes import Failed -from trezorlib.client import ProtocolVersion from trezorlib.debuglink import TrezorClientDebugLink as Client LOG = logging.getLogger(__name__) @@ -58,15 +56,9 @@ def screen_recording( shutil.rmtree(testcase.actual_dir, ignore_errors=True) testcase.actual_dir.mkdir() - if client.protocol_version is ProtocolVersion.V2: - # In case of an event loop restart, it's possible that the first - # packet(s) of `DebugLinkRecordScreen` will be lost, resulting in - # `TrezorFailure: FirmwareError: Invalid magic` error responses - # during test setup. - # This issue will be resolved as part of THP event loop restart refactoring, - # but till then let's wait a bit here, to reduce the packet loss probability. - # TODO: remove after THP event loop restart refactoring - time.sleep(0.1) + # Make sure the device is ready - otherwise, the next `DebugLinkRecordScreen` request + # may be lost due to an event loop restart. + client.sync_responses() try: client.debug.start_recording(str(testcase.actual_dir)) yield From 5f7d5b4f3afb3a549abf61ac9d71b92b8552c0f7 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Wed, 13 Aug 2025 16:50:12 +0300 Subject: [PATCH 2/3] feat(core): reimplement THP event loop restarts [no changelog] --- core/embed/upymod/qstrdefsport.h | 2 - core/src/apps/base.py | 5 +- .../apps/management/reboot_to_bootloader.py | 2 +- core/src/apps/management/wipe_device.py | 5 +- core/src/apps/thp/pairing.py | 8 +- core/src/storage/cache_thp.py | 17 +- core/src/trezor/wire/__init__.py | 27 +- core/src/trezor/wire/errors.py | 5 - core/src/trezor/wire/protocol_common.py | 5 +- core/src/trezor/wire/thp/__init__.py | 8 - core/src/trezor/wire/thp/channel.py | 280 ++++++++------ core/src/trezor/wire/thp/interface_context.py | 69 ++-- core/src/trezor/wire/thp/memory_manager.py | 150 +------- core/src/trezor/wire/thp/pairing_context.py | 15 +- .../wire/thp/received_message_handler.py | 349 ++++-------------- core/src/trezor/wire/thp/session_context.py | 51 +-- core/src/trezor/wire/thp/transmission_loop.py | 62 ---- core/tests/test_trezor.wire.thp.py | 3 - core/tests/thp_common.py | 7 +- tests/device_tests/thp/test_multiple_hosts.py | 140 +------ tests/ui_tests/fixtures.json | 18 +- 21 files changed, 347 insertions(+), 881 deletions(-) delete mode 100644 core/src/trezor/wire/thp/transmission_loop.py diff --git a/core/embed/upymod/qstrdefsport.h b/core/embed/upymod/qstrdefsport.h index bbea7d5e407..c80eb2f80d9 100644 --- a/core/embed/upymod/qstrdefsport.h +++ b/core/embed/upymod/qstrdefsport.h @@ -421,7 +421,6 @@ Q(session_manager) Q(storage.cache_thp) Q(storage.cache_thp_keys) Q(thp) -Q(transmission_loop) Q(trezor.enums.ThpMessageType) Q(trezor.enums.ThpPairingMethod) Q(trezor.wire.thp) @@ -438,7 +437,6 @@ Q(trezor.wire.thp.pairing_context) Q(trezor.wire.thp.received_message_handler) Q(trezor.wire.thp.session_context) Q(trezor.wire.thp.session_manager) -Q(trezor.wire.thp.transmission_loop) Q(trezor.wire.thp.ui) Q(trezor.wire.thp.writer) Q(ui) diff --git a/core/src/apps/base.py b/core/src/apps/base.py index bbe145cc5a3..7d3d680a18a 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -231,7 +231,7 @@ async def handle_ThpCreateNewSession( Returns an appropriate `Failure` message if session creation fails. """ - from trezor import log, loop + from trezor import log from trezor.enums import FailureType from trezor.messages import Failure from trezor.wire import NotInitialized @@ -281,9 +281,6 @@ async def handle_ThpCreateNewSession( message.passphrase if message.passphrase is not None else "", ) - channel.sessions[new_session.session_id] = new_session - loop.schedule(new_session.handle()) - return Success(message="New session created.") async def handle_ThpCredentialRequest( diff --git a/core/src/apps/management/reboot_to_bootloader.py b/core/src/apps/management/reboot_to_bootloader.py index fadc65a9912..1b0523deeae 100644 --- a/core/src/apps/management/reboot_to_bootloader.py +++ b/core/src/apps/management/reboot_to_bootloader.py @@ -94,7 +94,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn: boot_args = None ctx = get_context() - await ctx.write_force(Success(message="Rebooting")) + await ctx.write(Success(message="Rebooting")) # make sure the outgoing USB buffer is flushed await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) diff --git a/core/src/apps/management/wipe_device.py b/core/src/apps/management/wipe_device.py index a4a8ad0eafd..b4f2b628f5b 100644 --- a/core/src/apps/management/wipe_device.py +++ b/core/src/apps/management/wipe_device.py @@ -13,7 +13,7 @@ async def wipe_device(msg: WipeDevice) -> NoReturn: import storage - from trezor import TR, config, loop, translations + from trezor import TR, config, translations from trezor.enums import ButtonRequestType from trezor.messages import Success from trezor.pin import render_empty_loader @@ -49,7 +49,7 @@ async def wipe_device(msg: WipeDevice) -> NoReturn: translations.deinit() translations.erase() try: - await get_context().write_force(Success(message="Device wiped")) + await get_context().write(Success(message="Device wiped")) except Exception: if __debug__: log.debug(__name__, "Failed to send Success message after wipe.") @@ -58,6 +58,5 @@ async def wipe_device(msg: WipeDevice) -> NoReturn: # reload settings reload_settings_from_storage() - loop.clear() if __debug__: log.debug(__name__, "Device wipe - finished") diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index a9c945592eb..61a37bf27f1 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -221,7 +221,7 @@ async def _handle_code_entry_is_selected(ctx: PairingContext) -> None: if ctx.code_entry_secret is None: await _handle_code_entry_is_selected_first_time(ctx) else: - await ctx.write_force(ThpPairingPreparationsFinished()) + await ctx.write(ThpPairingPreparationsFinished()) async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None: @@ -251,7 +251,7 @@ async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None ) assert ctx.code_code_entry is not None ctx.cpace.generate_keys(f"{ctx.code_code_entry:06}".encode("ascii")) - await ctx.write_force( + await ctx.write( ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key) ) @@ -259,7 +259,7 @@ async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None @check_state_and_log(ChannelState.TP1) async def _handle_nfc_is_selected(ctx: PairingContext) -> None: ctx.nfc_secret = random.bytes(16) - await ctx.write_force(ThpPairingPreparationsFinished()) + await ctx.write(ThpPairingPreparationsFinished()) @check_state_and_log(ChannelState.TP1) @@ -271,7 +271,7 @@ async def _handle_qr_code_is_selected(ctx: PairingContext) -> None: sha_ctx.update(ctx.qr_code_secret) ctx.code_qr_code = sha_ctx.digest()[:16] - await ctx.write_force(ThpPairingPreparationsFinished()) + await ctx.write(ThpPairingPreparationsFinished()) @check_state_and_log(ChannelState.TP3) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 5a9c9a00757..7247616ceab 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -1,11 +1,9 @@ import builtins from micropython import const -from typing import TYPE_CHECKING from storage.cache_common import ( CHANNEL_HOST_STATIC_PUBKEY, CHANNEL_ID, - CHANNEL_IFACE, CHANNEL_STATE, CHANNEL_SYNC, SESSION_ID, @@ -13,12 +11,6 @@ DataCache, ) -if TYPE_CHECKING: - from typing import Iterable, Tuple - - pass - - # THP specific constants _MAX_CHANNELS_COUNT = const(10) _MAX_SESSIONS_COUNT = const(20) @@ -183,13 +175,14 @@ def update_session_last_used(channel_id: bytes, session_id: bytes) -> None: return -def iter_allocated_channels(iface_num: int) -> Iterable[ChannelCache]: +def find_allocated_channel(cid: int) -> ChannelCache | None: for channel in _CHANNELS: state = channel.get_int(CHANNEL_STATE, _UNALLOCATED_STATE) if state == _UNALLOCATED_STATE: continue - if channel.get_int(CHANNEL_IFACE) == iface_num: - yield channel + if channel.get_int(CHANNEL_ID) == cid: + return channel + return None def get_allocated_session( @@ -393,7 +386,7 @@ def clear_all() -> None: channel.clear() -def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None: +def clear_all_except_one_session_keys(excluded: tuple[bytes, bytes]) -> None: cid, sid = excluded for channel in _CHANNELS: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 46b4607df72..ce54741fd36 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -82,9 +82,10 @@ def setup(iface: WireInterface) -> None: if utils.USE_THP: - # memory_manager is imported to create READ/WRITE buffers - # in more stable area of memory - from .thp import memory_manager # noqa: F401 + from .thp.memory_manager import ThpBuffer + + # Allocate THP read/write buffers in more stable area of memory + THP_BUFFERS_PROVIDER = Provider((ThpBuffer(), ThpBuffer())) if __debug__: _THP_CHANNELS = [] @@ -99,19 +100,17 @@ def find_thp_channel(channel_id: bytes) -> Channel | None: return None async def handle_session(iface: WireInterface) -> None: - ctx = ThpContext.load_from_cache(iface) + ctx = ThpContext(iface) if __debug__: _THP_CHANNELS.append(ctx._channels) - - while True: - try: - channel = await ctx.get_next_message() - message = channel.reassembler.message - assert message is not None - await received_message_handler.handle_received_message(channel, message) - except Exception: - loop.clear() # restart event loop in case of error - raise # the traceback will be printed by `loop._step()` + try: + channel = await ctx.get_next_message() + while await received_message_handler.handle_received_message(channel): + pass + finally: + # Wait for all active workflows to finish. + await workflow.join_all() + loop.clear() else: diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index 8f572fcf0c2..e8b2d3feb45 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -14,11 +14,6 @@ def __init__(self, message: str) -> None: self.message = message -class WireBufferError(Error): - def __init__(self, message: str = "Buffer error") -> None: - super().__init__(FailureType.BufferError, message) - - class UnexpectedMessage(Error): def __init__(self, message: str) -> None: super().__init__(FailureType.UnexpectedMessage, message) diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index da35c0d291d..d7fa3532194 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -74,13 +74,10 @@ async def read( """ ... - async def write(self, msg: protobuf.MessageType) -> None: + def write(self, msg: protobuf.MessageType) -> Awaitable[None]: """Write a message to the wire.""" ... - def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: - return self.write(msg) - async def call( self, msg: protobuf.MessageType, diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index 38c56329adb..56c32e20ce7 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -44,18 +44,10 @@ class ThpDecryptionError(ThpError): pass -class ThpInvalidDataError(ThpError): - pass - - class ThpDeviceLockedError(ThpError): pass -class ThpUnallocatedChannelError(ThpError): - pass - - class ThpUnallocatedSessionError(ThpError): def __init__(self, session_id: int) -> None: diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 8661d0c4501..4a772f16507 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -19,14 +19,20 @@ conditionally_replace_channel, is_there_a_channel_to_replace, ) -from trezor import loop, protobuf, utils, workflow -from trezor.wire.errors import WireBufferError - -from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError +from trezor import protobuf, utils, workflow + +from ..protocol_common import Message +from . import ( + ACK_MESSAGE, + ENCRYPTED, + ChannelState, + PacketHeader, + ThpDecryptionError, + ThpError, +) from . import alternating_bit_protocol as ABP from . import control_byte, crypto, memory_manager from .checksum import CHECKSUM_LENGTH, is_valid -from .transmission_loop import TransmissionLoop from .writer import MESSAGE_TYPE_LENGTH if __debug__: @@ -36,19 +42,21 @@ from . import state_to_str if TYPE_CHECKING: - from typing import Any, Awaitable + from typing import Any, Awaitable, Callable from trezor.messages import ThpPairingCredential from trezor.wire import WireInterface from .interface_context import ThpContext + from .memory_manager import ThpBuffer from .pairing_context import PairingContext from .session_context import GenericSessionContext class Reassembler: - def __init__(self, cid: int) -> None: + def __init__(self, cid: int, read_buf: ThpBuffer) -> None: self.cid = cid + self.thp_read_buf = read_buf self.reset() def reset(self) -> None: @@ -61,8 +69,6 @@ def handle_packet(self, packet: memoryview) -> bool: Process current packet, returning `True` when a valid message is reassembled. The parsed message can retrieved via the `message` field (if it's not `None`). In case of a checksum error or if the reassembly is not over, return `False`. - - May raise `WireBufferError` if there is a concurrent payload reassembly in progress. """ ctrl_byte = packet[0] if control_byte.is_continuation(ctrl_byte): @@ -70,8 +76,7 @@ def handle_packet(self, packet: memoryview) -> bool: # ignore unexpected continuation packets return False - # may raise WireBufferError - buffer = memory_manager.get_existing_read_buffer(self.cid) + buffer = self.thp_read_buf.get(self.buffer_len) self._buffer_packet_data(buffer, packet, PacketHeader.CONT_LENGTH) else: self.reset() @@ -83,8 +88,7 @@ def handle_packet(self, packet: memoryview) -> bool: buffer = packet[: self.buffer_len] self.bytes_read = len(buffer) else: - # may raise WireBufferError - buffer = memory_manager.get_new_read_buffer(self.cid, self.buffer_len) + buffer = self.thp_read_buf.get(self.buffer_len) self._buffer_packet_data(buffer, packet, 0) assert len(buffer) == self.buffer_len @@ -124,26 +128,27 @@ class Channel: THP protocol encrypted communication channel. """ - def __init__(self, channel_cache: ChannelCache, ctx: ThpContext) -> None: + def __init__( + self, + channel_cache: ChannelCache, + ctx: ThpContext, + buffers: tuple[ThpBuffer, ThpBuffer], + ) -> None: assert ctx._iface.iface_num() == channel_cache.get_int(CHANNEL_IFACE) # Channel properties self.channel_id: bytes = channel_cache.channel_id self.ctx: ThpContext = ctx + self.read_buf, self.write_buf = buffers if __debug__: self._log("channel initialization") self.channel_cache: ChannelCache = channel_cache # Shared variables self.sessions: dict[int, GenericSessionContext] = {} - self.reassembler = Reassembler(self.get_channel_id_int()) - - # Objects for writing a message to a wire - self.transmission_loop: TransmissionLoop | None = None - self.write_task_spawn: loop.spawn | None = None + self.reassembler = Reassembler(self.get_channel_id_int(), self.read_buf) # Temporary objects - self.handshake: crypto.Handshake | None = None self.credential: ThpPairingCredential | None = None self.connection_context: PairingContext | None = None @@ -156,7 +161,6 @@ def iface(self) -> WireInterface: def clear(self) -> None: clear_sessions_with_channel_id(self.channel_id) - memory_manager.release_lock_if_owner(self.get_channel_id_int()) self.channel_cache.clear() # ACCESS TO CHANNEL_DATA @@ -204,36 +208,100 @@ def is_channel_to_replace(self) -> bool: # READ and DECRYPT + async def recv_payload( + self, expected_ctrl_byte: Callable[[int], bool] | None + ) -> memoryview: + """ + Receive and return a valid THP payload from this channel & its control byte. + Also handle ACKs while waiting for the payload. + + Raise if the received control byte is an unexpected one. + + If `expected_ctrl_byte` is `None`, returns after the first received ACK. + """ + while True: + # Handle an existing message (if already reassembled). + # Otherwise, receive and reassemble a new one. + msg = await self._get_reassembled_message() + + # Synchronization process + ctrl_byte = msg[0] + payload = msg[PacketHeader.INIT_LENGTH : -CHECKSUM_LENGTH] + seq_bit = control_byte.get_seq_bit(ctrl_byte) + + # 1: Handle ACKs + if control_byte.is_ack(ctrl_byte): + handle_ack(self, control_byte.get_ack_bit(ctrl_byte)) + if expected_ctrl_byte is None: + return payload + continue + + if expected_ctrl_byte is None or not expected_ctrl_byte(ctrl_byte): + raise ThpError("Unexpected control byte") + + # 2: Handle message with unexpected sequential bit + if seq_bit != ABP.get_expected_receive_seq_bit(self.channel_cache): + if __debug__: + self._log( + "Received message with an unexpected sequential bit", + ) + await send_ack(self, ack_bit=seq_bit) + raise ThpError("Received message with an unexpected sequential bit") + + # 3: Send ACK in response + await send_ack(self, ack_bit=seq_bit) + + ABP.set_expected_receive_seq_bit(self.channel_cache, 1 - seq_bit) + + return payload + + async def _get_reassembled_message(self) -> memoryview: + """Doesn't block if a message has been already reassembled.""" + while self.reassembler.message is None: + # receive and reassemble a new message from this channel + channel = await self.ctx.get_next_message() + if channel is self: + break + + # currently only single-channel sessions are supported during a single event loop run + self._log( + "Ignoring unexpected channel: ", + utils.hexlify_if_bytes(channel.channel_id), + logger=log.warning, + ) + + msg = self.reassembler.message + self.reassembler.reset() # next call will reassemble a new message + assert msg is not None + return msg + def reassemble(self, packet: utils.BufferType) -> bool: """ Process current packet, returning `True` when a valid message is reassembled. The parsed message can retrieved via the `message` field (if it's not `None`). In case of a checksum error or if the reassembly is not over, return `False`. - - May raise `WireBufferError` if there is a concurrent payload reassembly in progress. """ if self.get_channel_state() == ChannelState.UNALLOCATED: return False - try: - return self.reassembler.handle_packet(memoryview(packet)) - except WireBufferError: - self.reassembler.reset() - raise - - def decrypt_buffer( - self, message_length: int, offset: int = PacketHeader.INIT_LENGTH - ) -> None: - buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) - - noise_buffer = memoryview(buffer)[ - offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH - ] - tag = buffer[ - message_length - - CHECKSUM_LENGTH - - TAG_LENGTH : message_length - - CHECKSUM_LENGTH - ] + return self.reassembler.handle_packet(memoryview(packet)) + + async def decrypt_message(self) -> tuple[int, Message]: + """ + Receive, decrypt and return a `(session_id, message)` tuple. + Also handle ACKs while waiting for the message. + """ + payload = await self.recv_payload(control_byte.is_encrypted_transport) + self._decrypt_buffer(payload) + session_id, message_type = ustruct.unpack(">BH", payload) + message = Message( + message_type, + payload[SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH : -TAG_LENGTH], + ) + return (session_id, message) + + def _decrypt_buffer(self, payload: memoryview) -> None: + noise_buffer = payload[:-TAG_LENGTH] + tag = payload[-TAG_LENGTH:] key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) @@ -264,7 +332,6 @@ async def write( self, msg: protobuf.MessageType, session_id: int = 0, - force: bool = False, ) -> None: if __debug__: self._log( @@ -279,93 +346,54 @@ async def write( iface=self.iface, ) - cid = self.get_channel_id_int() msg_size = protobuf.encoded_length(msg) payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + PacketHeader.INIT_LENGTH - buffer = memory_manager.get_new_write_buffer(cid, length) + buffer = self.write_buf.get(length) noise_payload_len = memory_manager.encode_into_buffer(buffer, msg, session_id) - task = self._write_and_encrypt(noise_payload_len=noise_payload_len, force=force) - if task is not None: - await task - - def write_error(self, err_type: int) -> Awaitable[None]: - msg_data = err_type.to_bytes(1, "big") - length = len(msg_data) + CHECKSUM_LENGTH - header = PacketHeader.get_error_header(self.get_channel_id_int(), length) - return self.ctx.write_payload(header, msg_data) - - def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: - self._prepare_write() - self.write_task_spawn = loop.spawn( - self._write_encrypted_payload_loop(ctrl_byte, payload) - ) - - def _write_and_encrypt( - self, - noise_payload_len: int, - force: bool = False, - ) -> Awaitable[None] | None: - buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) - self._encrypt(buffer, noise_payload_len) payload_length = noise_payload_len + TAG_LENGTH - if self.write_task_spawn is not None: - self.write_task_spawn.close() # TODO might break something - if __debug__: - self._log("Closed write task", logger=log.warning) - self._prepare_write() - if force: - if __debug__: - self._log("Writing FORCE message (without async or retransmission).") - - return self._write_encrypted_payload_loop( - ENCRYPTED, memoryview(buffer[:payload_length]) - ) - self.write_task_spawn = loop.spawn( - self._write_encrypted_payload_loop( - ENCRYPTED, memoryview(buffer[:payload_length]) - ) - ) - return None + return await self.write_encrypted_payload(ENCRYPTED, buffer[:payload_length]) - def _prepare_write(self) -> None: - # TODO add condition that disallows to write when can_send_message is false - ABP.set_sending_allowed(self.channel_cache, False) + def write_handshake_message( + self, ctrl_byte: int, payload: bytes + ) -> Awaitable[None]: + return self.write_encrypted_payload(ctrl_byte, payload) - async def _write_encrypted_payload_loop( - self, ctrl_byte: int, payload: bytes, only_once: bool = False - ) -> None: + async def write_encrypted_payload(self, ctrl_byte: int, payload: bytes) -> None: if __debug__: self._log("write_encrypted_payload_loop") + assert ABP.is_sending_allowed(self.channel_cache) + payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = ABP.get_send_seq_bit(self.channel_cache) ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len) - self.transmission_loop = TransmissionLoop(self, header, payload) - if only_once: - if __debug__: - self._log('Starting transmission loop "only once"') - await self.transmission_loop.start(max_retransmission_count=1) - else: - if __debug__: - self._log("Starting transmission loop") - await self.transmission_loop.start() + + # ACK is needed before sending more data + ABP.set_sending_allowed(self.channel_cache, False) + + # TODO implement retransmissions: + # sender = loop.spawn(self._retransmit(header, payload)) # will raise on timeout + # receiver = loop.spawn(self._wait_for_ack()) # will return on success + # await loop.race(sender, receiver) + await self.ctx.write_payload(header, payload) + await self._wait_for_ack() + + # `ABP.set_sending_allowed()` will be called after a valid ACK + assert ABP.is_sending_allowed(self.channel_cache) ABP.set_send_seq_bit_to_opposite(self.channel_cache) - # Let the main loop be restarted and clear loop, if there is no other - # workflow and the state is ENCRYPTED_TRANSPORT - # TODO only once is there to not clear when FALLBACK - # TODO missing transmission loop is active -> do not clear - if not only_once and self._can_clear_loop(): - if __debug__: - self._log("clearing loop from channel") - loop.clear() + async def _wait_for_ack(self) -> None: + # `ABP.set_sending_allowed()` will be called after a valid ACK + while not ABP.is_sending_allowed(self.channel_cache): + # wait and return after receiving an ACK, or raise in case of an unexpected message. + await self.recv_payload(expected_ctrl_byte=None) def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: if __debug__: @@ -389,11 +417,6 @@ def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag - def _can_clear_loop(self) -> bool: - return ( - not workflow.tasks - ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT - if __debug__: def _log(self, text_1: str, text_2: str = "", logger: Any = log.debug) -> None: @@ -405,3 +428,30 @@ def _log(self, text_1: str, text_2: str = "", logger: Any = log.debug) -> None: text_2, iface=self.iface, ) + + +def send_ack(channel: Channel, ack_bit: int) -> Awaitable[None]: + ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) + header = PacketHeader(ctrl_byte, channel.get_channel_id_int(), CHECKSUM_LENGTH) + if __debug__: + log.debug( + __name__, + "Writing ACK message to a channel with cid: %s, ack_bit: %d", + hexlify_if_bytes(channel.channel_id), + ack_bit, + iface=channel.iface, + ) + return channel.ctx.write_payload(header, b"") + + +def handle_ack(ctx: Channel, ack_bit: int) -> None: + if not ABP.is_ack_valid(ctx.channel_cache, ack_bit): + return + # ACK is expected and it has correct sync bit + if __debug__: + log.debug( + __name__, + "Received ACK message with correct ack bit", + iface=ctx.iface, + ) + ABP.set_sending_allowed(ctx.channel_cache, True) diff --git a/core/src/trezor/wire/thp/interface_context.py b/core/src/trezor/wire/thp/interface_context.py index f52a95f5681..b81942917a5 100644 --- a/core/src/trezor/wire/thp/interface_context.py +++ b/core/src/trezor/wire/thp/interface_context.py @@ -5,13 +5,11 @@ from storage.cache_thp import ( BROADCAST_CHANNEL_ID, - ChannelCache, - iter_allocated_channels, + find_allocated_channel, update_channel_last_used, ) from trezor import io, loop, utils -from ..errors import WireBufferError from . import ( CHANNEL_ALLOCATION_REQ, CODEC_V1, @@ -43,19 +41,6 @@ class ThpContext: It also handles and responds to low-level single packet THP messages, creating new channels if needed. """ - @classmethod - def load_from_cache(cls, iface: WireInterface) -> "ThpContext": - ctx = cls(iface) - for channel_cache in iter_allocated_channels(iface.iface_num()): - ctx._load_channel(channel_cache) - return ctx - - def _load_channel(self, cache: ChannelCache) -> Channel: - channel_id = int.from_bytes(cache.channel_id, "big") - assert channel_id not in self._channels - self._channels[channel_id] = channel = Channel(cache, self) - return channel - def __init__(self, iface: WireInterface) -> None: self._iface = iface self._read = loop.wait(iface.iface_num() | io.POLL_READ) @@ -63,6 +48,13 @@ def __init__(self, iface: WireInterface) -> None: self._channels: dict[int, Channel] = {} async def get_next_message(self) -> Channel: + """ + Reassemble a valid THP payload and return its channel. + + Also handle THP channel allocation. + """ + from .. import THP_BUFFERS_PROVIDER + packet = bytearray(self._iface.RX_PACKET_LEN) while True: packet_len = await self._read @@ -76,25 +68,25 @@ async def get_next_message(self) -> Channel: continue cid = ustruct.unpack(">BH", packet)[1] - if cid == BROADCAST_CHANNEL_ID: await self._handle_broadcast(packet) continue - channel = self._channels.get(cid) - if channel is None: - await self._handle_unallocated(cid, packet) + if (cache := find_allocated_channel(cid)) is None: + if not control_byte.is_continuation(_get_ctrl_byte(packet)): + await self.write_error(cid, ThpErrorType.UNALLOCATED_CHANNEL) continue - try: - if channel.reassemble(packet): - update_channel_last_used(channel.channel_id) - # The reassembled message must be handled ASAP without blocking, - # since it may point to the global read buffer. - return channel - except WireBufferError: - await channel.write_error(ThpErrorType.TRANSPORT_BUSY) - continue + if (channel := self._channels.get(cid)) is None: + if (buffers := THP_BUFFERS_PROVIDER.take()) is None: + # concurrent payload reassembly is not supported + await self.write_error(cid, ThpErrorType.TRANSPORT_BUSY) + continue + channel = self._channels[cid] = Channel(cache, self, buffers) + + if channel.reassemble(packet): + update_channel_last_used(channel.channel_id) + return channel def write_payload(self, header: PacketHeader, payload: bytes) -> Awaitable[None]: checksum = crc.crc32(payload, crc.crc32(header.to_bytes())) @@ -148,10 +140,8 @@ async def _handle_broadcast(self, packet: bytes) -> None: raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") channel_cache = channel_manager.create_new_channel(self._iface) - channel = self._load_channel(channel_cache) - response_data = get_channel_allocation_response( - nonce, channel.channel_id, self._iface + nonce, channel_cache.channel_id, self._iface ) response_header = PacketHeader.get_channel_allocation_response_header( len(response_data) + CHECKSUM_LENGTH, @@ -159,18 +149,17 @@ async def _handle_broadcast(self, packet: bytes) -> None: if __debug__: log.debug( __name__, - "New channel allocated with id %d", - channel.get_channel_id_int(), + "New channel allocated with id: %s", + utils.hexlify_if_bytes(channel_cache.channel_id), iface=self._iface, ) await self.write_payload(response_header, response_data) - async def _handle_unallocated(self, cid: int, packet: bytes) -> None: - if control_byte.is_continuation(_get_ctrl_byte(packet)): - return - data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big") - header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) - await self.write_payload(header, data) + def write_error(self, cid: int, err_type: ThpErrorType) -> Awaitable[None]: + msg_data = err_type.to_bytes(1, "big") + length = len(msg_data) + CHECKSUM_LENGTH + header = PacketHeader.get_error_header(cid, length) + return self.write_payload(header, msg_data) def _get_ctrl_byte(packet: bytes) -> int: diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index 2c05026a406..ad730e80910 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -1,154 +1,20 @@ -import utime from micropython import const from storage.cache_thp import SESSION_ID_LENGTH from trezor import protobuf, utils -from trezor.wire.errors import WireBufferError -from . import ThpError -from .writer import MAX_PAYLOAD_LEN, MESSAGE_TYPE_LENGTH +from .writer import MESSAGE_TYPE_LENGTH -if __debug__: - from trezor import log - from trezor.utils import hexlify_if_bytes +_PROTOBUF_BUFFER_SIZE = const(8192) -_PROTOBUF_BUFFER_SIZE = 8192 -READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) -WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) -LOCK_TIMEOUT = 200 # miliseconds +class ThpBuffer: + def __init__(self) -> None: + self.buf = memoryview(bytearray(_PROTOBUF_BUFFER_SIZE)) -lock_owner_cid: int | None = None -lock_time: int = 0 - -READ_BUFFER_SLICE: memoryview | None = None -WRITE_BUFFER_SLICE: memoryview | None = None - -# Buffer types -_READ: int = const(0) -_WRITE: int = const(1) - - -# -# Access to buffer slices - - -def release_lock_if_owner(channel_id: int) -> None: - global lock_owner_cid - if lock_owner_cid == channel_id: - lock_owner_cid = None - - -def get_new_read_buffer(channel_id: int, length: int) -> memoryview: - return _get_new_buffer(_READ, channel_id, length) - - -def get_new_write_buffer(channel_id: int, length: int) -> memoryview: - return _get_new_buffer(_WRITE, channel_id, length) - - -def get_existing_read_buffer(channel_id: int) -> memoryview: - return _get_existing_buffer(_READ, channel_id) - - -def get_existing_write_buffer(channel_id: int) -> memoryview: - return _get_existing_buffer(_WRITE, channel_id) - - -def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview: - if is_locked(): - if not is_owner(channel_id): - if __debug__: - log.debug( - __name__, - "Failed to get new buffer to channel %s. Owner is %s.", - hexlify_if_bytes((channel_id or 0).to_bytes(2, "big")), - hexlify_if_bytes((lock_owner_cid or 0).to_bytes(2, "big")), - ) - raise WireBufferError - update_lock_time() - else: - update_lock(channel_id) - - if buffer_type == _READ: - global READ_BUFFER - buffer = READ_BUFFER - elif buffer_type == _WRITE: - global WRITE_BUFFER - buffer = WRITE_BUFFER - else: - raise ValueError("Unknown buffer_type") - - if length > MAX_PAYLOAD_LEN or length > len(buffer): - raise ThpError("Message is too large") # TODO reword - - if buffer_type == _READ: - global READ_BUFFER_SLICE - READ_BUFFER_SLICE = memoryview(READ_BUFFER)[:length] - return READ_BUFFER_SLICE - - if buffer_type == _WRITE: - global WRITE_BUFFER_SLICE - WRITE_BUFFER_SLICE = memoryview(WRITE_BUFFER)[:length] - return WRITE_BUFFER_SLICE - - raise ValueError("Unknown buffer_type") - - -def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview: - if not is_owner(channel_id): - raise WireBufferError - update_lock_time() - - if buffer_type == _READ: - global READ_BUFFER_SLICE - if READ_BUFFER_SLICE is None: - raise WireBufferError - return READ_BUFFER_SLICE - - if buffer_type == _WRITE: - global WRITE_BUFFER_SLICE - if WRITE_BUFFER_SLICE is None: - raise WireBufferError - return WRITE_BUFFER_SLICE - - raise ValueError("Unknown buffer_type") - - -# -# Buffer locking - - -def is_locked() -> bool: - global lock_owner_cid - global lock_time - - time_diff = utime.ticks_diff(utime.ticks_ms(), lock_time) - return lock_owner_cid is not None and time_diff < LOCK_TIMEOUT - - -def is_owner(channel_id: int) -> bool: - global lock_owner_cid - return lock_owner_cid is not None and lock_owner_cid == channel_id - - -def update_lock(channel_id: int) -> None: - set_owner(channel_id) - update_lock_time() - - -def set_owner(channel_id: int) -> None: - global lock_owner_cid - lock_owner_cid = channel_id - - -def update_lock_time() -> None: - global lock_time - lock_time = utime.ticks_ms() - - -# -# Helper for encoding messages into buffer + def get(self, length: int) -> memoryview: + assert length <= len(self.buf) + return self.buf[:length] def encode_into_buffer( diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 1f0c4a93e71..8520f08c2c5 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -28,7 +28,6 @@ class PairingContext(Context): def __init__(self, channel_ctx: Channel) -> None: super().__init__(channel_ctx.iface, channel_ctx.channel_id, "ThpMessageType") self.channel_ctx: Channel = channel_ctx - self.incoming_message = loop.mailbox() self.nfc_secret: bytes | None = None self.qr_code_secret: bytes | None = None self.code_entry_secret: bytes | None = None @@ -46,8 +45,7 @@ def __init__(self, channel_ctx: Channel) -> None: self.host_name: str | None self.app_name: str | None - async def handle(self) -> None: - next_message: Message | None = None + async def handle(self, next_message: Message | None = None) -> None: while True: try: @@ -55,7 +53,7 @@ async def handle(self) -> None: # If the previous run did not keep an unprocessed message for us, # wait for a new one. try: - message: Message = await self.incoming_message + _, message = await self.channel_ctx.decrypt_message() except protocol_common.WireError as e: if __debug__: log.exception(__name__, e, iface=self.iface) @@ -102,7 +100,7 @@ async def read( iface=self.iface, ) - message: Message = await self.incoming_message + _, message = await self.channel_ctx.decrypt_message() if message.type not in expected_types: from trezor.messages import Cancel @@ -119,11 +117,8 @@ async def read( return message_handler.wrap_protobuf_load(message.data, expected_type) - async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel_ctx.write(msg) - - def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: - return self.channel_ctx.write(msg, force=True) + def write(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.channel_ctx.write(msg) async def call_any( self, msg: protobuf.MessageType, *expected_types: int diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 0e4828cc624..bef6b6598c4 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -1,4 +1,3 @@ -import ustruct from typing import TYPE_CHECKING from storage.cache_common import ( @@ -8,45 +7,31 @@ CHANNEL_NONCE_RECEIVE, CHANNEL_NONCE_SEND, ) -from storage.cache_thp import ( - KEY_LENGTH, - SESSION_ID_LENGTH, - TAG_LENGTH, - update_session_last_used, -) -from trezor import config, loop, protobuf, utils +from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, update_session_last_used +from trezor import config, protobuf, utils from trezor.enums import FailureType from trezor.messages import Failure -from trezor.wire.thp import memory_manager from .. import message_handler from ..errors import DataError -from ..protocol_common import Message from . import ( - ACK_MESSAGE, HANDSHAKE_COMP_RES, HANDSHAKE_INIT_RES, ChannelState, - PacketHeader, SessionState, ThpDecryptionError, ThpDeviceLockedError, ThpError, ThpErrorType, - ThpInvalidDataError, - ThpUnallocatedChannelError, ThpUnallocatedSessionError, + control_byte, + get_encoded_device_properties, + session_manager, ) -from . import alternating_bit_protocol as ABP -from . import control_byte, get_encoded_device_properties, session_manager -from .checksum import CHECKSUM_LENGTH from .crypto import PUBKEY_LENGTH, Handshake from .session_context import SeedlessSessionContext -from .writer import MESSAGE_TYPE_LENGTH if TYPE_CHECKING: - from typing import Awaitable - from trezor.messages import ThpHandshakeCompletionReqNoisePayload from .channel import Channel @@ -60,185 +45,61 @@ _TREZOR_STATE_PAIRED_AUTOCONNECT = b"\x02" -async def handle_received_message( - ctx: Channel, message_buffer: utils.BufferType -) -> None: - """Handle a message received from the channel.""" - - if __debug__: - log.debug(__name__, "handle_received_message", iface=ctx.iface) - # TODO remove after performance tests are done - # try: - # import micropython - - # print("micropython.mem_info() from received_message_handler.py") - # micropython.mem_info() - # print("Allocation count:", micropython.alloc_count()) - # except AttributeError: - # print( - # "To show allocation count, create the build with TREZOR_MEMPERF=1" - # ) - ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer) - message_length = payload_length + PacketHeader.INIT_LENGTH - - # Synchronization process - seq_bit = control_byte.get_seq_bit(ctrl_byte) - ack_bit = control_byte.get_ack_bit(ctrl_byte) - if __debug__: - log.debug( - __name__, - "handle_completed_message - seq bit of message: %d, ack bit of message: %d", - seq_bit, - ack_bit, - iface=ctx.iface, - ) - - # 1: Handle ACKs - if control_byte.is_ack(ctrl_byte): - await handle_ack(ctx, ack_bit) - return - - if _should_have_ctrl_byte_encrypted_transport( - ctx - ) and not control_byte.is_encrypted_transport(ctrl_byte): - raise ThpError("Message is not encrypted. Ignoring") - - # 2: Handle message with unexpected sequential bit - if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache): - if __debug__: - log.debug( - __name__, - "Received message with an unexpected sequential bit", - iface=ctx.iface, - ) - await _send_ack(ctx, ack_bit=seq_bit) - raise ThpError("Received message with an unexpected sequential bit") - - # 3: Send ACK in response - await _send_ack(ctx, ack_bit=seq_bit) - - ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit) +async def handle_received_message(channel: Channel) -> bool: + """ + Handle a message received from the channel. + Returns False if we can restart the event loop. + """ try: - _handle_message_to_app_or_channel( - ctx, payload_length, message_length, ctrl_byte - ) + state = channel.get_channel_state() + if state is ChannelState.ENCRYPTED_TRANSPORT: + await _handle_state_ENCRYPTED_TRANSPORT(channel) + return False + elif _is_channel_state_pairing(state): + await _handle_pairing(channel) + return False + elif state is ChannelState.TH1: + await _handle_state_handshake(channel) + return channel.get_channel_state() == ChannelState.TC1 + else: + raise ThpError("Unimplemented channel state") + except ThpUnallocatedSessionError as e: error_message = Failure(code=FailureType.ThpUnallocatedSession) - await ctx.write(error_message, e.session_id) - except ThpUnallocatedChannelError: - await ctx.write_error(ThpErrorType.UNALLOCATED_CHANNEL) - ctx.clear() + await channel.write(error_message, e.session_id) except ThpDecryptionError: - await ctx.write_error(ThpErrorType.DECRYPTION_FAILED) - ctx.clear() - except ThpInvalidDataError: - await ctx.write_error(ThpErrorType.INVALID_DATA) - ctx.clear() - except ThpDeviceLockedError: - await ctx.write_error(ThpErrorType.DEVICE_LOCKED) - - if __debug__: - log.debug(__name__, "handle_received_message - end", iface=ctx.iface) - - -def _send_ack(channel: Channel, ack_bit: int) -> Awaitable[None]: - ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) - header = PacketHeader(ctrl_byte, channel.get_channel_id_int(), CHECKSUM_LENGTH) - if __debug__: - log.debug( - __name__, - "Writing ACK message to a channel with cid: %s, ack_bit: %d", - hexlify_if_bytes(channel.channel_id), - ack_bit, - iface=channel.iface, + await channel.ctx.write_error( + channel.get_channel_id_int(), ThpErrorType.DECRYPTION_FAILED ) - return channel.ctx.write_payload(header, b"") - - -async def handle_ack(ctx: Channel, ack_bit: int) -> None: - if not ABP.is_ack_valid(ctx.channel_cache, ack_bit): - return - # ACK is expected and it has correct sync bit - if __debug__: - log.debug( - __name__, - "Received ACK message with correct ack bit", - iface=ctx.iface, + channel.clear() + except ThpDeviceLockedError: + await channel.ctx.write_error( + channel.get_channel_id_int(), ThpErrorType.DEVICE_LOCKED ) - if ctx.transmission_loop is not None: - ctx.transmission_loop.stop_immediately() - if __debug__: - log.debug(__name__, "Stopped transmission loop", iface=ctx.iface) - elif __debug__: - log.debug(__name__, "Transmission loop was not stopped!", iface=ctx.iface) - - ABP.set_sending_allowed(ctx.channel_cache, True) - - if ctx.write_task_spawn is not None: - if __debug__: - log.debug( - __name__, - 'Control to "write_encrypted_payload_loop" task', - iface=ctx.iface, - ) - await ctx.write_task_spawn - # Note that no the write_task_spawn could result in loop.clear(), - # which will result in termination of this function - any code after - # this await might not be executed + return False -def _handle_message_to_app_or_channel( +async def _handle_state_handshake( ctx: Channel, - payload_length: int, - message_length: int, - ctrl_byte: int, ) -> None: - state = ctx.get_channel_state() - - if state == ChannelState.ENCRYPTED_TRANSPORT: - return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) - - if state == ChannelState.TH1: - return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte) - - if state == ChannelState.TH2: - return _handle_state_TH2(ctx, message_length, ctrl_byte) - - if _is_channel_state_pairing(state): - return _handle_pairing(ctx, message_length) - - raise ThpError("Unimplemented channel state") + if __debug__: + log.debug(__name__, "handle_state_handshake", iface=ctx.iface) + payload = await ctx.recv_payload(control_byte.is_handshake_init_req) -def _handle_state_TH1( - ctx: Channel, - payload_length: int, - message_length: int, - ctrl_byte: int, -) -> None: - if __debug__: - log.debug(__name__, "handle_state_TH1", iface=ctx.iface) - if not control_byte.is_handshake_init_req(ctrl_byte): - raise ThpError("Message received is not a handshake init request!") - if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: + if len(payload) != PUBKEY_LENGTH: raise ThpError("Message received is not a valid handshake init request!") if not config.is_unlocked(): raise ThpDeviceLockedError - ctx.handshake = Handshake() + handshake = Handshake() - buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) - # if buffer is BufferError: - # pass # TODO buffer is gone :/ - - host_ephemeral_public_key = bytearray( - buffer[PacketHeader.INIT_LENGTH : message_length - CHECKSUM_LENGTH] - ) trezor_ephemeral_public_key, encrypted_trezor_static_public_key, tag = ( - ctx.handshake.handle_th1_crypto( - get_encoded_device_properties(ctx.iface), host_ephemeral_public_key + handshake.handle_th1_crypto( + get_encoded_device_properties(ctx.iface), + host_ephemeral_public_key=payload, ) ) @@ -260,61 +121,31 @@ def _handle_state_TH1( payload = trezor_ephemeral_public_key + encrypted_trezor_static_public_key + tag # send handshake init response message - ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload) - ctx.set_channel_state(ChannelState.TH2) - return + await ctx.write_encrypted_payload(HANDSHAKE_INIT_RES, payload) - -def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None: - from apps.thp.credential_manager import decode_credential, validate_credential - - if __debug__: - log.debug(__name__, "handle_state_TH2", iface=ctx.iface) - if not control_byte.is_handshake_comp_req(ctrl_byte): - raise ThpError("Message received is not a handshake completion request!") - - if ctx.handshake is None: - raise ThpUnallocatedChannelError( - "Handshake object is not prepared. Create new channel." - ) + payload = await ctx.recv_payload(control_byte.is_handshake_comp_req) if not config.is_unlocked(): raise ThpDeviceLockedError - buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) - # if buffer is BufferError: - # pass # TODO handle - host_encrypted_static_public_key = buffer[ - PacketHeader.INIT_LENGTH : PacketHeader.INIT_LENGTH + KEY_LENGTH + TAG_LENGTH - ] - handshake_completion_request_noise_payload = buffer[ - PacketHeader.INIT_LENGTH - + KEY_LENGTH - + TAG_LENGTH : message_length - - CHECKSUM_LENGTH - ] - - ctx.handshake.handle_th2_crypto( + host_encrypted_static_public_key = payload[: KEY_LENGTH + TAG_LENGTH] + handshake_completion_request_noise_payload = payload[KEY_LENGTH + TAG_LENGTH :] + + handshake.handle_th2_crypto( host_encrypted_static_public_key, handshake_completion_request_noise_payload ) - ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive) - ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send) - ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h) + ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, handshake.key_receive) + ctx.channel_cache.set(CHANNEL_KEY_SEND, handshake.key_send) + ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, handshake.h) ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0) ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1) - noise_payload = _decode_message( - buffer[ - PacketHeader.INIT_LENGTH - + KEY_LENGTH - + TAG_LENGTH : message_length - - CHECKSUM_LENGTH - - TAG_LENGTH - ], - 0, - "ThpHandshakeCompletionReqNoisePayload", - ) + buffer = payload[KEY_LENGTH + TAG_LENGTH : -TAG_LENGTH] + + payload_type = protobuf.type_for_name("ThpHandshakeCompletionReqNoisePayload") + noise_payload = message_handler.wrap_protobuf_load(buffer, payload_type) + if TYPE_CHECKING: assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload) @@ -335,6 +166,8 @@ def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None trezor_state = _TREZOR_STATE_UNPAIRED if noise_payload.host_pairing_credential is not None: + from apps.thp.credential_manager import decode_credential, validate_credential + try: # TODO change try-except for something better credential = decode_credential(noise_payload.host_pairing_credential) paired = validate_credential( @@ -352,12 +185,8 @@ def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None pass # send hanshake completion response - ctx.write_handshake_message( - HANDSHAKE_COMP_RES, - ctx.handshake.get_handshake_completion_response(trezor_state), - ) - - ctx.handshake = None + response = handshake.get_handshake_completion_response(trezor_state) + await ctx.write_encrypted_payload(HANDSHAKE_COMP_RES, response) if paired: ctx.set_channel_state(ChannelState.TC1) @@ -365,18 +194,11 @@ def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None ctx.set_channel_state(ChannelState.TP0) -def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None: +async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel) -> None: if __debug__: log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT", iface=ctx.iface) - ctx.decrypt_buffer(message_length) - - buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) - # if buffer is BufferError: - # pass # TODO handle - session_id, message_type = ustruct.unpack( - ">BH", memoryview(buffer)[PacketHeader.INIT_LENGTH :] - ) + session_id, message = await ctx.decrypt_message() if session_id not in ctx.sessions: s = session_manager.get_session_from_cache(ctx, session_id) @@ -385,61 +207,22 @@ def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None s = SeedlessSessionContext(ctx, session_id) ctx.sessions[session_id] = s - loop.schedule(s.handle()) elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED: raise ThpUnallocatedSessionError(session_id) s = ctx.sessions[session_id] update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big")) + await s.handle(message) - s.incoming_message.put( - Message( - message_type, - buffer[ - PacketHeader.INIT_LENGTH - + MESSAGE_TYPE_LENGTH - + SESSION_ID_LENGTH : message_length - - CHECKSUM_LENGTH - - TAG_LENGTH - ], - ) - ) - if __debug__: - log.debug( - __name__, - f"Scheduled message to be handled by a session (session_id: {session_id}, msg_type (int): {message_type})", - iface=ctx.iface, - ) - -def _handle_pairing(ctx: Channel, message_length: int) -> None: +async def _handle_pairing(ctx: Channel) -> None: from .pairing_context import PairingContext - if ctx.connection_context is None: - ctx.connection_context = PairingContext(ctx) - loop.schedule(ctx.connection_context.handle()) - - ctx.decrypt_buffer(message_length) - buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) - # if buffer is BufferError: - # pass # TODO handle - message_type = ustruct.unpack( - ">H", buffer[PacketHeader.INIT_LENGTH + SESSION_ID_LENGTH :] - )[0] - - ctx.connection_context.incoming_message.put( - Message( - message_type, - buffer[ - PacketHeader.INIT_LENGTH - + MESSAGE_TYPE_LENGTH - + SESSION_ID_LENGTH : message_length - - CHECKSUM_LENGTH - - TAG_LENGTH - ], - ) - ) + ctx.connection_context = PairingContext(ctx) + + _session_id, message = await ctx.decrypt_message() + await ctx.connection_context.handle(message) def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool: diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index ad84c7a93d5..0c23bb70a7c 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -3,10 +3,10 @@ from storage import cache_thp from storage.cache_common import InvalidSessionError from storage.cache_thp import SessionThpCache -from trezor import loop, protobuf +from trezor import protobuf from trezor.wire import message_handler, protocol_common from trezor.wire.context import UnexpectedMessageException -from trezor.wire.message_handler import failure +from trezor.wire.message_handler import failure, handle_single_message from ..protocol_common import Context, Message from . import SessionState @@ -34,9 +34,8 @@ def __init__(self, channel: Channel, session_id: int) -> None: super().__init__(channel.iface, channel.channel_id) self.channel: Channel = channel self.session_id: int = session_id - self.incoming_message = loop.mailbox() - async def handle(self) -> None: + async def handle(self, message: Message | None = None) -> None: if __debug__: log.debug( __name__, @@ -46,14 +45,12 @@ async def handle(self) -> None: iface=self.iface, ) - next_message: Message | None = None - while True: - message = next_message - next_message = None try: - await self._handle_message(message) - loop.schedule(self.handle()) + if message is None: + message = await self._read_next_message() + await handle_single_message(self, message) + self.channel._log("session loop is over") return except protocol_common.WireError as e: if __debug__: @@ -62,26 +59,20 @@ async def handle(self) -> None: except UnexpectedMessageException as unexpected: # The workflow was interrupted by an unexpected message. We need to # process it as if it was a new message... - next_message = unexpected.msg + message = unexpected.msg except Exception as exc: # Log and try again. if __debug__: log.exception(__name__, exc, iface=self.iface) - async def _handle_message( - self, - next_message: Message | None, - ) -> None: - - if next_message is not None: - # Process the message from previous run. - message = next_message - next_message = None - else: - # Wait for a new message from wire - message = await self.incoming_message - - await message_handler.handle_single_message(self, message) + async def _read_next_message(self) -> Message: + while True: + session_id, message = await self.channel.decrypt_message() + if session_id == self.session_id: + return message + self.channel._log( + "Ignored message for unexpected session", logger=log.warning + ) async def read( self, @@ -99,7 +90,8 @@ async def read( exp_type, iface=self.iface, ) - message: Message = await self.incoming_message + + message = await self._read_next_message() if message.type not in expected_types: if __debug__: log.debug( @@ -118,11 +110,8 @@ async def read( return message_handler.wrap_protobuf_load(message.data, expected_type) - async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel.write(msg, self.session_id) - - def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: - return self.channel.write(msg, self.session_id, force=True) + def write(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.channel.write(msg, self.session_id) def get_session_state(self) -> SessionState: ... diff --git a/core/src/trezor/wire/thp/transmission_loop.py b/core/src/trezor/wire/thp/transmission_loop.py deleted file mode 100644 index 015b9ba9945..00000000000 --- a/core/src/trezor/wire/thp/transmission_loop.py +++ /dev/null @@ -1,62 +0,0 @@ -from micropython import const -from typing import TYPE_CHECKING - -from trezor import loop - -if TYPE_CHECKING: - from . import PacketHeader - from .channel import Channel - -MAX_RETRANSMISSION_COUNT = const(50) -MIN_RETRANSMISSION_COUNT = const(2) - - -class TransmissionLoop: - - def __init__( - self, channel: Channel, header: PacketHeader, transport_payload: bytes - ) -> None: - self.channel: Channel = channel - self.header: PacketHeader = header - self.transport_payload: bytes = transport_payload - self.wait_task: loop.spawn | None = None - self.min_retransmisson_count_achieved: bool = False - self.stop: bool = False - - async def start( - self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT - ) -> None: - self.min_retransmisson_count_achieved = False - for i in range(max_retransmission_count): - if i >= MIN_RETRANSMISSION_COUNT: - self.min_retransmisson_count_achieved = True - await self.channel.ctx.write_payload(self.header, self.transport_payload) - - # Do not create wait task for last iteration - if i == max_retransmission_count - 1: - break - - # Got ack while we were retransmitting? - if self.stop: - break - - self.wait_task = loop.spawn(self._wait(i)) - try: - await self.wait_task - except loop.TaskClosed: - break - finally: - self.wait_task = None - - def stop_immediately(self) -> None: - if self.wait_task is not None: - self.wait_task.close() - self.wait_task = None - self.stop = True - - async def _wait(self, counter: int = 0) -> None: - timeout_ms = round(10200 - 1010000 / (counter + 100)) - await loop.sleep(timeout_ms) - - def __del__(self) -> None: - self.stop_immediately() diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py index ea165a9c2b9..de6bff16624 100644 --- a/core/tests/test_trezor.wire.thp.py +++ b/core/tests/test_trezor.wire.thp.py @@ -6,7 +6,6 @@ if utils.USE_THP: import thp_common from trezor.wire import handle_session as thp_main_loop - from trezor.wire.thp import memory_manager @unittest.skipUnless(utils.USE_THP, "only needed for THP") @@ -19,8 +18,6 @@ def __init__(self): def setUp(self): self.interface = MockHID() - memory_manager.READ_BUFFER = bytearray(64) - memory_manager.WRITE_BUFFER = bytearray(256) def test_codec_message(self): self.assertEqual(len(self.interface.data), 0) diff --git a/core/tests/thp_common.py b/core/tests/thp_common.py index 22427180f45..85b12da35f3 100644 --- a/core/tests/thp_common.py +++ b/core/tests/thp_common.py @@ -10,6 +10,7 @@ from trezor.wire.thp.channel import Channel from trezor.wire.thp.channel_manager import create_new_channel from trezor.wire.thp.interface_context import ThpContext + from trezor.wire.thp.memory_manager import ThpBuffer from trezor.wire.thp.session_context import SessionContext if TYPE_CHECKING: @@ -21,12 +22,14 @@ def prepare_context() -> None: session_cache = cache_thp.create_or_replace_session( channel_cache, session_id=b"\x01" ) - channel = Channel(channel_cache, ThpContext.load_from_cache(mock_iface)) + channel = Channel( + channel_cache, ThpContext(mock_iface), (ThpBuffer(), ThpBuffer()) + ) context.CURRENT_CONTEXT = SessionContext(channel, session_cache) def get_new_channel(iface: WireInterface) -> Channel: channel_cache = create_new_channel(iface) - return Channel(channel_cache, ThpContext(iface)) + return Channel(channel_cache, ThpContext(iface), (ThpBuffer(), ThpBuffer())) if __debug__: diff --git a/tests/device_tests/thp/test_multiple_hosts.py b/tests/device_tests/thp/test_multiple_hosts.py index 4dd7aa16393..4e145ecc7bf 100644 --- a/tests/device_tests/thp/test_multiple_hosts.py +++ b/tests/device_tests/thp/test_multiple_hosts.py @@ -1,70 +1,27 @@ -import os -from time import sleep - import pytest from trezorlib import exceptions from trezorlib.client import ProtocolV2Channel from trezorlib.debuglink import TrezorClientDebugLink as Client -from ...conftest import LOCK_TIME - pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client] -# LOCK_TIME = 0.2 - - -def _prepare_two_hosts_for_handshake( - client: Client, init_noise: bool = True -) -> tuple[ProtocolV2Channel, ProtocolV2Channel]: - # Sleep for LOCK_TIME - sleep(LOCK_TIME) - - protocol_1 = client.protocol - protocol_1._reset_sync_bits() - protocol_2 = ProtocolV2Channel( - protocol_1.transport, protocol_1.mapping, prepare_channel_without_pairing=False - ) - protocol_2._reset_sync_bits() - - nonce_1 = os.urandom(8) - nonce_2 = os.urandom(8) - if nonce_1 == nonce_2: - nonce_2 = (int.from_bytes(nonce_1) + 1).to_bytes(8, "big") - protocol_1._send_channel_allocation_request(nonce_1) - protocol_1.channel_id, protocol_1.device_properties = ( - protocol_1._read_channel_allocation_response(nonce_1) - ) - protocol_2._send_channel_allocation_request(nonce_2) - protocol_2.channel_id, protocol_2.device_properties = ( - protocol_2._read_channel_allocation_response(nonce_2) - ) - if init_noise: - protocol_1._init_noise() - protocol_2._init_noise() - - return protocol_1, protocol_2 - -def _prepare_two_hosts(client: Client) -> tuple[ProtocolV2Channel, ProtocolV2Channel]: - protocol_1, protocol_2 = _prepare_two_hosts_for_handshake( - client=client, init_noise=False +def _new_channel(client) -> ProtocolV2Channel: + channel = ProtocolV2Channel( + transport=client.transport, + mapping=client.mapping, + credential=None, + prepare_channel_without_pairing=False, ) - protocol_1._do_handshake() - - client.protocol = protocol_1 - client.do_pairing() - sleep(LOCK_TIME) - protocol_2._do_handshake() - client.protocol = protocol_2 - client.do_pairing() - - return protocol_1, protocol_2 + channel._do_channel_allocation() + channel._init_noise() + return channel -def test_concurrent_handshakes_1(client: Client) -> None: - client = client.get_new_client() - protocol_1, protocol_2 = _prepare_two_hosts_for_handshake(client) +def test_concurrent_handshakes(client: Client) -> None: + protocol_1 = _new_channel(client) + protocol_2 = _new_channel(client) # The first host starts handshake protocol_1._send_handshake_init_request() @@ -75,79 +32,14 @@ def test_concurrent_handshakes_1(client: Client) -> None: protocol_2._send_handshake_init_request() # The second host should not be able to interrupt the first host's handshake - # until timeout (LOCK_TIME) has expired with pytest.raises(exceptions.ThpError) as e: protocol_2._read_ack() assert e.value.args[0] == "TRANSPORT BUSY" - # Wait for LOCK_TIME to expire - sleep(LOCK_TIME) - - # The second host retries and finishes handhake successfully - protocol_2._init_noise() - protocol_2._send_handshake_init_request() - protocol_2._read_ack() - protocol_2._read_handshake_init_response() - - protocol_2._send_handshake_completion_request() - protocol_2._read_ack() - protocol_2._read_handshake_completion_response() - - # The second host performs action that results - # in the invalidation of the first host's handshake state - client.protocol = protocol_2 - client.do_pairing() - - # Even after LOCK_TIME passes, the first host's channel cannot - # be resumed - sleep(LOCK_TIME) + # The first host can complete handshake protocol_1._send_handshake_completion_request() protocol_1._read_ack() + protocol_1._read_handshake_completion_response() - with pytest.raises(exceptions.ThpError) as e: - protocol_1._read_handshake_completion_response() - assert e.value.args[0] == "UNALLOCATED CHANNEL" - - -def test_concurrent_handshakes_2(client: Client) -> None: - protocol_1, protocol_2 = _prepare_two_hosts_for_handshake(client) - - # The first host starts handshake - protocol_1._send_handshake_init_request() - protocol_1._read_ack() - protocol_1._read_handshake_init_response() - - # The second host starts handshake - protocol_2._send_handshake_init_request() - - # The second host should not be able to interrupt the first host's handshake - # until timeout (LOCK_TIME) has expired - with pytest.raises(exceptions.ThpError) as e: - protocol_2._read_ack() - assert e.value.args[0] == "TRANSPORT BUSY" - - # Wait for LOCK_TIME to expire - sleep(LOCK_TIME) - - # The second host retries and finishes handhake successfully - protocol_2._init_noise() - protocol_2._send_handshake_init_request() - protocol_2._read_ack() - protocol_2._read_handshake_init_response() - - protocol_2._send_handshake_completion_request() - protocol_2._read_ack() - protocol_2._read_handshake_completion_response() - - # The first host tries to continue handshake immediately after - # the second host finishes it - - protocol_1._send_handshake_completion_request() - - with pytest.raises(exceptions.ThpError) as e: - protocol_1._read_ack() - - # protocol_1._read_handshake_completion_response() - assert e.value.args[0] == "TRANSPORT BUSY" - - # TODO - test ACK fallback, test standard encrypted message fallback + # Now the second handshake can be done + protocol_2._do_handshake() diff --git a/tests/ui_tests/fixtures.json b/tests/ui_tests/fixtures.json index cf54b74cea0..b2bc348e369 100644 --- a/tests/ui_tests/fixtures.json +++ b/tests/ui_tests/fixtures.json @@ -30190,8 +30190,7 @@ "T3W1_cs_thp-test_abp.py::test_abp": "599b7c45309b3e09f382610f3e2198c264fc35aefb4effe7816040e6ec29f91d", "T3W1_cs_thp-test_handshake.py::test_allocate_channel": "535037bfe5f1459cfdf305915835d8bf2a9a427c3f60264a8cc3ca6f306a61b1", "T3W1_cs_thp-test_handshake.py::test_handshake": "599b7c45309b3e09f382610f3e2198c264fc35aefb4effe7816040e6ec29f91d", -"T3W1_cs_thp-test_multiple_hosts.py::test_concurrent_handshakes_1": "a8055674ce15f6440992bcd745f6da0f0a0367b08bf0b90ba8fdf74988a17dea", -"T3W1_cs_thp-test_multiple_hosts.py::test_concurrent_handshakes_2": "535037bfe5f1459cfdf305915835d8bf2a9a427c3f60264a8cc3ca6f306a61b1", +"T3W1_cs_thp-test_multiple_hosts.py::test_concurrent_handshakes": "535037bfe5f1459cfdf305915835d8bf2a9a427c3f60264a8cc3ca6f306a61b1", "T3W1_cs_thp-test_pairing.py::test_autoconnect_credential_request_cancel": "457ab9e14416f52c1bb1ea257c7dd4bd43c781a71ba4400fb700cf376200cd27", "T3W1_cs_thp-test_pairing.py::test_channel_replacement": "2ec597b20ac7bbd981d5b446e41c8ee99530b887232b0e6e050eb543776eef78", "T3W1_cs_thp-test_pairing.py::test_connection_confirmation_cancel": "c22a79b68dd1f57a5db58388cd0344cb5111dbc68afef20e3362853d8ea54f8a", @@ -31650,8 +31649,7 @@ "T3W1_de_thp-test_abp.py::test_abp": "4089127fa2d033a6d9c4a09781f54ce1f92fc1684ce2ffdacc153586685e8392", "T3W1_de_thp-test_handshake.py::test_allocate_channel": "2bf2200d3f158d1cffae639cd632a3b651fe53af5942b65eb881cfce12817592", "T3W1_de_thp-test_handshake.py::test_handshake": "4089127fa2d033a6d9c4a09781f54ce1f92fc1684ce2ffdacc153586685e8392", -"T3W1_de_thp-test_multiple_hosts.py::test_concurrent_handshakes_1": "e8a51f88f32744b43c7b996a50ec865033caa965b23dc70263fd442fe79c7729", -"T3W1_de_thp-test_multiple_hosts.py::test_concurrent_handshakes_2": "2bf2200d3f158d1cffae639cd632a3b651fe53af5942b65eb881cfce12817592", +"T3W1_de_thp-test_multiple_hosts.py::test_concurrent_handshakes": "2bf2200d3f158d1cffae639cd632a3b651fe53af5942b65eb881cfce12817592", "T3W1_de_thp-test_pairing.py::test_autoconnect_credential_request_cancel": "356d3b7c5aed10f8b91d9e8e35d06edf483a25afa2ce9444b442f9810fa6fe8f", "T3W1_de_thp-test_pairing.py::test_channel_replacement": "f6d77878c5385ae01b727db3d1253041ec8b8d5594084a68bb64c4ef9065d5bf", "T3W1_de_thp-test_pairing.py::test_connection_confirmation_cancel": "34b20c5e6fdd46511dce62e10f92ec5b08cd55ad06ba0acef6704ddd931000e4", @@ -33110,8 +33108,7 @@ "T3W1_en_thp-test_abp.py::test_abp": "e8338fb6f6d8dabdb3ba79c521852b3db26cf2d65de359a133074e1858173aba", "T3W1_en_thp-test_handshake.py::test_allocate_channel": "2b19d878184abddf53159d4acb504a1e86a7c2d5fd15de433495742ba7df9cc8", "T3W1_en_thp-test_handshake.py::test_handshake": "e8338fb6f6d8dabdb3ba79c521852b3db26cf2d65de359a133074e1858173aba", -"T3W1_en_thp-test_multiple_hosts.py::test_concurrent_handshakes_1": "89b192fe654bedc8f9d3f11be08e8544abb55ee66ca6d4036242d67039c47401", -"T3W1_en_thp-test_multiple_hosts.py::test_concurrent_handshakes_2": "2b19d878184abddf53159d4acb504a1e86a7c2d5fd15de433495742ba7df9cc8", +"T3W1_en_thp-test_multiple_hosts.py::test_concurrent_handshakes": "2b19d878184abddf53159d4acb504a1e86a7c2d5fd15de433495742ba7df9cc8", "T3W1_en_thp-test_pairing.py::test_autoconnect_credential_request_cancel": "4d40957d3dca48a1e792042d1e5b7fe90200336ddd7df008e66027844eba1fab", "T3W1_en_thp-test_pairing.py::test_channel_replacement": "40cfae080c587feac48b4d4c781c0e4c15edf37fbf4468e53b6ac43b4fd53000", "T3W1_en_thp-test_pairing.py::test_connection_confirmation_cancel": "a920958bb70e5be597f04a644eb529c77a7c0a6ac87fe77ce037d5299526e6da", @@ -34570,8 +34567,7 @@ "T3W1_es_thp-test_abp.py::test_abp": "446c4b55e2b04cf4f12f20dd2a544a139cd1f04db2de732de69f0ab1595d34a9", "T3W1_es_thp-test_handshake.py::test_allocate_channel": "991c2fedae415c4284948d276edc08daeea466ca5849934ed2784cc8884ee589", "T3W1_es_thp-test_handshake.py::test_handshake": "446c4b55e2b04cf4f12f20dd2a544a139cd1f04db2de732de69f0ab1595d34a9", -"T3W1_es_thp-test_multiple_hosts.py::test_concurrent_handshakes_1": "d4fa84772b45b0d5be2443fac20dca5f46108b1955317b23306b4974f571fef1", -"T3W1_es_thp-test_multiple_hosts.py::test_concurrent_handshakes_2": "991c2fedae415c4284948d276edc08daeea466ca5849934ed2784cc8884ee589", +"T3W1_es_thp-test_multiple_hosts.py::test_concurrent_handshakes": "991c2fedae415c4284948d276edc08daeea466ca5849934ed2784cc8884ee589", "T3W1_es_thp-test_pairing.py::test_autoconnect_credential_request_cancel": "cb482a59c593c3c8803c35dc8021e1ae9518a680caeaea4af39813a56eb435fd", "T3W1_es_thp-test_pairing.py::test_channel_replacement": "55b7542cab8a983244e79de523287963e077e5237a4db04f4d41d418e03e1b68", "T3W1_es_thp-test_pairing.py::test_connection_confirmation_cancel": "6f433c176315ce1654a4c675ccc4f20b83facaf78124d5004e91d3cd9b6cdd2b", @@ -36030,8 +36026,7 @@ "T3W1_fr_thp-test_abp.py::test_abp": "4a1b056b6117bfb79336e6e174efff3006017d838bdcf4b16f46fd3ed4628d2e", "T3W1_fr_thp-test_handshake.py::test_allocate_channel": "250f727140e1737e1e25e4620b4557d886c40e53f8205987dacb8904050dc474", "T3W1_fr_thp-test_handshake.py::test_handshake": "4a1b056b6117bfb79336e6e174efff3006017d838bdcf4b16f46fd3ed4628d2e", -"T3W1_fr_thp-test_multiple_hosts.py::test_concurrent_handshakes_1": "ade0f0dfc01ce5f2bf20a7a7c063e11425a7b83ecb7173411aa766f169f461a1", -"T3W1_fr_thp-test_multiple_hosts.py::test_concurrent_handshakes_2": "250f727140e1737e1e25e4620b4557d886c40e53f8205987dacb8904050dc474", +"T3W1_fr_thp-test_multiple_hosts.py::test_concurrent_handshakes": "250f727140e1737e1e25e4620b4557d886c40e53f8205987dacb8904050dc474", "T3W1_fr_thp-test_pairing.py::test_autoconnect_credential_request_cancel": "cd3c958a5b358779765d7d2485ddff48d22fd9a5340d6aca2e48f0e0e939b013", "T3W1_fr_thp-test_pairing.py::test_channel_replacement": "ed61fb08c9858fbb1148660d1228f7d7c497f7a89f9066486d5c7bb37f8ef873", "T3W1_fr_thp-test_pairing.py::test_connection_confirmation_cancel": "7b1db2a82d9528fb3dda6ecfa6de511b2066df3eb157f898e33b5c49c9181767", @@ -37490,8 +37485,7 @@ "T3W1_pt_thp-test_abp.py::test_abp": "c3514f5e9f61efb7bb753a10d9af0f9b932f1b741948bfb1ff576ee557c1ceb1", "T3W1_pt_thp-test_handshake.py::test_allocate_channel": "637658c0c6bbf3267f67895c12aeea91280f0471133d7865e4b0d19bdd1da874", "T3W1_pt_thp-test_handshake.py::test_handshake": "c3514f5e9f61efb7bb753a10d9af0f9b932f1b741948bfb1ff576ee557c1ceb1", -"T3W1_pt_thp-test_multiple_hosts.py::test_concurrent_handshakes_1": "6900b07edb410c88a421f4d79b8e15cb8e1399fcff0bc06826e1f58c8c715de8", -"T3W1_pt_thp-test_multiple_hosts.py::test_concurrent_handshakes_2": "637658c0c6bbf3267f67895c12aeea91280f0471133d7865e4b0d19bdd1da874", +"T3W1_pt_thp-test_multiple_hosts.py::test_concurrent_handshakes": "637658c0c6bbf3267f67895c12aeea91280f0471133d7865e4b0d19bdd1da874", "T3W1_pt_thp-test_pairing.py::test_autoconnect_credential_request_cancel": "475b9204de9b71fa3637f14787bea5b1612ad64789c3663936dc15498b362a56", "T3W1_pt_thp-test_pairing.py::test_channel_replacement": "52dce5f754788e21d59789e6a0c5a3a72b4b1810e44d59d1f368896fe93229a1", "T3W1_pt_thp-test_pairing.py::test_connection_confirmation_cancel": "f5ba7f22b3b5dd1038588d7433d5597318724811a2d208fa206434cbfed45c22", From 5d55e483ee45ac2f20a80db30a644a747b289626 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Sun, 17 Aug 2025 10:35:44 +0300 Subject: [PATCH 3/3] test(core): wait for debuglink task before restarting THP event loop [no changelog] --- core/src/trezor/wire/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index ce54741fd36..99cb8a4d362 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -110,6 +110,10 @@ async def handle_session(iface: WireInterface) -> None: finally: # Wait for all active workflows to finish. await workflow.join_all() + if __debug__: + import apps.debug + + await apps.debug.close_session() loop.clear() else: