diff --git a/Makefile b/Makefile index f97c321..346a4c9 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint: $(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/ typecheck: - $(PYTHON) -m mypy --explicit-package-bases trio_websocket tests autobahn examples + $(PYTHON) -m mypy publish: rm -fr build dist .egg trio_websocket.egg-info diff --git a/autobahn/client.py b/autobahn/client.py index d93be1c..5c94b56 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -17,7 +17,7 @@ logger = logging.getLogger('client') -async def get_case_count(url): +async def get_case_count(url: str) -> int: url = url + '/getCaseCount' async with open_websocket_url(url) as conn: case_count = await conn.get_message() @@ -25,13 +25,13 @@ async def get_case_count(url): return int(case_count) -async def get_case_info(url, case): +async def get_case_info(url: str, case: str) -> object: url = f'{url}/getCaseInfo?case={case}' async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) -async def run_case(url, case): +async def run_case(url: str, case: str) -> None: url = f'{url}/runCase?case={case}&agent={AGENT}' try: async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn: @@ -42,7 +42,7 @@ async def run_case(url, case): pass -async def update_reports(url): +async def update_reports(url: str) -> None: url = f'{url}/updateReports?agent={AGENT}' async with open_websocket_url(url) as conn: # This command runs as soon as we connect to it, so we don't need to @@ -50,7 +50,7 @@ async def update_reports(url): pass -async def run_tests(args): +async def run_tests(args: argparse.Namespace) -> None: logger = logging.getLogger('trio-websocket') if args.debug_cases: # Don't fetch case count when debugging a subset of test cases. It adds @@ -62,7 +62,10 @@ async def run_tests(args): test_cases = list(range(1, case_count + 1)) exception_cases = [] for case in test_cases: - case_id = (await get_case_info(args.url, case))['id'] + result = await get_case_info(args.url, case) + assert isinstance(result, dict) + case_id = result['id'] + assert isinstance(case_id, int) if case_count: logger.info("Running test case %s (%d of %d)", case_id, case, case_count) else: @@ -82,7 +85,7 @@ async def run_tests(args): sys.exit(1) -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Autobahn client for' ' trio-websocket') diff --git a/autobahn/server.py b/autobahn/server.py index ff23846..6a84de4 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -23,14 +23,14 @@ connection_count = 0 -async def main(): +async def main() -> None: ''' Main entry point. ''' logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT) await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None, max_message_size=MAX_MESSAGE_SIZE) -async def handler(request: WebSocketRequest): +async def handler(request: WebSocketRequest) -> None: ''' Reverse incoming websocket messages and send them back. ''' global connection_count # pylint: disable=global-statement connection_count += 1 @@ -46,7 +46,7 @@ async def handler(request: WebSocketRequest): logger.exception(' runtime exception handling connection #%d', connection_count) -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Autobahn server for' ' trio-websocket') diff --git a/examples/client.py b/examples/client.py index 030c12b..ba5311c 100644 --- a/examples/client.py +++ b/examples/client.py @@ -11,16 +11,23 @@ import ssl import sys import urllib.parse +from typing import NoReturn import trio -from trio_websocket import open_websocket_url, ConnectionClosed, HandshakeError +from trio_websocket import ( + open_websocket_url, + ConnectionClosed, + HandshakeError, + WebSocketConnection, + CloseReason, +) logging.basicConfig(level=logging.DEBUG) here = pathlib.Path(__file__).parent -def commands(): +def commands() -> None: ''' Print the supported commands. ''' print('Commands: ') print('send -> send message') @@ -29,7 +36,7 @@ def commands(): print() -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Example trio-websocket client') parser.add_argument('--heartbeat', action='store_true', @@ -38,7 +45,7 @@ def parse_args(): return parser.parse_args() -async def main(args): +async def main(args: argparse.Namespace) -> bool: ''' Main entry point, returning False in the case of logged error. ''' if urllib.parse.urlsplit(args.url).scheme == 'wss': # Configure SSL context to handle our self-signed certificate. Most @@ -59,9 +66,10 @@ async def main(args): except HandshakeError as e: logging.error('Connection attempt failed: %s', e) return False + return True -async def handle_connection(ws, use_heartbeat): +async def handle_connection(ws: WebSocketConnection, use_heartbeat: bool) -> None: ''' Handle the connection. ''' logging.debug('Connected!') try: @@ -71,11 +79,12 @@ async def handle_connection(ws, use_heartbeat): nursery.start_soon(get_commands, ws) nursery.start_soon(get_messages, ws) except ConnectionClosed as cc: + assert isinstance(cc.reason, CloseReason) reason = '' if cc.reason.reason is None else f'"{cc.reason.reason}"' print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}') -async def heartbeat(ws, timeout, interval): +async def heartbeat(ws: WebSocketConnection, timeout: float, interval: float) -> NoReturn: ''' Send periodic pings on WebSocket ``ws``. @@ -99,11 +108,10 @@ async def heartbeat(ws, timeout, interval): await trio.sleep(interval) -async def get_commands(ws): +async def get_commands(ws: WebSocketConnection) -> None: ''' In a loop: get a command from the user and execute it. ''' while True: - cmd = await trio.to_thread.run_sync(input, 'cmd> ', - cancellable=True) + cmd = await trio.to_thread.run_sync(input, 'cmd> ') if cmd.startswith('ping'): payload = cmd[5:].encode('utf8') or None await ws.ping(payload) @@ -123,11 +131,11 @@ async def get_commands(ws): await trio.sleep(0.25) -async def get_messages(ws): +async def get_messages(ws: WebSocketConnection) -> None: ''' In a loop: get a WebSocket message and print it out. ''' while True: message = await ws.get_message() - print(f'message: {message}') + print(f'message: {message!r}') if __name__ == '__main__': diff --git a/examples/generate-cert.py b/examples/generate-cert.py index cc21698..fcb36bd 100644 --- a/examples/generate-cert.py +++ b/examples/generate-cert.py @@ -3,7 +3,7 @@ import trustme -def main(): +def main() -> None: here = pathlib.Path(__file__).parent ca_path = here / 'fake.ca.pem' server_path = here / 'fake.server.pem' diff --git a/examples/server.py b/examples/server.py index 611d89b..5274013 100644 --- a/examples/server.py +++ b/examples/server.py @@ -14,7 +14,7 @@ import ssl import trio -from trio_websocket import serve_websocket, ConnectionClosed +from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest logging.basicConfig(level=logging.DEBUG) @@ -22,7 +22,7 @@ here = pathlib.Path(__file__).parent -def parse_args(): +def parse_args() -> argparse.Namespace: ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description='Example trio-websocket client') parser.add_argument('--ssl', action='store_true', help='Use SSL') @@ -32,7 +32,7 @@ def parse_args(): return parser.parse_args() -async def main(args): +async def main(args: argparse.Namespace) -> None: ''' Main entry point. ''' logging.info('Starting websocket server…') if args.ssl: @@ -48,7 +48,7 @@ async def main(args): await serve_websocket(handler, host, args.port, ssl_context) -async def handler(request): +async def handler(request: WebSocketRequest) -> None: ''' Reverse incoming websocket messages and send them back. ''' logging.info('Handler starting on path "%s"', request.path) ws = await request.accept() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..95d5ff9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.mypy] +explicit_package_bases = true +files = ["trio_websocket", "tests", "autobahn", "examples"] +show_column_numbers = true +show_error_codes = true +show_traceback = true +disallow_any_decorated = true +disallow_any_unimported = true +ignore_missing_imports = true +local_partial_types = true +no_implicit_optional = true +strict = true +warn_unreachable = true diff --git a/setup.py b/setup.py index 17a21f9..46c6506 100644 --- a/setup.py +++ b/setup.py @@ -35,10 +35,12 @@ 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', + 'Typing :: Typed', ], python_requires=">=3.8", keywords='websocket client server trio', packages=find_packages(exclude=['docs', 'examples', 'tests']), + package_data={"trio-websocket": ["py.typed"]}, install_requires=[ 'exceptiongroup; python_version<"3.11"', 'trio>=0.11', diff --git a/tests/test_connection.py b/tests/test_connection.py index 0837aa5..45c1268 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,11 +32,14 @@ from __future__ import annotations import copy -from functools import partial, wraps import re import ssl import sys -from unittest.mock import patch +from collections.abc import AsyncGenerator +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar, cast +from unittest.mock import Mock, patch +from importlib.metadata import version import attr import pytest @@ -58,30 +61,43 @@ except ImportError: pass + from trio_websocket import ( - connect_websocket, - connect_websocket_url, + CloseReason, ConnectionClosed, ConnectionRejected, ConnectionTimeout, DisconnectionTimeout, Endpoint, HandshakeError, + WebSocketConnection, + WebSocketRequest, + WebSocketServer, + connect_websocket, + connect_websocket_url, open_websocket, open_websocket_url, serve_websocket, - WebSocketConnection, - WebSocketServer, - WebSocketRequest, wrap_client_stream, - wrap_server_stream + wrap_server_stream, ) - from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from wsproto.events import Event + + from typing_extensions import ParamSpec, TypeAlias + PS = ParamSpec("PS") + + StapledMemoryStream: TypeAlias = trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ] + WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) HOST = '127.0.0.1' @@ -96,19 +112,23 @@ FORCE_TIMEOUT = 2 TIMEOUT_TEST_MAX_DURATION = 3 +T = TypeVar("T") + @pytest.fixture -async def echo_server(nursery): +async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer, None]: ''' A server that reads one message, sends back the same message, then closes the connection. ''' serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) - yield server + # Cast needed because currently `nursery.start` has typing issues + # blocked by https://github.com/python/mypy/pull/17512 + yield cast(WebSocketServer, server) @pytest.fixture -async def echo_conn(echo_server): +async def echo_conn(echo_server: WebSocketServer) -> AsyncGenerator[WebSocketConnection, None]: ''' Return a client connection instance that is connected to an echo server. ''' async with open_websocket(HOST, echo_server.port, RESOURCE, @@ -116,7 +136,7 @@ async def echo_conn(echo_server): yield conn -async def echo_request_handler(request): +async def echo_request_handler(request: WebSocketRequest) -> None: ''' Accept incoming request and then pass off to echo connection handler. ''' @@ -131,35 +151,47 @@ async def echo_request_handler(request): class fail_after: ''' This decorator fails if the runtime of the decorated function (as measured by the Trio clock) exceeds the specified value. ''' - def __init__(self, seconds): + def __init__(self, seconds: int) -> None: self._seconds = seconds - def __call__(self, fn): + def __call__(self, fn: Callable[PS, Awaitable[T]]) -> Callable[PS, Awaitable[T | None]]: + # Type of decorated function contains type `Any` @wraps(fn) - async def wrapper(*args, **kwargs): + async def wrapper( # type: ignore[misc] + *args: PS.args, + **kwargs: PS.kwargs, + ) -> T: with trio.move_on_after(self._seconds) as cancel_scope: - await fn(*args, **kwargs) + return await fn(*args, **kwargs) if cancel_scope.cancelled_caught: pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') + raise AssertionError("Should be unreachable") return wrapper @attr.s(hash=False, eq=False) -class MemoryListener(trio.abc.Listener): - closed = attr.ib(default=False) - accepted_streams: list[ - tuple[trio.abc.SendChannel[str], trio.abc.ReceiveChannel[str]] - ] = attr.ib(factory=list) - queued_streams = attr.ib(factory=lambda: trio.open_memory_channel[str](1)) - accept_hook = attr.ib(default=None) - - async def connect(self): +class MemoryListener(trio.abc.Listener["StapledMemoryStream"]): + closed: bool = attr.ib(default=False) + accepted_streams: list[StapledMemoryStream] = attr.ib(factory=list) + queued_streams: tuple[ + trio.MemorySendChannel[StapledMemoryStream], + trio.MemoryReceiveChannel[StapledMemoryStream], + ] = attr.ib(factory=lambda: trio.open_memory_channel["StapledMemoryStream"](1)) + accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) + + async def connect(self) -> trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ]: assert not self.closed client, server = memory_stream_pair() await self.queued_streams[0].send(server) return client - async def accept(self): + async def accept(self) -> trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ]: await trio.sleep(0) assert not self.closed if self.accept_hook is not None: @@ -168,12 +200,12 @@ async def accept(self): self.accepted_streams.append(stream) return stream - async def aclose(self): + async def aclose(self) -> None: self.closed = True await trio.sleep(0) -async def test_endpoint_ipv4(): +async def test_endpoint_ipv4() -> None: e1 = Endpoint('10.105.0.2', 80, False) assert e1.url == 'ws://10.105.0.2' assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)' @@ -185,7 +217,7 @@ async def test_endpoint_ipv4(): assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)' -async def test_listen_port_ipv6(): +async def test_listen_port_ipv6() -> None: e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False) assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]' assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \ @@ -198,17 +230,19 @@ async def test_listen_port_ipv6(): assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)' -async def test_server_has_listeners(nursery): +async def test_server_has_listeners(nursery: trio.Nursery) -> None: server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) assert len(server.listeners) > 0 assert isinstance(server.listeners[0], Endpoint) -async def test_serve(nursery): +async def test_serve(nursery: trio.Nursery) -> None: task = current_task() server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) port = server.port assert server.port != 0 # The server nursery begins with one task (server.listen). @@ -221,7 +255,7 @@ async def test_serve(nursery): assert len(task.child_nurseries) == no_clients_nursery_count + 1 -async def test_serve_ssl(nursery): +async def test_serve_ssl(nursery: trio.Nursery) -> None: server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) client_context = ssl.create_default_context() ca = trustme.CA() @@ -231,19 +265,23 @@ async def test_serve_ssl(nursery): server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, server_context) + assert isinstance(server, WebSocketServer) port = server.port async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context ) as conn: assert not conn.closed + assert isinstance(conn.local, Endpoint) assert conn.local.is_ssl + assert isinstance(conn.remote, Endpoint) assert conn.remote.is_ssl -async def test_serve_handler_nursery(nursery): +async def test_serve_handler_nursery(nursery: trio.Nursery) -> None: async with trio.open_nursery() as handler_nursery: serve_with_nursery = partial(serve_websocket, echo_request_handler, HOST, 0, None, handler_nursery=handler_nursery) server = await nursery.start(serve_with_nursery) + assert isinstance(server, WebSocketServer) port = server.port # The server nursery begins with one task (server.listen). assert len(nursery.child_tasks) == 1 @@ -253,25 +291,39 @@ async def test_serve_handler_nursery(nursery): assert len(handler_nursery.child_tasks) == 1 -async def test_serve_with_zero_listeners(): +async def test_serve_with_zero_listeners() -> None: with pytest.raises(ValueError): WebSocketServer(echo_request_handler, []) -async def test_serve_non_tcp_listener(nursery): - listeners = [MemoryListener()] - server = WebSocketServer(echo_request_handler, listeners) +def memory_listener() -> trio.SocketListener: + return MemoryListener() # type: ignore[return-value] + + +async def test_serve_non_tcp_listener(nursery: trio.Nursery) -> None: + listeners = [memory_listener()] + server = WebSocketServer( + echo_request_handler, + listeners, + ) await nursery.start(server.run) assert len(server.listeners) == 1 with pytest.raises(RuntimeError): server.port # pylint: disable=pointless-statement - assert server.listeners[0].startswith('MemoryListener(') + listener = server.listeners[0] + assert isinstance(listener, str) + assert listener.startswith('MemoryListener(') -async def test_serve_multiple_listeners(nursery): +async def test_serve_multiple_listeners(nursery: trio.Nursery) -> None: listener1 = (await trio.open_tcp_listeners(0, host=HOST))[0] - listener2 = MemoryListener() - server = WebSocketServer(echo_request_handler, [listener1, listener2]) + listener2 = memory_listener() + server = WebSocketServer( + echo_request_handler, [ + listener1, + listener2, + ] + ) await nursery.start(server.run) assert len(server.listeners) == 2 with pytest.raises(RuntimeError): @@ -279,13 +331,17 @@ async def test_serve_multiple_listeners(nursery): # usable if you have exactly one listener. server.port # pylint: disable=pointless-statement # The first listener metadata is a ListenPort instance. - assert server.listeners[0].port != 0 + listener_zero = server.listeners[0] + assert isinstance(listener_zero, Endpoint) + assert listener_zero.port != 0 # The second listener metadata is a string containing the repr() of a # MemoryListener object. - assert server.listeners[1].startswith('MemoryListener(') + listener_one = server.listeners[1] + assert isinstance(listener_one, str) + assert listener_one.startswith('MemoryListener(') -async def test_client_open(echo_server): +async def test_client_open(echo_server: WebSocketServer) -> None: async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \ as conn: assert not conn.closed @@ -299,35 +355,44 @@ async def test_client_open(echo_server): (RESOURCE + '/path', RESOURCE + '/path'), (RESOURCE + '?foo=bar', RESOURCE + '?foo=bar') ]) -async def test_client_open_url(path, expected_path, echo_server): +async def test_client_open_url(path: str, expected_path: str, echo_server: WebSocketServer) -> None: url = f'ws://{HOST}:{echo_server.port}{path}' async with open_websocket_url(url) as conn: assert conn.path == expected_path -async def test_client_open_invalid_url(echo_server): +async def test_client_open_invalid_url(echo_server: WebSocketServer) -> None: with pytest.raises(ValueError): async with open_websocket_url('http://foo.com/bar'): pass -async def test_client_open_invalid_ssl(echo_server, nursery): +async def test_client_open_invalid_ssl( + echo_server: WebSocketServer, + nursery: trio.Nursery, +) -> None: with pytest.raises(TypeError, match='`use_ssl` argument must be bool or ssl.SSLContext'): - await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=1) + await connect_websocket( + nursery, HOST, echo_server.port, RESOURCE, + use_ssl=1, # type: ignore[arg-type] + ) url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' with pytest.raises(ValueError, match='^SSL context must be None for ws: URL scheme$' ): await connect_websocket_url(nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) -async def test_ascii_encoded_path_is_ok(echo_server): +async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None: path = '%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90' url = f'ws://{HOST}:{echo_server.port}{RESOURCE}/{path}' async with open_websocket_url(url) as conn: assert conn.path == RESOURCE + '/' + path +# Type ignore because @patch contains `Any` @patch('trio_websocket._impl.open_websocket') -def test_client_open_url_options(open_websocket_mock): +def test_client_open_url_options( # type: ignore[misc] + open_websocket_mock: Mock, +) -> None: """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 url = f'ws://{HOST}:{port}{RESOURCE}' @@ -339,7 +404,7 @@ def test_client_open_url_options(open_websocket_mock): 'connect_timeout': 36, 'disconnect_timeout': 37, } - open_websocket_url(url, **options) + open_websocket_url(url, **options) # type: ignore[arg-type] _, call_args, call_kwargs = open_websocket_mock.mock_calls[0] assert call_args == (HOST, port, RESOURCE) assert not call_kwargs.pop('use_ssl') @@ -350,19 +415,19 @@ def test_client_open_url_options(open_websocket_mock): assert call_kwargs['use_ssl'] -async def test_client_connect(echo_server, nursery): +async def test_client_connect(echo_server: WebSocketServer, nursery: trio.Nursery) -> None: conn = await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=False) assert not conn.closed -async def test_client_connect_url(echo_server, nursery): +async def test_client_connect_url(echo_server: WebSocketServer, nursery: trio.Nursery) -> None: url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' conn = await connect_websocket_url(nursery, url) assert not conn.closed -async def test_connection_has_endpoints(echo_conn): +async def test_connection_has_endpoints(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert isinstance(echo_conn.local, Endpoint) assert str(echo_conn.local.address) == HOST @@ -376,47 +441,53 @@ async def test_connection_has_endpoints(echo_conn): @fail_after(1) -async def test_handshake_has_endpoints(nursery): - async def handler(request): +async def test_handshake_has_endpoints(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + assert isinstance(server, WebSocketServer) + assert isinstance(request.local, Endpoint) assert str(request.local.address) == HOST assert request.local.port == server.port assert not request.local.is_ssl + assert isinstance(request.remote, Endpoint) assert str(request.remote.address) == HOST assert not request.remote.is_ssl await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass -async def test_handshake_subprotocol(nursery): - async def handler(request): +async def test_handshake_subprotocol(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: assert request.proposed_subprotocols == ('chat', 'file') server_ws = await request.accept(subprotocol='chat') assert server_ws.subprotocol == 'chat' server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, subprotocols=('chat', 'file')) as client_ws: assert client_ws.subprotocol == 'chat' -async def test_handshake_path(nursery): - async def handler(request): +async def test_handshake_path(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: assert request.path == RESOURCE server_ws = await request.accept() assert server_ws.path == RESOURCE server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, ) as client_ws: assert client_ws.path == RESOURCE @fail_after(1) -async def test_handshake_client_headers(nursery): - async def handler(request): +async def test_handshake_client_headers(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: headers = dict(request.headers) assert b'x-test-header' in headers assert headers[b'x-test-header'] == b'My test header' @@ -424,6 +495,7 @@ async def handler(request): await server_ws.send_message('test') server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) headers = [(b'X-Test-Header', b'My test header')] async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, extra_headers=headers) as client_ws: @@ -431,12 +503,13 @@ async def handler(request): @fail_after(1) -async def test_handshake_server_headers(nursery): - async def handler(request): - headers = [('X-Test-Header', 'My test header')] +async def test_handshake_server_headers(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + headers = [(b'X-Test-Header', b'My test header')] await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False ) as client_ws: header_key, header_value = client_ws.handshake_headers[0] @@ -444,22 +517,25 @@ async def handler(request): assert header_value == b'My test header' - - @fail_after(5) -async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_internal_ki( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """_reader_task._handle_ping_event triggers KeyboardInterrupt. user code also raises exception. Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ - async def ki_raising_ping_handler(*args, **kwargs) -> None: + async def ki_raising_ping_handler(*args: object, **kwargs: object) -> None: raise KeyboardInterrupt monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(KeyboardInterrupt) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): with trio.fail_after(1) as cs: @@ -471,7 +547,11 @@ async def handler(request): assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions) @fail_after(5) -async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_internal_exc( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """_reader_task._handle_ping_event triggers ValueError. user code also raises exception. internal exception is in __context__ exceptiongroup and user exc is delivered @@ -480,15 +560,16 @@ async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock) internal_error.__context__ = TypeError() user_error = NameError() user_error_context = KeyError() - async def raising_ping_event(*args, **kwargs) -> None: + async def raising_ping_event(*args: object, **kwargs: object) -> None: raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(type(user_error)) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): await trio.lowlevel.checkpoint() @@ -502,19 +583,23 @@ async def handler(request): assert user_error_context in e_context.exceptions @fail_after(5) -async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_cancellations( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """Both user code and _reader_task raise Cancellation. Check that open_websocket reraises the one from user code for traceback reasons. """ - async def sleeping_ping_event(*args, **kwargs) -> None: + async def sleeping_ping_event(*args: object, **kwargs: object) -> None: await trio.sleep_forever() # We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually # raise Cancelled upon being cancelled. For some reason it doesn't otherwise. monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event) - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") user_cancelled = None @@ -522,6 +607,7 @@ async def handler(request): user_cancelled_context = None server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with trio.move_on_after(2): with pytest.raises(trio.Cancelled) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): @@ -538,15 +624,16 @@ async def handler(request): assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: - assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" - return int(trio.__version__[2:4]) < 25 + trio_version = version("trio") + assert re.match(r'^0\.\d\d\.', trio_version), "unexpected trio versioning scheme" + return int(trio_version[2:4]) < 25 @fail_after(1) async def test_handshake_exception_before_accept() -> None: ''' In #107, a request handler that throws an exception before finishing the handshake causes the task to hang. The proper behavior is to raise an exception to the nursery as soon as possible. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: raise ValueError() # pylint fails to resolve that BaseExceptionGroup will always be available @@ -554,6 +641,7 @@ async def handler(request): async with trio.open_nursery() as nursery: server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass @@ -572,10 +660,11 @@ async def handler(request): RaisesGroup(ValueError)))).matches(exc.value) -async def test_user_exception_cause(nursery) -> None: - async def handler(request): +async def test_user_exception_cause(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) e_context = TypeError("foo") e_primary = ValueError("bar") e_cause = RuntimeError("zee") @@ -591,12 +680,13 @@ async def handler(request): assert e.__context__ is e_context @fail_after(1) -async def test_reject_handshake(nursery): - async def handler(request): +async def test_reject_handshake(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: body = b'My body' await request.reject(400, body=body) server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(ConnectionRejected) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass @@ -605,17 +695,18 @@ async def handler(request): @fail_after(1) -async def test_reject_handshake_invalid_info_status(nursery): +async def test_reject_handshake_invalid_info_status(nursery: trio.Nursery) -> None: ''' An informational status code that is not 101 should cause the client to reject the handshake. Since it is an informational response, there will not be a response body, so this test exercises a different code path. ''' - async def handler(stream): + async def handler(stream: trio.SocketStream) -> None: await stream.send_all(b'HTTP/1.1 100 CONTINUE\r\n\r\n') await stream.receive_some(max_bytes=1024) serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: @@ -627,7 +718,7 @@ async def handler(stream): assert exc.body is None -async def test_handshake_protocol_error(echo_server): +async def test_handshake_protocol_error(echo_server: WebSocketServer) -> None: ''' If a client connects to a trio-websocket server and tries to speak HTTP instead of WebSocket, the server should reject the connection. (If the @@ -641,29 +732,29 @@ async def test_handshake_protocol_error(echo_server): assert response.startswith(b'HTTP/1.1 400') -async def test_client_send_and_receive(echo_conn): +async def test_client_send_and_receive(echo_conn: WebSocketConnection) -> None: async with echo_conn: await echo_conn.send_message('This is a test message.') received_msg = await echo_conn.get_message() assert received_msg == 'This is a test message.' -async def test_client_send_invalid_type(echo_conn): +async def test_client_send_invalid_type(echo_conn: WebSocketConnection) -> None: async with echo_conn: with pytest.raises(ValueError): - await echo_conn.send_message(object()) + await echo_conn.send_message(object()) # type: ignore[arg-type] -async def test_client_ping(echo_conn): +async def test_client_ping(echo_conn: WebSocketConnection) -> None: async with echo_conn: await echo_conn.ping(b'A') with pytest.raises(ConnectionClosed): await echo_conn.ping(b'B') -async def test_client_ping_two_payloads(echo_conn): +async def test_client_ping_two_payloads(echo_conn: WebSocketConnection) -> None: pong_count = 0 - async def ping_and_count(): + async def ping_and_count() -> None: nonlocal pong_count await echo_conn.ping() pong_count += 1 @@ -674,12 +765,12 @@ async def ping_and_count(): assert pong_count == 2 -async def test_client_ping_same_payload(echo_conn): +async def test_client_ping_same_payload(echo_conn: WebSocketConnection) -> None: # This test verifies that two tasks can't ping with the same payload at the # same time. One of them should succeed and the other should get an # exception. exc_count = 0 - async def ping_and_catch(): + async def ping_and_catch() -> None: nonlocal exc_count try: await echo_conn.ping(b'A') @@ -692,47 +783,53 @@ async def ping_and_catch(): assert exc_count == 1 -async def test_client_pong(echo_conn): +async def test_client_pong(echo_conn: WebSocketConnection) -> None: async with echo_conn: await echo_conn.pong(b'A') with pytest.raises(ConnectionClosed): await echo_conn.pong(b'B') -async def test_client_default_close(echo_conn): +async def test_client_default_close(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert not echo_conn.closed + assert isinstance(echo_conn.closed, CloseReason) assert echo_conn.closed.code == 1000 assert echo_conn.closed.reason is None assert repr(echo_conn.closed) == 'CloseReason' -async def test_client_nondefault_close(echo_conn): +async def test_client_nondefault_close(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert not echo_conn.closed await echo_conn.aclose(code=1001, reason='test reason') + assert isinstance(echo_conn.closed, CloseReason) assert echo_conn.closed.code == 1001 assert echo_conn.closed.reason == 'test reason' -async def test_wrap_client_stream(nursery): +async def test_wrap_client_stream(nursery: trio.Nursery) -> None: listener = MemoryListener() - server = WebSocketServer(echo_request_handler, [listener]) + server = WebSocketServer(echo_request_handler, [listener]) # type: ignore[list-item] await nursery.start(server.run) stream = await listener.connect() - conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE) + conn = await wrap_client_stream( + nursery, + stream, # type: ignore[arg-type] + HOST, RESOURCE) async with conn: assert not conn.closed await conn.send_message('Hello from client!') msg = await conn.get_message() assert msg == 'Hello from client!' + assert isinstance(conn.local, str) assert conn.local.startswith('StapledStream(') assert conn.closed -async def test_wrap_server_stream(nursery): - async def handler(stream): +async def test_wrap_server_stream(nursery: trio.Nursery) -> None: + async def handler(stream: trio.SocketStream) -> None: request = await wrap_server_stream(nursery, stream) server_ws = await request.accept() async with server_ws: @@ -741,25 +838,30 @@ async def handler(stream): assert msg == 'Hello from client!' assert server_ws.closed serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: await client.send_message('Hello from client!') @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_client_open_timeout(nursery, autojump_clock): +async def test_client_open_timeout( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: ''' The client times out waiting for the server to complete the opening handshake. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: await trio.sleep(FORCE_TIMEOUT) await request.accept() pytest.fail('Should not reach this line.') server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) with pytest.raises(ConnectionTimeout): async with open_websocket(HOST, server.port, '/', use_ssl=False, @@ -768,7 +870,10 @@ async def handler(request): @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_client_close_timeout(nursery, autojump_clock): +async def test_client_close_timeout( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: ''' This client times out waiting for the server to complete the closing handshake. @@ -778,7 +883,7 @@ async def test_client_close_timeout(nursery, autojump_clock): server's reader so it won't do the closing handshake for at least ``FORCE_TIMEOUT`` seconds. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await trio.sleep(FORCE_TIMEOUT) # The next line should raise ConnectionClosed. @@ -788,6 +893,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None, message_queue_size=0)) + assert isinstance(server, WebSocketServer) with pytest.raises(DisconnectionTimeout): async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, @@ -795,7 +901,7 @@ async def handler(request): await client_ws.send_message('test') -async def test_client_connect_networking_error(): +async def test_client_connect_networking_error() -> None: with patch('trio_websocket._impl.connect_websocket') as \ connect_websocket_mock: connect_websocket_mock.side_effect = OSError() @@ -805,7 +911,7 @@ async def test_client_connect_networking_error(): @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_server_open_timeout(autojump_clock): +async def test_server_open_timeout(autojump_clock: trio.testing.MockClock) -> None: ''' The server times out waiting for the client to complete the opening handshake. @@ -814,12 +920,13 @@ async def test_server_open_timeout(autojump_clock): in an internal nursery and sending exceptions wouldn't be helpful. Instead, timed out tasks silently end. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: pytest.fail('This handler should not be called.') async with trio.open_nursery() as nursery: server = await nursery.start(partial(serve_websocket, handler, HOST, 0, ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + assert isinstance(server, WebSocketServer) old_task_count = len(nursery.child_tasks) # This stream is not a WebSocket, so it won't send a handshake: @@ -837,7 +944,7 @@ async def handler(request): @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_server_close_timeout(autojump_clock): +async def test_server_close_timeout(autojump_clock: trio.testing.MockClock) -> None: ''' The server times out waiting for the client to complete the closing handshake. @@ -850,7 +957,7 @@ async def test_server_close_timeout(autojump_clock): its message queue size is 0 and the server sends it exactly 1 message. This blocks the client's reader and prevents it from doing the client handshake. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() # Send one message to block the client's reader task: await ws.send_message('test') @@ -859,6 +966,7 @@ async def handler(request): server = await outer.start(partial(serve_websocket, handler, HOST, 0, ssl_context=None, handler_nursery=outer, disconnect_timeout=TIMEOUT)) + assert isinstance(server, WebSocketServer) old_task_count = len(outer.child_tasks) # Spawn client inside an inner nursery so that we can cancel it's reader @@ -883,12 +991,13 @@ async def handler(request): outer.cancel_scope.cancel() -async def test_client_does_not_close_handshake(nursery): - async def handler(request): +async def test_client_does_not_close_handshake(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() with pytest.raises(ConnectionClosed): await server_ws.get_message() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: @@ -897,10 +1006,10 @@ async def handler(request): await client_ws.send_message('Hello from client!') -async def test_server_sends_after_close(nursery): +async def test_server_sends_after_close(nursery: trio.Nursery) -> None: done = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() with pytest.raises(ConnectionClosed): while True: @@ -908,6 +1017,7 @@ async def handler(request): done.set() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: @@ -918,8 +1028,8 @@ async def handler(request): await done.wait() -async def test_server_does_not_close_handshake(nursery): - async def handler(stream): +async def test_server_does_not_close_handshake(nursery: trio.Nursery) -> None: + async def handler(stream: trio.SocketStream) -> None: request = await wrap_server_stream(nursery, stream) server_ws = await request.accept() async with server_ws: @@ -927,20 +1037,25 @@ async def handler(stream): with pytest.raises(ConnectionClosed): await server_ws.send_message('Hello from client!') serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: with pytest.raises(ConnectionClosed): await client.get_message() -async def test_server_handler_exit(nursery, autojump_clock): - async def handler(request): +async def test_server_handler_exit( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: + async def handler(request: WebSocketRequest) -> None: await request.accept() await trio.sleep(1) server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) # connection should close when server handler exits with trio.fail_after(2): @@ -949,11 +1064,12 @@ async def handler(request): with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() exc = exc_info.value + assert isinstance(exc.reason, CloseReason) assert exc.reason.name == 'NORMAL_CLOSURE' @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_read_messages_after_remote_close(nursery): +async def test_read_messages_after_remote_close(nursery: trio.Nursery) -> None: ''' When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will @@ -963,7 +1079,7 @@ async def test_read_messages_after_remote_close(nursery): ''' server_closed = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server = await request.accept() async with server: await server.send_message('1') @@ -972,6 +1088,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) # The client needs a message queue of size 2 so that it can buffer both # incoming messages without blocking the reader task. @@ -984,14 +1101,14 @@ async def handler(request): await client.get_message() -async def test_no_messages_after_local_close(nursery): +async def test_no_messages_after_local_close(nursery: trio.Nursery) -> None: ''' If the local endpoint initiates closing, then pending messages are discarded and any attempt to read a message will raise ConnectionClosed. ''' client_closed = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: # The server sends some messages and then closes. server = await request.accept() async with server: @@ -1001,6 +1118,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: pass @@ -1009,7 +1127,10 @@ async def handler(request): client_closed.set() -async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): +async def test_cm_exit_with_pending_messages( + echo_server: WebSocketServer, + autojump_clock: trio.testing.MockClock, +) -> None: ''' Regression test for #74, where a context manager was not able to exit when there were pending messages in the receive queue. @@ -1023,13 +1144,13 @@ async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_max_message_size(nursery): +async def test_max_message_size(nursery: trio.Nursery) -> None: ''' Set the client's max message size to 100 bytes. The client can send a message larger than 100 bytes, but when it receives a message larger than 100 bytes, it closes the connection with code 1009. ''' - async def handler(request): + async def handler(request: WebSocketRequest) -> None: ''' Similar to the echo_request_handler fixture except it runs in a loop. ''' conn = await request.accept() @@ -1042,6 +1163,7 @@ async def handler(request): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, max_message_size=100) as client: @@ -1057,10 +1179,13 @@ async def handler(request): assert client.closed.code == 1009 -async def test_server_close_client_disconnect_race(nursery, autojump_clock): +async def test_server_close_client_disconnect_race( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: """server attempts close just as client disconnects (issue #96)""" - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() ws._for_testing_peer_closed_connection = trio.Event() await ws.send_message('foo') @@ -1070,6 +1195,7 @@ async def handler(request: WebSocketRequest): server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) connection = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) @@ -1078,7 +1204,10 @@ async def handler(request: WebSocketRequest): await trio.sleep(.1) -async def test_remote_close_local_message_race(nursery, autojump_clock): +async def test_remote_close_local_message_race( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: """as remote initiates close, local attempts message (issue #175) This exposed multiple problems in the trio-websocket API and implementation: @@ -1089,13 +1218,14 @@ async def test_remote_close_local_message_race(nursery, autojump_clock): * with wsproto >= 1.2.0, LocalProtocolError will be leaked """ - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() await ws.get_message() await ws.aclose() server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) client = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) @@ -1106,27 +1236,28 @@ async def handler(request: WebSocketRequest): await client.send_message('bar') -async def test_message_after_local_close_race(nursery): +async def test_message_after_local_close_race(nursery: trio.Nursery) -> None: """test message send during local-initiated close handshake (issue #158)""" - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: await request.accept() await trio.sleep_forever() server = await nursery.start( partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) client = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) orig_send = client._send close_sent = trio.Event() - async def _send_wrapper(event): + async def _send_wrapper(event: Event) -> None: if isinstance(event, CloseConnection): close_sent.set() return await orig_send(event) - client._send = _send_wrapper + client._send = _send_wrapper # type: ignore[method-assign] assert not client.closed nursery.start_soon(client.aclose) await close_sent.wait() @@ -1136,21 +1267,22 @@ async def _send_wrapper(event): @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_server_tcp_closed_on_close_connection_event(nursery): +async def test_server_tcp_closed_on_close_connection_event(nursery: trio.Nursery) -> None: """ensure server closes TCP immediately after receiving CloseConnection""" server_stream_closed = trio.Event() - async def _close_stream_stub(): + async def _close_stream_stub() -> None: assert not server_stream_closed.is_set() server_stream_closed.set() - async def handle_connection(request): + async def handle_connection(request: WebSocketRequest) -> None: ws = await request.accept() - ws._close_stream = _close_stream_stub + ws._close_stream = _close_stream_stub # type: ignore[method-assign] await trio.sleep_forever() server = await nursery.start( partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None)) + assert isinstance(server, WebSocketServer) client = await connect_websocket(nursery, HOST, server.port, RESOURCE, use_ssl=False) # send a CloseConnection event to server but leave client connected @@ -1158,7 +1290,10 @@ async def handle_connection(request): await server_stream_closed.wait() -async def test_finalization_dropped_exception(echo_server, autojump_clock): +async def test_finalization_dropped_exception( + echo_server: WebSocketServer, + autojump_clock: trio.testing.MockClock, +) -> None: # Confirm that open_websocket finalization does not contribute to dropped # exceptions as described in https://github.com/python-trio/trio/issues/1559. with pytest.raises(ValueError): @@ -1170,7 +1305,7 @@ async def test_finalization_dropped_exception(echo_server, autojump_clock): raise ValueError -async def test_remote_close_rude(): +async def test_remote_close_rude() -> None: """ Bad ordering: 1. Remote close @@ -1180,14 +1315,17 @@ async def test_remote_close_rude(): """ client_stream, server_stream = memory_stream_pair() - async def client(): - client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE) + async def client() -> None: + client_conn = await wrap_client_stream( + nursery, + client_stream, # type: ignore[arg-type] + HOST, RESOURCE) assert not client_conn.closed await client_conn.send_message('Hello from client!') with pytest.raises(ConnectionClosed): await client_conn.get_message() - async def server(): + async def server() -> None: server_request = await wrap_server_stream(nursery, server_stream) server_ws = await server_request.accept() assert not server_ws.closed @@ -1208,14 +1346,16 @@ async def server(): nursery.start_soon(client) -def test_copy_exceptions(): +def test_copy_exceptions() -> None: # test that exceptions are copy- and pickleable copy.copy(HandshakeError()) copy.copy(ConnectionTimeout()) copy.copy(DisconnectionTimeout()) - assert copy.copy(ConnectionClosed("foo")).reason == "foo" + assert copy.copy( + ConnectionClosed("foo") # type: ignore[arg-type] + ).reason == "foo" # type: ignore[comparison-overlap] - rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c")) + rej_copy = copy.copy(ConnectionRejected(404, ((b"a", b"b"),), b"c")) assert rej_copy.status_code == 404 - assert rej_copy.headers == (("a", "b"),) + assert rej_copy.headers == ((b"a", b"b"),) assert rej_copy.body == b"c" diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index 82ca0ae..afa944c 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -1,20 +1,21 @@ +# pylint: disable=useless-import-alias from ._impl import ( - CloseReason, - ConnectionClosed, - ConnectionRejected, - ConnectionTimeout, - connect_websocket, - connect_websocket_url, - DisconnectionTimeout, - Endpoint, - HandshakeError, - open_websocket, - open_websocket_url, - WebSocketConnection, - WebSocketRequest, - WebSocketServer, - wrap_client_stream, - wrap_server_stream, - serve_websocket, + CloseReason as CloseReason, + ConnectionClosed as ConnectionClosed, + ConnectionRejected as ConnectionRejected, + ConnectionTimeout as ConnectionTimeout, + connect_websocket as connect_websocket, + connect_websocket_url as connect_websocket_url, + DisconnectionTimeout as DisconnectionTimeout, + Endpoint as Endpoint, + HandshakeError as HandshakeError, + open_websocket as open_websocket, + open_websocket_url as open_websocket_url, + WebSocketConnection as WebSocketConnection, + WebSocketRequest as WebSocketRequest, + WebSocketServer as WebSocketServer, + wrap_client_stream as wrap_client_stream, + wrap_server_stream as wrap_server_stream, + serve_websocket as serve_websocket, ) -from ._version import __version__ +from ._version import __version__ as __version__ diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 5f3a9d4..b7c1ea4 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -2,7 +2,7 @@ import sys from collections import OrderedDict -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, AbstractAsyncContextManager from functools import partial from ipaddress import ip_address import itertools @@ -11,7 +11,8 @@ import ssl import struct import urllib.parse -from typing import Iterable, List, NoReturn, Optional, Union +from typing import Any, List, NoReturn, Optional, Union, TypeVar, TYPE_CHECKING, Generic, cast +from importlib.metadata import version import outcome import trio @@ -36,18 +37,26 @@ # pylint doesn't care about the version_info check, so need to ignore the warning from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) +if TYPE_CHECKING: + from types import TracebackType + from typing_extensions import Final + from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Coroutine, Sequence + +_IS_TRIO_MULTI_ERROR: Final = tuple(map(int, version("trio").split(".")[:2])) < (0, 22) if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member else: _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment -CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds -MESSAGE_QUEUE_SIZE = 1 -MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB -logger = logging.getLogger('trio-websocket') +CONN_TIMEOUT: Final = 60 # default connect & disconnect timeout, in seconds +MESSAGE_QUEUE_SIZE: Final = 1 +MAX_MESSAGE_SIZE: Final = 2 ** 20 # 1 MiB +RECEIVE_BYTES: Final = 4 * 2 ** 10 # 4 KiB +logger: Final = logging.getLogger('trio-websocket') + +T = TypeVar("T") +E = TypeVar("E", bound=BaseException) class TrioWebsocketInternalError(Exception): @@ -57,7 +66,7 @@ class TrioWebsocketInternalError(Exception): """ -def _ignore_cancel(exc): +def _ignore_cancel(exc: E) -> E | None: return None if isinstance(exc, trio.Cancelled) else exc @@ -73,18 +82,23 @@ class _preserve_current_exception: """ __slots__ = ("_armed",) - def __init__(self): + def __init__(self) -> None: self._armed = False - def __enter__(self): + def __enter__(self) -> None: self._armed = sys.exc_info()[1] is not None - def __exit__(self, ty, value, tb): + def __exit__( + self, + ty: type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, + ) -> bool: if value is None or not self._armed: return False if _IS_TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member + filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # type: ignore[attr-defined] # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) else: @@ -94,18 +108,18 @@ def __exit__(self, ty, value, tb): @asynccontextmanager async def open_websocket( - host: str, - port: int, - resource: str, - *, - use_ssl: Union[bool, ssl.SSLContext], - subprotocols: Optional[Iterable[str]] = None, - extra_headers: Optional[list[tuple[bytes,bytes]]] = None, - message_queue_size: int = MESSAGE_QUEUE_SIZE, - max_message_size: int = MAX_MESSAGE_SIZE, - connect_timeout: float = CONN_TIMEOUT, - disconnect_timeout: float = CONN_TIMEOUT - ): + host: str, + port: int, + resource: str, + *, + use_ssl: Union[bool, ssl.SSLContext], + subprotocols: Optional[Iterable[str]] = None, + extra_headers: Optional[list[tuple[bytes,bytes]]] = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT +) -> AsyncGenerator[WebSocketConnection, None]: ''' Open a WebSocket client connection to a host. @@ -286,10 +300,18 @@ def _raise(exc: BaseException) -> NoReturn: result.unwrap() -async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE - ) -> WebSocketConnection: +async def connect_websocket( + nursery: trio.Nursery, + host: str, + port: int, + resource: str, + *, + use_ssl: bool | ssl.SSLContext, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: ''' Return an open WebSocket client connection to a host. @@ -352,10 +374,17 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None, - extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): +def open_websocket_url( + url: str, + ssl_context: ssl.SSLContext | None = None, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, +) -> AbstractAsyncContextManager[WebSocketConnection]: ''' Open a WebSocket client connection to a URL. @@ -386,17 +415,24 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. ''' - host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, use_ssl=ssl_context, + host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) + return open_websocket(host, port, resource, use_ssl=return_ssl_context, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) -async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): +async def connect_websocket_url( + nursery: trio.Nursery, + url: str, + ssl_context: ssl.SSLContext | None = None, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: ''' Return an open WebSocket client connection to a URL. @@ -423,14 +459,17 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection ''' - host, port, resource, ssl_context = _url_to_host(url, ssl_context) + host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols, + use_ssl=return_ssl_context, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size) -def _url_to_host(url, ssl_context): +def _url_to_host( + url: str, + ssl_context: ssl.SSLContext | None, +) -> tuple[str, int, str, ssl.SSLContext | bool]: ''' Convert a WebSocket URL to a (host,port,resource) tuple. @@ -446,11 +485,16 @@ def _url_to_host(url, ssl_context): parts = urllib.parse.urlsplit(url) if parts.scheme not in ('ws', 'wss'): raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"') + return_ssl_context: ssl.SSLContext | bool if ssl_context is None: - ssl_context = parts.scheme == 'wss' + return_ssl_context = parts.scheme == 'wss' elif parts.scheme == 'ws': raise ValueError('SSL context must be None for ws: URL scheme') + else: + return_ssl_context = ssl_context host = parts.hostname + if host is None: + raise ValueError('URL host must not be None') if parts.port is not None: port = parts.port else: @@ -463,12 +507,20 @@ def _url_to_host(url, ssl_context): path_qs = '/' if '?' in url: path_qs += '?' + parts.query - return host, port, path_qs, ssl_context - - -async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): + return host, port, path_qs, return_ssl_context + + +async def wrap_client_stream( + nursery: trio.Nursery, + stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], + host: str, + resource: str, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: ''' Wrap an arbitrary stream in a WebSocket connection. @@ -505,8 +557,12 @@ async def wrap_client_stream(nursery, stream, host, resource, *, return connection -async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): +async def wrap_server_stream( + nursery: trio.Nursery, + stream: trio.abc.Stream, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketRequest: ''' Wrap an arbitrary stream in a server-side WebSocket. @@ -523,7 +579,8 @@ async def wrap_server_stream(nursery, stream, :type stream: trio.abc.Stream :rtype: WebSocketRequest ''' - connection = WebSocketConnection(stream, + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, max_message_size=max_message_size) @@ -532,11 +589,20 @@ async def wrap_server_stream(nursery, stream, return request -async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): - ''' +async def serve_websocket( + handler: Callable[[WebSocketRequest], Awaitable[None]], + host: str | bytes | None, + port: int, + ssl_context: ssl.SSLContext | None, + *, + handler_nursery: trio.Nursery | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, + task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: + """ Serve a WebSocket over TCP. This function supports the Trio nursery start protocol: ``server = await @@ -570,17 +636,31 @@ async def serve_websocket(handler, host, port, ssl_context, *, to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. - ''' + """ + open_tcp_listeners: ( + partial[Coroutine[Any, Any, list[trio.SocketListener]]] + | partial[Coroutine[Any, Any, list[trio.SSLListener[trio.SocketStream]]]] + ) if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: - open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, - ssl_context, host=host, https_compatible=True) + open_tcp_listeners = partial( + trio.open_ssl_over_tcp_listeners, + port, + ssl_context, + host=host, + https_compatible=True, + ) listeners = await open_tcp_listeners() - server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, - disconnect_timeout=disconnect_timeout) + server = WebSocketServer( + handler, + listeners, + handler_nursery=handler_nursery, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) await server.run(task_status=task_status) @@ -601,7 +681,7 @@ class ConnectionClosed(Exception): A WebSocket operation cannot be completed because the connection is closed or in the process of closing. ''' - def __init__(self, reason): + def __init__(self, reason: CloseReason | None) -> None: ''' Constructor. @@ -611,7 +691,7 @@ def __init__(self, reason): super().__init__(reason) self.reason = reason - def __repr__(self): + def __repr__(self) -> str: ''' Return representation. ''' return f'{self.__class__.__name__}<{self.reason}>' @@ -621,7 +701,12 @@ class ConnectionRejected(HandshakeError): A WebSocket connection could not be established because the server rejected the connection attempt. ''' - def __init__(self, status_code, headers, body): + def __init__( + self, + status_code: int, + headers: tuple[tuple[bytes, bytes], ...], + body: bytes | None, + ) -> None: ''' Constructor. @@ -636,14 +721,14 @@ def __init__(self, status_code, headers, body): #: an optional ``bytes`` response body self.body = body - def __repr__(self): + def __repr__(self) -> str: ''' Return representation. ''' return f'{self.__class__.__name__}' class CloseReason: ''' Contains information about why a WebSocket was closed. ''' - def __init__(self, code, reason): + def __init__(self, code: int, reason: str | None) -> None: ''' Constructor. @@ -665,34 +750,39 @@ def __init__(self, code, reason): self._reason = reason @property - def code(self): + def code(self) -> int: ''' (Read-only) The numeric close code. ''' return self._code @property - def name(self): + def name(self) -> str: ''' (Read-only) The human-readable close code. ''' return self._name @property - def reason(self): + def reason(self) -> str | None: ''' (Read-only) An arbitrary reason string. ''' return self._reason - def __repr__(self): + def __repr__(self) -> str: ''' Show close code, name, and reason. ''' return f'{self.__class__.__name__}' \ f'' -class Future: +NULL: Final = object() + + +class Future(Generic[T]): ''' Represents a value that will be available in the future. ''' - def __init__(self): + def __init__(self) -> None: ''' Constructor. ''' - self._value = None + # We do some type shenanigins + # Would do `T | Literal[NULL]` but that's not right apparently. + self._value: T = cast(T, NULL) self._value_event = trio.Event() - def set_value(self, value): + def set_value(self, value: T) -> None: ''' Set a value, which will notify any waiters. @@ -701,13 +791,14 @@ def set_value(self, value): self._value = value self._value_event.set() - async def wait_value(self): + async def wait_value(self) -> T: ''' Wait for this future to have a value, then return it. :returns: The value set by ``set_value()``. ''' await self._value_event.wait() + assert self._value is not NULL return self._value @@ -718,7 +809,11 @@ class WebSocketRequest: The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. ''' - def __init__(self, connection, event): + def __init__( + self, + connection: WebSocketConnection, + event: wsproto.events.Request, + ) -> None: ''' Constructor. @@ -729,7 +824,7 @@ def __init__(self, connection, event): self._event = event @property - def headers(self): + def headers(self) -> list[tuple[bytes, bytes]]: ''' HTTP headers represented as a list of (name, value) pairs. @@ -738,7 +833,7 @@ def headers(self): return self._event.extra_headers @property - def path(self): + def path(self) -> str: ''' The requested URL path. @@ -747,7 +842,7 @@ def path(self): return self._event.target @property - def proposed_subprotocols(self): + def proposed_subprotocols(self) -> tuple[str, ...]: ''' A tuple of protocols proposed by the client. @@ -756,7 +851,7 @@ def proposed_subprotocols(self): return tuple(self._event.subprotocols) @property - def local(self): + def local(self) -> Endpoint | str: ''' The connection's local endpoint. @@ -765,7 +860,7 @@ def local(self): return self._connection.local @property - def remote(self): + def remote(self) -> Endpoint | str: ''' The connection's remote endpoint. @@ -773,7 +868,12 @@ def remote(self): ''' return self._connection.remote - async def accept(self, *, subprotocol=None, extra_headers=None): + async def accept( + self, + *, + subprotocol: str | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + ) -> WebSocketConnection: ''' Accept the request and return a connection object. @@ -789,7 +889,13 @@ async def accept(self, *, subprotocol=None, extra_headers=None): await self._connection._accept(self._event, subprotocol, extra_headers) return self._connection - async def reject(self, status_code, *, extra_headers=None, body=None): + async def reject( + self, + status_code: int, + *, + extra_headers: list[tuple[bytes, bytes]] | None = None, + body: bytes | None = None, + ) -> None: ''' Reject the handshake. @@ -807,7 +913,11 @@ async def reject(self, status_code, *, extra_headers=None, body=None): await self._connection._reject(status_code, extra_headers, body) -def _get_stream_endpoint(stream, *, local): +def _get_stream_endpoint( + stream: trio.abc.Stream, + *, + local: bool, +) -> Endpoint | str: ''' Construct an endpoint from a stream. @@ -823,6 +933,7 @@ def _get_stream_endpoint(stream, *, local): elif isinstance(stream, trio.SSLStream): socket = stream.transport_stream.socket is_ssl = True + endpoint: Endpoint | str if socket: addr, port, *_ = socket.getsockname() if local else socket.getpeername() endpoint = Endpoint(addr, port, is_ssl) @@ -837,16 +948,17 @@ class WebSocketConnection(trio.abc.AsyncResource): CONNECTION_ID = itertools.count() def __init__( - self, - stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], - ws_connection: wsproto.WSConnection, - *, - host=None, - path=None, - client_subprotocols=None, client_extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE - ): + self, + stream: trio.abc.Stream, + ws_connection: wsproto.WSConnection, + *, + host: str | None = None, + path: str | None = None, + client_subprotocols: Iterable[str] | None = None, + client_extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE + ) -> None: ''' Constructor. @@ -886,16 +998,18 @@ def __init__( self._max_message_size = max_message_size self._reader_running = True if ws_connection.client: + assert host is not None + assert path is not None self._initial_request: Optional[Request] = Request(host=host, target=path, - subprotocols=client_subprotocols, + subprotocols=list(client_subprotocols or ()), extra_headers=client_extra_headers or []) else: self._initial_request = None self._path = path self._subprotocol: Optional[str] = None - self._handshake_headers: tuple[tuple[str,str], ...] = tuple() + self._handshake_headers: tuple[tuple[bytes, bytes], ...] = () self._reject_status = 0 - self._reject_headers: tuple[tuple[str,str], ...] = tuple() + self._reject_headers: tuple[tuple[bytes, bytes], ...] = () self._reject_body = b'' self._send_channel, self._recv_channel = trio.open_memory_channel[ Union[bytes, str] @@ -903,7 +1017,7 @@ def __init__( self._pings: OrderedDict[bytes, trio.Event] = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. - self._connection_proposal = Future() + self._connection_proposal: Future[WebSocketRequest] | None = Future[WebSocketRequest]() # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. self._open_handshake = trio.Event() @@ -915,7 +1029,7 @@ def __init__( self._for_testing_peer_closed_connection = trio.Event() @property - def closed(self): + def closed(self) -> CloseReason | None: ''' (Read-only) The reason why the connection was or is being closed, else ``None``. @@ -925,17 +1039,17 @@ def closed(self): return self._close_reason @property - def is_client(self): + def is_client(self) -> bool: ''' (Read-only) Is this a client instance? ''' return self._wsproto.client @property - def is_server(self): + def is_server(self) -> bool: ''' (Read-only) Is this a server instance? ''' return not self._wsproto.client @property - def local(self): + def local(self) -> Endpoint | str: ''' The local endpoint of the connection. @@ -944,7 +1058,7 @@ def local(self): return _get_stream_endpoint(self._stream, local=True) @property - def remote(self): + def remote(self) -> Endpoint | str: ''' The remote endpoint of the connection. @@ -953,17 +1067,17 @@ def remote(self): return _get_stream_endpoint(self._stream, local=False) @property - def path(self): + def path(self) -> str | None: ''' The requested URL path. For clients, this is set when the connection is instantiated. For servers, it is set after the handshake completes. - :rtype: str + :rtype: str or None ''' return self._path @property - def subprotocol(self): + def subprotocol(self) -> str | None: ''' (Read-only) The negotiated subprotocol, or ``None`` if there is no subprotocol. @@ -975,7 +1089,7 @@ def subprotocol(self): return self._subprotocol @property - def handshake_headers(self): + def handshake_headers(self) -> tuple[tuple[bytes, bytes], ...]: ''' The HTTP headers that were sent by the remote during the handshake, stored as 2-tuples containing key/value pairs. Header keys are always @@ -985,7 +1099,7 @@ def handshake_headers(self): ''' return self._handshake_headers - async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ + async def aclose(self, code: int = 1000, reason: str | None = None) -> None: # pylint: disable=arguments-differ ''' Close the WebSocket connection. @@ -1003,7 +1117,7 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif with _preserve_current_exception(): await self._aclose(code, reason) - async def _aclose(self, code, reason): + async def _aclose(self, code: int, reason: str | None) -> None: if self._close_reason: # Per AsyncResource interface, calling aclose() on a closed resource # should succeed. @@ -1029,7 +1143,7 @@ async def _aclose(self, code, reason): # stream is closed. await self._close_stream() - async def get_message(self): + async def get_message(self) -> str | bytes: ''' Receive the next WebSocket message. @@ -1052,7 +1166,7 @@ async def get_message(self): raise ConnectionClosed(self._close_reason) from None return message - async def ping(self, payload: bytes|None=None): + async def ping(self, payload: bytes | None = None) -> None: ''' Send WebSocket ping to remote endpoint and wait for a correspoding pong. @@ -1083,7 +1197,7 @@ async def ping(self, payload: bytes|None=None): await self._send(Ping(payload=payload)) await event.wait() - async def pong(self, payload=None): + async def pong(self, payload: bytes | None = None) -> None: ''' Send an unsolicted pong. @@ -1094,9 +1208,9 @@ async def pong(self, payload=None): ''' if self._close_reason: raise ConnectionClosed(self._close_reason) - await self._send(Pong(payload=payload)) + await self._send(Pong(payload=payload or b'')) - async def send_message(self, message): + async def send_message(self, message: str | bytes) -> None: ''' Send a WebSocket message. @@ -1106,6 +1220,7 @@ async def send_message(self, message): ''' if self._close_reason: raise ConnectionClosed(self._close_reason) + event: TextMessage | BytesMessage if isinstance(message, str): event = TextMessage(data=message) elif isinstance(message, bytes): @@ -1114,12 +1229,17 @@ async def send_message(self, message): raise ValueError('message must be str or bytes') await self._send(event) - def __str__(self): + def __str__(self) -> str: ''' Connection ID and type. ''' type_ = 'client' if self.is_client else 'server' return f'{type_}-{self._id}' - async def _accept(self, request, subprotocol, extra_headers): + async def _accept( + self, + request: Request, + subprotocol: str | None, + extra_headers: list[tuple[bytes, bytes]], + ) -> None: ''' Accept the handshake. @@ -1137,7 +1257,12 @@ async def _accept(self, request, subprotocol, extra_headers): extra_headers=extra_headers)) self._open_handshake.set() - async def _reject(self, status_code, headers, body): + async def _reject( + self, + status_code: int, + headers: list[tuple[bytes, bytes]], + body: bytes, + ) -> None: ''' Reject the handshake. @@ -1149,7 +1274,7 @@ async def _reject(self, status_code, headers, body): :param bytes body: An optional response body. ''' if body: - headers.append(('Content-length', str(len(body)).encode('ascii'))) + headers.append((b'Content-length', str(len(body)).encode('ascii'))) reject_conn = RejectConnection(status_code=status_code, headers=headers, has_body=bool(body)) await self._send(reject_conn) @@ -1159,7 +1284,7 @@ async def _reject(self, status_code, headers, body): self._close_reason = CloseReason(1006, 'Rejected WebSocket handshake') self._close_handshake.set() - async def _abort_web_socket(self): + async def _abort_web_socket(self) -> None: ''' If a stream is closed outside of this class, e.g. due to network conditions or because some other code closed our stream object, then we @@ -1176,7 +1301,7 @@ async def _abort_web_socket(self): # (e.g. self.aclose()) to resume. self._close_handshake.set() - async def _close_stream(self): + async def _close_stream(self) -> None: ''' Close the TCP connection. ''' self._reader_running = False try: @@ -1186,7 +1311,7 @@ async def _close_stream(self): # This means the TCP connection is already dead. pass - async def _close_web_socket(self, code, reason=None): + async def _close_web_socket(self, code: int, reason: str | None = None) -> None: ''' Mark the WebSocket as closed. Close the message channel so that if any tasks are suspended in get_message(), they will wake up with a @@ -1197,7 +1322,7 @@ async def _close_web_socket(self, code, reason=None): logger.debug('%s websocket closed %r', self, exc) await self._send_channel.aclose() - async def _get_request(self): + async def _get_request(self) -> WebSocketRequest: ''' Return a proposal for a WebSocket handshake. @@ -1215,7 +1340,7 @@ async def _get_request(self): self._connection_proposal = None return proposal - async def _handle_request_event(self, event): + async def _handle_request_event(self, event: wsproto.events.Request) -> None: ''' Handle a connection request. @@ -1225,9 +1350,10 @@ async def _handle_request_event(self, event): :param event: ''' proposal = WebSocketRequest(self, event) + assert self._connection_proposal is not None self._connection_proposal.set_value(proposal) - async def _handle_accept_connection_event(self, event): + async def _handle_accept_connection_event(self, event: wsproto.events.AcceptConnection) -> None: ''' Handle an AcceptConnection event. @@ -1237,7 +1363,7 @@ async def _handle_accept_connection_event(self, event): self._handshake_headers = tuple(event.extra_headers) self._open_handshake.set() - async def _handle_reject_connection_event(self, event): + async def _handle_reject_connection_event(self, event: wsproto.events.RejectConnection) -> None: ''' Handle a RejectConnection event. @@ -1249,7 +1375,7 @@ async def _handle_reject_connection_event(self, event): raise ConnectionRejected(self._reject_status, self._reject_headers, body=None) - async def _handle_reject_data_event(self, event): + async def _handle_reject_data_event(self, event: wsproto.events.RejectData) -> None: ''' Handle a RejectData event. @@ -1260,7 +1386,7 @@ async def _handle_reject_data_event(self, event): raise ConnectionRejected(self._reject_status, self._reject_headers, body=self._reject_body) - async def _handle_close_connection_event(self, event): + async def _handle_close_connection_event(self, event: wsproto.events.CloseConnection) -> None: ''' Handle a close event. @@ -1281,7 +1407,10 @@ async def _handle_close_connection_event(self, event): if self.is_server: await self._close_stream() - async def _handle_message_event(self, event): + async def _handle_message_event( + self, + event: wsproto.events.BytesMessage | wsproto.events.TextMessage, + ) -> None: ''' Handle a message event. @@ -1299,8 +1428,12 @@ async def _handle_message_event(self, event): await self._recv_channel.aclose() self._reader_running = False elif event.message_finished: - msg = (b'' if isinstance(event, BytesMessage) else '') \ - .join(self._message_parts) + msg: str | bytes + # Type checker does not understand `_message_parts` + if isinstance(event, BytesMessage): + msg = b''.join(cast("list[bytes]", self._message_parts)) + else: + msg = ''.join(cast("list[str]", self._message_parts)) self._message_size = 0 self._message_parts = [] try: @@ -1311,7 +1444,7 @@ async def _handle_message_event(self, event): # and there's no useful cleanup that we can do here. pass - async def _handle_ping_event(self, event): + async def _handle_ping_event(self, event: wsproto.events.Ping) -> None: ''' Handle a PingReceived event. @@ -1323,7 +1456,7 @@ async def _handle_ping_event(self, event): logger.debug('%s ping %r', self, event.payload) await self._send(event.response()) - async def _handle_pong_event(self, event): + async def _handle_pong_event(self, event: wsproto.events.Pong) -> None: ''' Handle a PongReceived event. @@ -1339,20 +1472,20 @@ async def _handle_pong_event(self, event): ''' payload = bytes(event.payload) try: - event = self._pings[payload] + ping_event = self._pings[payload] except KeyError: # We received a pong that doesn't match any in-flight pongs. Nothing # we can do with it, so ignore it. return while self._pings: - key, event = self._pings.popitem(0) + key, ping_event = self._pings.popitem(False) skipped = ' [skipped] ' if payload != key else ' ' logger.debug('%s pong%s%r', self, skipped, key) - event.set() + ping_event.set() if payload == key: break - async def _reader_task(self): + async def _reader_task(self) -> None: ''' A background task that reads network data and generates events. ''' handlers = { AcceptConnection: self._handle_accept_connection_event, @@ -1382,7 +1515,10 @@ async def _reader_task(self): handler = handlers[event_type] logger.debug('%s received event: %s', self, event_type) - await handler(event) + # Type checkers don't understand looking up type in handlers. + # If we wanted to fix, best I can figure is we'd need a huge + # if-else or case block for every type individually. + await handler(event) # type: ignore[operator] except KeyError: logger.warning('%s received unknown event type: "%s"', self, event_type) @@ -1416,7 +1552,7 @@ async def _reader_task(self): logger.debug('%s reader task finished', self) - async def _send(self, event): + async def _send(self, event: wsproto.events.Event) -> None: ''' Send an event to the remote WebSocket. @@ -1433,12 +1569,13 @@ async def _send(self, event): await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() + assert self._close_reason is not None raise ConnectionClosed(self._close_reason) from None class Endpoint: ''' Represents a connection endpoint. ''' - def __init__(self, address, port, is_ssl): + def __init__(self, address: str | int, port: int, is_ssl: bool) -> None: #: IP address :class:`ipaddress.ip_address` self.address = ip_address(address) #: TCP port @@ -1447,7 +1584,7 @@ def __init__(self, address, port, is_ssl): self.is_ssl = is_ssl @property - def url(self): + def url(self) -> str: ''' Return a URL representation of a TCP endpoint, e.g. ``ws://127.0.0.1:80``. ''' scheme = 'wss' if self.is_ssl else 'ws' @@ -1460,7 +1597,7 @@ def url(self): return f'{scheme}://{self.address}{port_str}' return f'{scheme}://[{self.address}]{port_str}' - def __repr__(self): + def __repr__(self) -> str: ''' Return endpoint info as string. ''' return f'Endpoint(address="{self.address}", port={self.port}, is_ssl={self.is_ssl})' @@ -1474,10 +1611,17 @@ class WebSocketServer: instance and starts some background tasks, ''' - def __init__(self, handler, listeners, *, handler_nursery=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT): + def __init__( + self, + handler: Callable[[WebSocketRequest], Awaitable[None]], + listeners: Sequence[trio.SSLListener[trio.SocketStream] | trio.SocketListener], + *, + handler_nursery: trio.Nursery | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, + ) -> None: ''' Constructor. @@ -1509,7 +1653,7 @@ def __init__(self, handler, listeners, *, handler_nursery=None, self._disconnect_timeout = disconnect_timeout @property - def port(self): + def port(self) -> int: """Returns the requested or kernel-assigned port number. In the case of kernel-assigned port (requested with port=0 in the @@ -1522,15 +1666,15 @@ def port(self): """ if len(self._listeners) > 1: raise RuntimeError('Cannot get port because this server has' - ' more than 1 listeners.') + ' more than 1 listener.') listener = self.listeners[0] try: - return listener.port + return listener.port # type: ignore[union-attr] except AttributeError: raise RuntimeError(f'This socket does not have a port: {repr(listener)}') from None @property - def listeners(self): + def listeners(self) -> list[Endpoint | str]: ''' Return a list of listener metadata. Each TCP listener is represented as an ``Endpoint`` instance. Other listener types are represented by their @@ -1539,13 +1683,15 @@ def listeners(self): :returns: Listeners :rtype list[Endpoint or str]: ''' - listeners = [] + listeners: list[Endpoint | str] = [] for listener in self._listeners: socket, is_ssl = None, False if isinstance(listener, trio.SocketListener): socket = listener.socket elif isinstance(listener, trio.SSLListener): - socket = listener.transport_listener.socket + internal_listener = listener.transport_listener + assert isinstance(internal_listener, trio.SocketListener) + socket = internal_listener.socket is_ssl = True if socket: sockname = socket.getsockname() @@ -1554,7 +1700,15 @@ def listeners(self): listeners.append(repr(listener)) return listeners - async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): + # Type ignore is because type checker does not think NoReturn is + # real for Trio 0.25.1 (current version used in requirements file as + # of writing). Not a problem for newer versions however, which is + # why we have unused-ignore as well. + async def run( # type: ignore[misc,unused-ignore] + self, + *, + task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, + ) -> NoReturn: ''' Start serving incoming connections requests. @@ -1567,7 +1721,7 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): ''' async with trio.open_nursery() as nursery: serve_listeners = partial(trio.serve_listeners, - self._handle_connection, self._listeners, + self._handle_connection, list(self._listeners), handler_nursery=self._handler_nursery) await nursery.start(serve_listeners) logger.debug('Listening on %s', @@ -1575,7 +1729,7 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): task_status.started(self) await trio.sleep_forever() - async def _handle_connection(self, stream): + async def _handle_connection(self, stream: trio.abc.Stream) -> None: ''' Handle an incoming connection by spawning a connection background task and a handler task inside a new nursery. diff --git a/trio_websocket/py.typed b/trio_websocket/py.typed new file mode 100644 index 0000000..e69de29