39
39
from functools import partial , wraps
40
40
from typing import TYPE_CHECKING , TypeVar , cast
41
41
from unittest .mock import Mock , patch
42
+ from importlib .metadata import version
42
43
43
44
import attr
44
45
import pytest
89
90
from collections .abc import Awaitable , Callable
90
91
from wsproto .events import Event
91
92
92
- from typing_extensions import ParamSpec
93
+ from typing_extensions import ParamSpec , TypeAlias
93
94
PS = ParamSpec ("PS" )
94
95
96
+ StapledMemoryStream : TypeAlias = trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
97
+
95
98
WS_PROTO_VERSION = tuple (map (int , wsproto .__version__ .split ('.' )))
96
99
97
100
HOST = '127.0.0.1'
@@ -116,6 +119,8 @@ async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer,
116
119
serve_fn = partial (serve_websocket , echo_request_handler , HOST , 0 ,
117
120
ssl_context = None )
118
121
server = await nursery .start (serve_fn )
122
+ # Cast needed because currently `nursery.start` has typing issues
123
+ # blocked by https://github.com/python/mypy/pull/17512
119
124
yield cast (WebSocketServer , server )
120
125
121
126
@@ -147,37 +152,28 @@ def __init__(self, seconds: int) -> None:
147
152
self ._seconds = seconds
148
153
149
154
def __call__ (self , fn : Callable [PS , Awaitable [T ]]) -> Callable [PS , Awaitable [T | None ]]:
155
+ # Type of decorated function contains type `Any`
150
156
@wraps (fn )
151
- async def wrapper (* args : PS .args , ** kwargs : PS .kwargs ) -> T | None :
152
- result : T | None = None
157
+ async def wrapper ( # type: ignore[misc]
158
+ * args : PS .args ,
159
+ ** kwargs : PS .kwargs ,
160
+ ) -> T :
153
161
with trio .move_on_after (self ._seconds ) as cancel_scope :
154
- result = await fn (* args , ** kwargs )
162
+ return await fn (* args , ** kwargs )
155
163
if cancel_scope .cancelled_caught :
156
164
pytest .fail (f'Test runtime exceeded the maximum { self ._seconds } seconds' )
157
- return result
165
+ raise AssertionError ( "Should be unreachable" )
158
166
return wrapper
159
167
160
168
161
169
@attr .s (hash = False , eq = False )
162
- class MemoryListener (
163
- trio .abc .Listener [
164
- "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]"
165
- ]
166
- ):
170
+ class MemoryListener (trio .abc .Listener ["StapledMemoryStream" ]):
167
171
closed : bool = attr .ib (default = False )
168
- accepted_streams : list [
169
- trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
170
- ] = attr .ib (factory = list )
172
+ accepted_streams : list [StapledMemoryStream ] = attr .ib (factory = list )
171
173
queued_streams : tuple [
172
- trio .MemorySendChannel [
173
- trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
174
- ],
175
- trio .MemoryReceiveChannel [
176
- trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
177
- ],
178
- ] = attr .ib (factory = lambda : trio .open_memory_channel [
179
- "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]"
180
- ](1 ))
174
+ trio .MemorySendChannel [StapledMemoryStream ],
175
+ trio .MemoryReceiveChannel [StapledMemoryStream ],
176
+ ] = attr .ib (factory = lambda : trio .open_memory_channel ["StapledMemoryStream" ](1 ))
181
177
accept_hook : Callable [[], Awaitable [object ]] | None = attr .ib (default = None )
182
178
183
179
async def connect (self ) -> trio .StapledStream [
@@ -385,8 +381,11 @@ async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None:
385
381
assert conn .path == RESOURCE + '/' + path
386
382
387
383
384
+ # Type ignore because @patch contains `Any`
388
385
@patch ('trio_websocket._impl.open_websocket' )
389
- def test_client_open_url_options (open_websocket_mock : Mock ) -> None :
386
+ def test_client_open_url_options ( # type: ignore[misc]
387
+ open_websocket_mock : Mock ,
388
+ ) -> None :
390
389
"""open_websocket_url() must pass its options on to open_websocket()"""
391
390
port = 1234
392
391
url = f'ws://{ HOST } :{ port } { RESOURCE } '
@@ -618,7 +617,7 @@ async def handler(request: WebSocketRequest) -> None:
618
617
assert exc_info .value .__context__ is user_cancelled_context
619
618
620
619
def _trio_default_non_strict_exception_groups () -> bool :
621
- version = trio . __version__ # type: ignore[attr-defined]
620
+ version = version ( " trio" )
622
621
assert re .match (r'^0\.\d\d\.' , version ), "unexpected trio versioning scheme"
623
622
return int (version [2 :4 ]) < 25
624
623
0 commit comments