Skip to content

Commit e98f233

Browse files
CoolCat467jakkdl
andcommitted
Suggestions from code review
Co-authored-by: jakkdl <h6+github@pm.me>
1 parent ce73cb1 commit e98f233

File tree

6 files changed

+61
-64
lines changed

6 files changed

+61
-64
lines changed

autobahn/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def get_case_count(url: str) -> int:
2626
return int(case_count)
2727

2828

29-
async def get_case_info(url: str, case: str) -> Any:
29+
async def get_case_info(url: str, case: str) -> object:
3030
url = f'{url}/getCaseInfo?case={case}'
3131
async with open_websocket_url(url) as conn:
3232
return json.loads(await conn.get_message())
@@ -63,7 +63,10 @@ async def run_tests(args: argparse.Namespace) -> None:
6363
test_cases = list(range(1, case_count + 1))
6464
exception_cases = []
6565
for case in test_cases:
66-
case_id = (await get_case_info(args.url, case))['id']
66+
result = await get_case_info(args.url, case)
67+
assert isinstance(result, dict)
68+
case_id = result['id']
69+
assert isinstance(case_id, int)
6770
if case_count:
6871
logger.info("Running test case %s (%d of %d)", case_id, case, case_count)
6972
else:

mypy.ini

Lines changed: 0 additions & 28 deletions
This file was deleted.

pyproject.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[tool.mypy]
2+
explicit_package_bases = true
3+
files = ["trio_websocket", "tests", "autobahn", "examples"]
4+
show_column_numbers = true
5+
show_error_codes = true
6+
show_traceback = true
7+
disallow_any_decorated = true
8+
disallow_any_unimported = true
9+
ignore_missing_imports = true
10+
local_partial_types = true
11+
no_implicit_optional = true
12+
strict = true
13+
warn_unreachable = true

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
'Programming Language :: Python :: 3.12',
3636
'Programming Language :: Python :: Implementation :: CPython',
3737
'Programming Language :: Python :: Implementation :: PyPy',
38+
'Typing :: Typed',
3839
],
3940
python_requires=">=3.8",
4041
keywords='websocket client server trio',

tests/test_connection.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from functools import partial, wraps
4040
from typing import TYPE_CHECKING, TypeVar, cast
4141
from unittest.mock import Mock, patch
42+
from importlib.metadata import version
4243

4344
import attr
4445
import pytest
@@ -89,9 +90,11 @@
8990
from collections.abc import Awaitable, Callable
9091
from wsproto.events import Event
9192

92-
from typing_extensions import ParamSpec
93+
from typing_extensions import ParamSpec, TypeAlias
9394
PS = ParamSpec("PS")
9495

96+
StapledMemoryStream: TypeAlias = trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]
97+
9598
WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.')))
9699

97100
HOST = '127.0.0.1'
@@ -116,6 +119,8 @@ async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer,
116119
serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0,
117120
ssl_context=None)
118121
server = await nursery.start(serve_fn)
122+
# Cast needed because currently `nursery.start` has typing issues
123+
# blocked by https://github.com/python/mypy/pull/17512
119124
yield cast(WebSocketServer, server)
120125

121126

@@ -147,37 +152,28 @@ def __init__(self, seconds: int) -> None:
147152
self._seconds = seconds
148153

149154
def __call__(self, fn: Callable[PS, Awaitable[T]]) -> Callable[PS, Awaitable[T | None]]:
155+
# Type of decorated function contains type `Any`
150156
@wraps(fn)
151-
async def wrapper(*args: PS.args, **kwargs: PS.kwargs) -> T | None:
152-
result: T | None = None
157+
async def wrapper( # type: ignore[misc]
158+
*args: PS.args,
159+
**kwargs: PS.kwargs,
160+
) -> T:
153161
with trio.move_on_after(self._seconds) as cancel_scope:
154-
result = await fn(*args, **kwargs)
162+
return await fn(*args, **kwargs)
155163
if cancel_scope.cancelled_caught:
156164
pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds')
157-
return result
165+
raise AssertionError("Should be unreachable")
158166
return wrapper
159167

160168

161169
@attr.s(hash=False, eq=False)
162-
class MemoryListener(
163-
trio.abc.Listener[
164-
"trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]"
165-
]
166-
):
170+
class MemoryListener(trio.abc.Listener["StapledMemoryStream"]):
167171
closed: bool = attr.ib(default=False)
168-
accepted_streams: list[
169-
trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]
170-
] = attr.ib(factory=list)
172+
accepted_streams: list[StapledMemoryStream] = attr.ib(factory=list)
171173
queued_streams: tuple[
172-
trio.MemorySendChannel[
173-
trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]
174-
],
175-
trio.MemoryReceiveChannel[
176-
trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]
177-
],
178-
] = attr.ib(factory=lambda: trio.open_memory_channel[
179-
"trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]"
180-
](1))
174+
trio.MemorySendChannel[StapledMemoryStream],
175+
trio.MemoryReceiveChannel[StapledMemoryStream],
176+
] = attr.ib(factory=lambda: trio.open_memory_channel["StapledMemoryStream"](1))
181177
accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None)
182178

183179
async def connect(self) -> trio.StapledStream[
@@ -385,8 +381,11 @@ async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None:
385381
assert conn.path == RESOURCE + '/' + path
386382

387383

384+
# Type ignore because @patch contains `Any`
388385
@patch('trio_websocket._impl.open_websocket')
389-
def test_client_open_url_options(open_websocket_mock: Mock) -> None:
386+
def test_client_open_url_options( # type: ignore[misc]
387+
open_websocket_mock: Mock,
388+
) -> None:
390389
"""open_websocket_url() must pass its options on to open_websocket()"""
391390
port = 1234
392391
url = f'ws://{HOST}:{port}{RESOURCE}'
@@ -618,7 +617,7 @@ async def handler(request: WebSocketRequest) -> None:
618617
assert exc_info.value.__context__ is user_cancelled_context
619618

620619
def _trio_default_non_strict_exception_groups() -> bool:
621-
version = trio.__version__ # type: ignore[attr-defined]
620+
version = version("trio")
622621
assert re.match(r'^0\.\d\d\.', version), "unexpected trio versioning scheme"
623622
return int(version[2:4]) < 25
624623

trio_websocket/_impl.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import struct
1313
import urllib.parse
1414
from typing import Any, List, NoReturn, Optional, Union, TypeVar, TYPE_CHECKING, Generic, cast
15+
from importlib.metadata import version
1516

1617
import outcome
1718
import trio
@@ -38,22 +39,21 @@
3839

3940
if TYPE_CHECKING:
4041
from types import TracebackType
42+
from typing_extensions import Final
4143
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Coroutine, Sequence
4244

43-
_IS_TRIO_MULTI_ERROR = tuple(
44-
map(int, trio.__version__.split(".")[:2]) # type: ignore[attr-defined]
45-
) < (0, 22)
45+
_IS_TRIO_MULTI_ERROR: Final = tuple(map(int, version("trio").split(".")[:2])) < (0, 22)
4646

4747
if _IS_TRIO_MULTI_ERROR:
4848
_TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member
4949
else:
5050
_TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment
5151

52-
CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds
53-
MESSAGE_QUEUE_SIZE = 1
54-
MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB
55-
RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB
56-
logger = logging.getLogger('trio-websocket')
52+
CONN_TIMEOUT: Final = 60 # default connect & disconnect timeout, in seconds
53+
MESSAGE_QUEUE_SIZE: Final = 1
54+
MAX_MESSAGE_SIZE: Final = 2 ** 20 # 1 MiB
55+
RECEIVE_BYTES: Final = 4 * 2 ** 10 # 4 KiB
56+
logger: Final = logging.getLogger('trio-websocket')
5757

5858
T = TypeVar("T")
5959
E = TypeVar("E", bound=BaseException)
@@ -770,11 +770,16 @@ def __repr__(self) -> str:
770770
f'<code={self.code}, name={self.name}, reason={self.reason}>'
771771

772772

773+
NULL: Final = object()
774+
775+
773776
class Future(Generic[T]):
774777
''' Represents a value that will be available in the future. '''
775778
def __init__(self) -> None:
776779
''' Constructor. '''
777-
self._value: T | None = None
780+
# We do some type shenanigins
781+
# Would do `T | Literal[NULL]` but that's not right apparently.
782+
self._value: T = cast(T, NULL)
778783
self._value_event = trio.Event()
779784

780785
def set_value(self, value: T) -> None:
@@ -793,7 +798,8 @@ async def wait_value(self) -> T:
793798
:returns: The value set by ``set_value()``.
794799
'''
795800
await self._value_event.wait()
796-
return cast(T, self._value)
801+
assert self._value is not NULL
802+
return self._value
797803

798804

799805
class WebSocketRequest:
@@ -1509,6 +1515,9 @@ async def _reader_task(self) -> None:
15091515
handler = handlers[event_type]
15101516
logger.debug('%s received event: %s', self,
15111517
event_type)
1518+
# Type checkers don't understand looking up type in handlers.
1519+
# If we wanted to fix, best I can figure is we'd need a huge
1520+
# if-else or case block for every type individually.
15121521
await handler(event) # type: ignore[operator]
15131522
except KeyError:
15141523
logger.warning('%s received unknown event type: "%s"', self,

0 commit comments

Comments
 (0)