Skip to content

A few small THP-related fixups #5612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/mocks/utime.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def ticks_ms() -> int: ...
def ticks_us() -> int: ...
def ticks_cpu() -> int: ...
def ticks_add(ticks_in: int, delta_in: int) -> int: ...
def ticks_diff(old: int, new: int) -> int: ...
def ticks_diff(new: int, old: int) -> int: ...
def gmtime2000(timestamp: int) -> tuple: ...
38 changes: 13 additions & 25 deletions core/src/apps/thp/pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
SilentError,
UnexpectedMessage,
)
from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods
from trezor.wire.thp import (
ChannelState,
ThpError,
crypto,
get_enabled_pairing_methods,
ui,
)
from trezor.wire.thp.pairing_context import PairingContext

from .credential_manager import is_credential_autoconnect, issue_credential
Expand Down Expand Up @@ -112,11 +118,8 @@ async def handle_pairing_request(

ctx.host_name = message.host_name
ctx.app_name = message.app_name
if __debug__ and not ctx.channel_ctx.should_show_pairing_dialog:
await _skip_pairing_dialog(ctx)
else:
await ctx.show_pairing_dialog()
await ctx.write(ThpPairingRequestApproved())
await ctx.show_pairing_dialog()
await ctx.write(ThpPairingRequestApproved())
assert ThpSelectMethod.MESSAGE_WIRE_TYPE is not None
select_method_msg = await ctx.read(
[
Expand Down Expand Up @@ -195,7 +198,7 @@ async def handle_credential_phase(
raise DataError("Missing host/app name in credential")

if show_connection_dialog and not autoconnect:
await ctx.show_connection_dialog()
await ui.show_connection_dialog(ctx.host_name, ctx.app_name)

while ThpCredentialRequest.is_type_of(message):
message = await _handle_credential_request(ctx, message)
Expand Down Expand Up @@ -425,7 +428,9 @@ async def _handle_credential_request(
"Cannot ask for autoconnect credential without a valid credential!"
)

await ctx.show_autoconnect_credential_confirmation_screen() # TODO add device name
await ui.show_autoconnect_credential_confirmation_screen(
host_name=ctx.host_name, app_name=ctx.app_name
)

trezor_static_public_key = crypto.get_trezor_static_public_key()
credential_metadata = ThpCredentialMetadata(
Expand Down Expand Up @@ -476,20 +481,3 @@ def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> N
def _check_method_is_selected(ctx: PairingContext, method: ThpPairingMethod) -> None:
if method is not ctx.selected_method:
raise ThpError("Not selected pairing method")


if __debug__:

async def _skip_pairing_dialog(ctx: PairingContext) -> None:
from trezor.enums import ButtonRequestType
from trezor.messages import ButtonAck, ButtonRequest, ThpPairingRequestApproved
from trezor.wire.errors import ActionCancelled

resp = await ctx.call(
ButtonRequest(code=ButtonRequestType.Other, name="thp_pairing_request"),
expected_type=ButtonAck,
)
if isinstance(resp, ButtonAck):
await ctx.write(ThpPairingRequestApproved())
else:
raise ActionCancelled
3 changes: 0 additions & 3 deletions core/src/trezor/wire/thp/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ def __init__(
self.credential: ThpPairingCredential | None = None
self.connection_context: PairingContext | None = None

if __debug__:
self.should_show_pairing_dialog: bool = True

@property
def iface(self) -> WireInterface:
return self.ctx._iface
Expand Down
8 changes: 0 additions & 8 deletions core/src/trezor/wire/thp/pairing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,6 @@ async def show_pairing_dialog(self) -> None:
action=action_string,
)

async def show_connection_dialog(self) -> None:
await ui.show_connection_dialog(self.host_name, self.app_name)

async def show_autoconnect_credential_confirmation_screen(self) -> None:
await ui.show_autoconnect_credential_confirmation_screen(
self.host_name, self.app_name
)

async def show_pairing_method_screen(
self, selected_method: ThpPairingMethod | None = None
) -> UiResult:
Expand Down
10 changes: 6 additions & 4 deletions core/src/trezor/wire/thp/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ async def handle(self, message: Message | None = None) -> None:
if message is None:
message = await self._read_next_message()
await handle_single_message(self, message)
self.channel._log("session loop is over")
if __debug__:
self.channel._log("session loop is over")
return
except protocol_common.WireError as e:
if __debug__:
Expand All @@ -71,9 +72,10 @@ async def _read_next_message(self) -> Message:
session_id, message = await self.channel.decrypt_message()
if session_id == self.session_id:
return message
self.channel._log(
"Ignored message for unexpected session", logger=log.warning
)
if __debug__:
self.channel._log(
"Ignored message for unexpected session", logger=log.warning
)

async def read(
self,
Expand Down
2 changes: 1 addition & 1 deletion python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def session_context(
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
except exceptions.DeviceLockedException:
except exceptions.DeviceLocked:
click.echo(
"Device is locked, enter a pin on the device.",
err=True,
Expand Down
16 changes: 2 additions & 14 deletions python/src/trezorlib/debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@

from . import btc, mapping, messages, models, protobuf
from .client import ProtocolVersion, TrezorClient
from .exceptions import (
Cancelled,
DeviceLockedException,
TrezorFailure,
UnexpectedMessageError,
)
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES
from .messages import DebugTouchEventType, DebugWaitType
from .tools import parse_path
Expand Down Expand Up @@ -1317,14 +1312,7 @@ def get_pin(_msg: messages.PinMatrixRequest) -> str:
self.pin_callback = get_pin
self.button_callback = self.ui.button_request

try:
super().__init__(transport)
except DeviceLockedException:
LOG.debug("Locked device handling")
self.debug.input("")
self.debug.input(self.debug.encode_pin("1234"))
super().__init__(transport)

super().__init__(transport)
self.sync_responses()

# So that we can choose right screenshotting logic (T1 vs TT)
Expand Down
26 changes: 23 additions & 3 deletions python/src/trezorlib/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,33 @@ class DerivationOnUninitaizedDeviceError(TrezorException):
To communicate with uninitialized device, use seedless session instead."""


class DeviceLockedException(TrezorException):
class UnexpectedCodeEntryTagException(TrezorException):
pass


class UnexpectedCodeEntryTagException(TrezorException):
class ThpError(TrezorException):
pass


class ThpError(TrezorException):
class TransportBusy(ThpError):
pass


class UnallocatedChannel(ThpError):
pass


class DecryptionFailed(ThpError):
pass


class InvalidData(ThpError):
pass


class DeviceLocked(ThpError):
pass


class ThpUnknownError(ThpError):
pass
35 changes: 11 additions & 24 deletions python/src/trezorlib/transport/thp/protocol_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,6 @@ def _send_handshake_init_request(self) -> None:

def _read_handshake_init_response(self) -> bytes:
header, payload = self._read_until_valid_crc_check()
if control_byte.is_error(header.ctrl_byte):
if payload == b"\x05":
raise exceptions.DeviceLockedException()
else:
raise exceptions.ThpError(_get_error_from_int(payload[0]))

if not header.is_handshake_init_response():
LOG.error("Received message is not a valid handshake init response message")
Expand Down Expand Up @@ -278,8 +273,6 @@ def _read_handshake_completion_response(self) -> int:
header, data = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
LOG.error("Received message is not a valid handshake completion response")
if control_byte.is_error(header.ctrl_byte):
raise exceptions.ThpError(_get_error_from_int(data[0]))
trezor_state = self._noise.decrypt(bytes(data))
assert trezor_state == b"\x00" or trezor_state == b"\x01"
self._send_ack_bit(bit=1)
Expand All @@ -289,8 +282,6 @@ def _read_ack(self):
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
LOG.error("Received message is not a valid ACK")
if control_byte.is_error(header.ctrl_byte):
raise exceptions.ThpError(_get_error_from_int(payload[0]))

def _send_ack_bit(self, bit: int):
if bit not in (0, 1):
Expand Down Expand Up @@ -336,8 +327,6 @@ def read_and_decrypt(
continue
if control_byte.is_ack(header.ctrl_byte):
continue
if control_byte.is_error(header.ctrl_byte):
raise exceptions.ThpError(_get_error_from_int(raw_payload[0]))
if not header.is_encrypted_transport():
LOG.error(
"Trying to decrypt not encrypted message! ("
Expand Down Expand Up @@ -392,6 +381,10 @@ def _read_until_valid_crc_check(

self.sync_bit_receive = 1 - self.sync_bit_receive

if control_byte.is_error(header.ctrl_byte):
code = payload[0]
raise _ERRORS_MAP.get(code) or exceptions.ThpUnknownError(code)

return header, payload

def _is_valid_channel_allocation_response(
Expand Down Expand Up @@ -420,16 +413,10 @@ def _is_valid_pong(
return True


def _get_error_from_int(error_code: int) -> str:
# TODO FIXME improve this (ThpErrorType)
if error_code == 1:
return "TRANSPORT BUSY"
if error_code == 2:
return "UNALLOCATED CHANNEL"
if error_code == 3:
return "DECRYPTION FAILED"
if error_code == 4:
return "INVALID DATA"
if error_code == 5:
return "DEVICE LOCKED"
raise Exception("Not Implemented error case")
_ERRORS_MAP = {
1: exceptions.TransportBusy,
2: exceptions.UnallocatedChannel,
3: exceptions.DecryptionFailed,
4: exceptions.InvalidData,
5: exceptions.DeviceLocked,
}
5 changes: 2 additions & 3 deletions tests/device_tests/thp/test_multiple_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def test_concurrent_handshakes(client: Client) -> None:
# The second host starts handshake
protocol_2._send_handshake_init_request()

# The second host should not be able to interrupt the first host's handshake
with pytest.raises(exceptions.ThpError) as e:
# The second host should not be able to interrupt the first host's handshake immediately
with pytest.raises(exceptions.TransportBusy):
protocol_2._read_ack()
assert e.value.args[0] == "TRANSPORT BUSY"

# The first host can complete handshake
protocol_1._send_handshake_completion_request()
Expand Down
6 changes: 2 additions & 4 deletions tests/device_tests/thp/test_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,8 @@ def test_credential_phase(client: Client) -> None:
protocol._noise.noise_protocol.cipher_state_encrypt.n = 250

protocol._send_message(ButtonAck())
with pytest.raises(exceptions.ThpError) as e:
with pytest.raises(exceptions.DecryptionFailed):
protocol.read(1)
assert e.value.args[0] == "DECRYPTION FAILED"

# Connect using credential with confirmation and ask for autoconnect credential.
protocol = prepare_protocol_for_handshake(client)
Expand Down Expand Up @@ -457,9 +456,8 @@ def test_credential_phase(client: Client) -> None:
protocol._noise.noise_protocol.cipher_state_encrypt.n = 100

protocol._send_message(ButtonAck())
with pytest.raises(exceptions.ThpError) as e:
with pytest.raises(exceptions.DecryptionFailed):
protocol.read(1)
assert e.value.args[0] == "DECRYPTION FAILED"

# Connect using autoconnect credential - should work the same as above
protocol = prepare_protocol_for_handshake(client)
Expand Down