Skip to content

[WIP] feat(core): reimplement stale THP channel interruption #5591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/mocks/utime.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def ticks_ms() -> int: ...
def ticks_us() -> int: ...
def ticks_cpu() -> int: ...
def ticks_add(ticks_in: int, delta_in: int) -> int: ...
def ticks_diff(old: int, new: int) -> int: ...
def ticks_diff(new: int, old: int) -> int: ...
def gmtime2000(timestamp: int) -> tuple: ...
5 changes: 4 additions & 1 deletion core/src/trezor/wire/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ async def handle_session(iface: WireInterface) -> None:
if __debug__:
_THP_CHANNELS.append(ctx._channels)
try:
channel = await ctx.get_next_message()
while (channel := await ctx.get_next_message()) is None:
# If there is an active THP channel on another interface,
# wait for event loop restart, while responding to the host.
pass
while await received_message_handler.handle_received_message(channel):
pass
finally:
Expand Down
6 changes: 5 additions & 1 deletion core/src/trezor/wire/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
T = TypeVar("T")


class UnexpectedMessageException(Exception):
class AbortWorkflow(Exception):
pass


class UnexpectedMessageException(AbortWorkflow):
"""A message was received that is not part of the current workflow.

Utility exception to inform the session handler that the current workflow
Expand Down
7 changes: 6 additions & 1 deletion core/src/trezor/wire/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from trezor.enums import FailureType
from trezor.messages import Failure

from .context import UnexpectedMessageException, with_context
from .context import AbortWorkflow, UnexpectedMessageException, with_context
from .errors import ActionCancelled, DataError, Error, UnexpectedMessage
from .protocol_common import Context, Message

Expand Down Expand Up @@ -153,6 +153,11 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise

except AbortWorkflow:
# Aborting current workflow, to restart the event loop.
return False

except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
Expand Down
20 changes: 20 additions & 0 deletions core/src/trezor/wire/thp/channel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ustruct
import utime
from micropython import const
from typing import TYPE_CHECKING

Expand All @@ -23,6 +24,7 @@
from trezor import protobuf, utils, workflow
from trezor.loop import Timeout

from ..context import AbortWorkflow
from ..protocol_common import Message
from . import (
ACK_MESSAGE,
Expand Down Expand Up @@ -58,6 +60,8 @@
_MAX_RETRANSMISSION_COUNT = const(50)
_MIN_RETRANSMISSION_COUNT = const(2)

_STALE_CHANNEL_TIMEOUT_MS = const(1000)


class Reassembler:
def __init__(self, cid: int, read_buf: ThpBuffer) -> None:
Expand Down Expand Up @@ -270,12 +274,28 @@ async def _get_reassembled_message(
self, timeout_ms: int | None = None
) -> memoryview:
"""Doesn't block if a message has been already reassembled."""
ping_sent_ms: int | None = None
while self.reassembler.message is None:
# receive and reassemble a new message from this channel
channel = await self.ctx.get_next_message(timeout_ms=timeout_ms)
if channel is self:
break

# interrupted by a different channel
if ping_sent_ms is None:
ping_sent_ms = utime.ticks_ms()
continue

elapsed = utime.ticks_diff(utime.ticks_ms(), ping_sent_ms)
if elapsed < _STALE_CHANNEL_TIMEOUT_MS:
continue

if __debug__:
self._log(
f"Interrupting stale channel after {elapsed} ms", logger=log.warning
)
raise AbortWorkflow

# currently only single-channel sessions are supported during a single event loop run
self._log(
"Ignoring unexpected channel: ",
Expand Down
9 changes: 6 additions & 3 deletions core/src/trezor/wire/thp/interface_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def __init__(self, iface: WireInterface) -> None:
self._write = loop.wait(iface.iface_num() | io.POLL_WRITE)
self._channels: dict[int, Channel] = {}

async def get_next_message(self, timeout_ms: int | None = None) -> Channel:
async def get_next_message(self, timeout_ms: int | None = None) -> Channel | None:
"""
Reassemble a valid THP payload and return its channel.

Also handle THP channel allocation.
Return `None` if there is already another active THP channel.
"""
from .. import THP_BUFFERS_PROVIDER

Expand Down Expand Up @@ -80,9 +81,11 @@ async def get_next_message(self, timeout_ms: int | None = None) -> Channel:

if (channel := self._channels.get(cid)) is None:
if (buffers := THP_BUFFERS_PROVIDER.take()) is None:
# concurrent payload reassembly is not supported
# Concurrent payload reassembly is not supported:
# - Notify the new channel, so it will retry later.
# - Interrupt this method, to detect whether the existing channel is still alive.
await self.write_error(cid, ThpErrorType.TRANSPORT_BUSY)
continue
return None
channel = self._channels[cid] = Channel(cache, self, buffers)

if channel.reassemble(packet):
Expand Down
4 changes: 3 additions & 1 deletion core/src/trezor/wire/thp/pairing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from trezor import loop, protobuf, workflow
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.context import AbortWorkflow, UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, DataError, SilentError
from trezor.wire.protocol_common import Context, Message
from trezor.wire.thp import ChannelState, get_enabled_pairing_methods, ui
Expand Down Expand Up @@ -234,6 +234,8 @@ async def handle_message(
# We might handle only the few common cases here, like
# Initialize and Cancel.
return exc.msg
except AbortWorkflow:
return None
except SilentError as exc:
if __debug__:
log.error(__name__, "SilentError: %s", exc.message, iface=pairing_ctx.iface)
Expand Down
10 changes: 6 additions & 4 deletions core/src/trezor/wire/thp/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ async def handle(self, message: Message | None = None) -> None:
if message is None:
message = await self._read_next_message()
await handle_single_message(self, message)
self.channel._log("session loop is over")
if __debug__:
self.channel._log("session loop is over")
return
except protocol_common.WireError as e:
if __debug__:
Expand All @@ -70,9 +71,10 @@ async def _read_next_message(self) -> Message:
session_id, message = await self.channel.decrypt_message()
if session_id == self.session_id:
return message
self.channel._log(
"Ignored message for unexpected session", logger=log.warning
)
if __debug__:
self.channel._log(
"Ignored message for unexpected session", logger=log.warning
)

async def read(
self,
Expand Down
2 changes: 1 addition & 1 deletion python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def session_context(
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
except exceptions.DeviceLockedException:
except exceptions.DeviceLocked:
click.echo(
"Device is locked, enter a pin on the device.",
err=True,
Expand Down
16 changes: 2 additions & 14 deletions python/src/trezorlib/debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@

from . import btc, mapping, messages, models, protobuf
from .client import ProtocolVersion, TrezorClient
from .exceptions import (
Cancelled,
DeviceLockedException,
TrezorFailure,
UnexpectedMessageError,
)
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES
from .messages import DebugTouchEventType, DebugWaitType
from .tools import parse_path
Expand Down Expand Up @@ -1317,14 +1312,7 @@ def get_pin(_msg: messages.PinMatrixRequest) -> str:
self.pin_callback = get_pin
self.button_callback = self.ui.button_request

try:
super().__init__(transport)
except DeviceLockedException:
LOG.debug("Locked device handling")
self.debug.input("")
self.debug.input(self.debug.encode_pin("1234"))
super().__init__(transport)

super().__init__(transport)
self.sync_responses()

# So that we can choose right screenshotting logic (T1 vs TT)
Expand Down
26 changes: 23 additions & 3 deletions python/src/trezorlib/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,33 @@ class DerivationOnUninitaizedDeviceError(TrezorException):
To communicate with uninitialized device, use seedless session instead."""


class DeviceLockedException(TrezorException):
class UnexpectedCodeEntryTagException(TrezorException):
pass


class UnexpectedCodeEntryTagException(TrezorException):
class ThpError(TrezorException):
pass


class ThpError(TrezorException):
class TransportBusy(ThpError):
pass


class UnallocatedChannel(ThpError):
pass


class DecryptionFailed(ThpError):
pass


class InvalidData(ThpError):
pass


class DeviceLocked(ThpError):
pass


class ThpUnknownError(ThpError):
pass
35 changes: 11 additions & 24 deletions python/src/trezorlib/transport/thp/protocol_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,6 @@ def _send_handshake_init_request(self) -> None:

def _read_handshake_init_response(self) -> bytes:
header, payload = self._read_until_valid_crc_check()
if control_byte.is_error(header.ctrl_byte):
if payload == b"\x05":
raise exceptions.DeviceLockedException()
else:
raise exceptions.ThpError(_get_error_from_int(payload[0]))

if not header.is_handshake_init_response():
LOG.error("Received message is not a valid handshake init response message")
Expand Down Expand Up @@ -278,8 +273,6 @@ def _read_handshake_completion_response(self) -> int:
header, data = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
LOG.error("Received message is not a valid handshake completion response")
if control_byte.is_error(header.ctrl_byte):
raise exceptions.ThpError(_get_error_from_int(data[0]))
trezor_state = self._noise.decrypt(bytes(data))
assert trezor_state == b"\x00" or trezor_state == b"\x01"
self._send_ack_bit(bit=1)
Expand All @@ -289,8 +282,6 @@ def _read_ack(self):
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
LOG.error("Received message is not a valid ACK")
if control_byte.is_error(header.ctrl_byte):
raise exceptions.ThpError(_get_error_from_int(payload[0]))

def _send_ack_bit(self, bit: int):
if bit not in (0, 1):
Expand Down Expand Up @@ -336,8 +327,6 @@ def read_and_decrypt(
continue
if control_byte.is_ack(header.ctrl_byte):
continue
if control_byte.is_error(header.ctrl_byte):
raise exceptions.ThpError(_get_error_from_int(raw_payload[0]))
if not header.is_encrypted_transport():
LOG.error(
"Trying to decrypt not encrypted message! ("
Expand Down Expand Up @@ -392,6 +381,10 @@ def _read_until_valid_crc_check(

self.sync_bit_receive = 1 - self.sync_bit_receive

if control_byte.is_error(header.ctrl_byte):
code = payload[0]
raise _ERRORS_MAP.get(code) or exceptions.ThpUnknownError(code)

return header, payload

def _is_valid_channel_allocation_response(
Expand Down Expand Up @@ -420,16 +413,10 @@ def _is_valid_pong(
return True


def _get_error_from_int(error_code: int) -> str:
# TODO FIXME improve this (ThpErrorType)
if error_code == 1:
return "TRANSPORT BUSY"
if error_code == 2:
return "UNALLOCATED CHANNEL"
if error_code == 3:
return "DECRYPTION FAILED"
if error_code == 4:
return "INVALID DATA"
if error_code == 5:
return "DEVICE LOCKED"
raise Exception("Not Implemented error case")
_ERRORS_MAP = {
1: exceptions.TransportBusy,
2: exceptions.UnallocatedChannel,
3: exceptions.DecryptionFailed,
4: exceptions.InvalidData,
5: exceptions.DeviceLocked,
}
5 changes: 2 additions & 3 deletions tests/device_tests/thp/test_multiple_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def test_concurrent_handshakes(client: Client) -> None:
# The second host starts handshake
protocol_2._send_handshake_init_request()

# The second host should not be able to interrupt the first host's handshake
with pytest.raises(exceptions.ThpError) as e:
# The second host should not be able to interrupt the first host's handshake immediately
with pytest.raises(exceptions.TransportBusy):
protocol_2._read_ack()
assert e.value.args[0] == "TRANSPORT BUSY"

# The first host can complete handshake
protocol_1._send_handshake_completion_request()
Expand Down
6 changes: 2 additions & 4 deletions tests/device_tests/thp/test_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,8 @@ def test_credential_phase(client: Client) -> None:
protocol._noise.noise_protocol.cipher_state_encrypt.n = 250

protocol._send_message(ButtonAck())
with pytest.raises(exceptions.ThpError) as e:
with pytest.raises(exceptions.DecryptionFailed):
protocol.read(1)
assert e.value.args[0] == "DECRYPTION FAILED"

# Connect using credential with confirmation and ask for autoconnect credential.
protocol = prepare_protocol_for_handshake(client)
Expand Down Expand Up @@ -457,9 +456,8 @@ def test_credential_phase(client: Client) -> None:
protocol._noise.noise_protocol.cipher_state_encrypt.n = 100

protocol._send_message(ButtonAck())
with pytest.raises(exceptions.ThpError) as e:
with pytest.raises(exceptions.DecryptionFailed):
protocol.read(1)
assert e.value.args[0] == "DECRYPTION FAILED"

# Connect using autoconnect credential - should work the same as above
protocol = prepare_protocol_for_handshake(client)
Expand Down