Skip to content

Ensure WebSocket max message size is not limited #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions src/surrealdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ def __call__(cls, *args, **kwargs):

constructed_url = Url(url)


# Extract `max_size` with a default if not explicitly provided
max_size = kwargs.get("max_size", 2 ** 20)

if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS:
return AsyncHttpSurrealConnection(url=url)
elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS:
return AsyncWsSurrealConnection(url=url, max_size=max_size)
return AsyncWsSurrealConnection(url=url)
else:
raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.")

Expand All @@ -54,32 +50,28 @@ def __call__(cls, *args, **kwargs):

constructed_url = Url(url)


# Extract `max_size` with a default if not explicitly provided
max_size = kwargs.get("max_size", 2 ** 20)

if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS:
return BlockingHttpSurrealConnection(url=url)
elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS:
return BlockingWsSurrealConnection(url=url, max_size=max_size)
return BlockingWsSurrealConnection(url=url)
else:
raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.")

def Surreal(url: Optional[str] = None, max_size: int = 2 ** 20) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]:
def Surreal(url: Optional[str] = None) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]:
constructed_url = Url(url)
if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS:
return BlockingHttpSurrealConnection(url=url)
elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS:
return BlockingWsSurrealConnection(url=url, max_size=max_size)
return BlockingWsSurrealConnection(url=url)
else:
raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.")


def AsyncSurreal(url: Optional[str] = None, max_size: int = 2 ** 20) -> Union[AsyncWsSurrealConnection, AsyncHttpSurrealConnection]:
def AsyncSurreal(url: Optional[str] = None) -> Union[AsyncWsSurrealConnection, AsyncHttpSurrealConnection]:
constructed_url = Url(url)
if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS:
return AsyncHttpSurrealConnection(url=url)
elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS:
return AsyncWsSurrealConnection(url=url, max_size=max_size)
return AsyncWsSurrealConnection(url=url)
else:
raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.")
1 change: 0 additions & 1 deletion src/surrealdb/connections/async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class AsyncHttpSurrealConnection(AsyncTemplate, UtilsMixin):

Attributes:
url: The URL of the database to process queries for.
max_size: The maximum size of the connection payload.
id: The ID of the connection.
"""

Expand Down
16 changes: 3 additions & 13 deletions src/surrealdb/connections/async_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,27 @@ class AsyncWsSurrealConnection(AsyncTemplate, UtilsMixin):
"""
A single async connection to a SurrealDB instance. To be used once and discarded.

# Notes
A new connection is created for each query. This is because the async websocket connection is
dropped

Attributes:
url: The URL of the database to process queries for.
user: The username to login on.
password: The password to login on.
namespace: The namespace that the connection will stick to.
database: The database that the connection will stick to.
max_size: The maximum size of the connection.
id: The ID of the connection.
"""
def __init__(
self,
url: str,
max_size: int = 2 ** 20,
) -> None:
"""
The constructor for the AsyncSurrealConnection class.

:param url: The URL of the database to process queries for.
:param max_size: The maximum size of the connection.
"""
self.url: Url = Url(url)
self.raw_url: str = f"{self.url.raw_url}/rpc"
self.host: str = self.url.hostname
self.port: int = self.url.port
self.max_size: int = max_size
self.id: str = str(uuid.uuid4())
self.token: Optional[str] = None
self.socket = None
Expand All @@ -64,19 +56,17 @@ async def _send(self, message: RequestMessage, process: str, bypass: bool = Fals
self.check_response_for_error(response, process)
return response

async def connect(self, url: Optional[str] = None, max_size: Optional[int] = None) -> None:
async def connect(self, url: Optional[str] = None) -> None:
# overwrite params if passed in
if url is not None:
self.url = Url(url)
self.raw_url: str = f"{self.url.raw_url}/rpc"
self.host: str = self.url.hostname
self.port: int = self.url.port
if max_size is not None:
self.max_size = max_size
if self.socket is None:
self.socket = await websockets.connect(
self.raw_url,
max_size=self.max_size,
max_size=None,
subprotocols=[websockets.Subprotocol("cbor")]
)

Expand Down Expand Up @@ -361,7 +351,7 @@ async def __aenter__(self) -> "AsyncWsSurrealConnection":
"""
self.socket = await websockets.connect(
self.raw_url,
max_size=self.max_size,
max_size=None,
subprotocols=[websockets.Subprotocol("cbor")]
)
return self
Expand Down
13 changes: 3 additions & 10 deletions src/surrealdb/connections/blocking_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,25 @@ class BlockingWsSurrealConnection(SyncTemplate, UtilsMixin):
"""
A single blocking connection to a SurrealDB instance. To be used once and discarded.

# Notes
A new connection is created for each query. This is because the WebSocket connection is
dropped after the query is completed.

Attributes:
url: The URL of the database to process queries for.
user: The username to login on.
password: The password to login on.
namespace: The namespace that the connection will stick to.
database: The database that the connection will stick to.
max_size: The maximum size of the connection.
id: The ID of the connection.
"""

def __init__(self, url: str, max_size: int = 2 ** 20) -> None:
def __init__(self, url: str) -> None:
"""
The constructor for the BlockingWsSurrealConnection class.

:param url: (str) the URL of the database to process queries for.
:param max_size: (int) The maximum size of the connection.
"""
self.url: Url = Url(url)
self.raw_url: str = f"{self.url.raw_url}/rpc"
self.host: str = self.url.hostname
self.port: int = self.url.port
self.max_size: int = max_size
self.id: str = str(uuid.uuid4())
self.token: Optional[str] = None
self.socket = None
Expand All @@ -56,7 +49,7 @@ def _send(self, message: RequestMessage, process: str, bypass: bool = False) ->
if self.socket is None:
self.socket = ws_sync.connect(
self.raw_url,
max_size=self.max_size,
max_size=None,
subprotocols=[websockets.Subprotocol("cbor")],
)
self.socket.send(message.WS_CBOR_DESCRIPTOR)
Expand Down Expand Up @@ -361,7 +354,7 @@ def __enter__(self) -> "BlockingWsSurrealConnection":
"""
self.socket = ws_sync.connect(
self.raw_url,
max_size=self.max_size,
max_size=None,
subprotocols=[websockets.Subprotocol("cbor")]
)
return self
Expand Down
Loading