From 0f969ab1daa36d633c374eae7fce0e4429d45f4d Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Wed, 26 Feb 2025 15:18:45 +0000 Subject: [PATCH 01/10] lock working for async recv --- src/surrealdb/connections/async_http.py | 50 +++++++++----- src/surrealdb/connections/async_ws.py | 65 ++++++++++++------- src/surrealdb/connections/blocking_http.py | 50 +++++++++----- src/surrealdb/connections/blocking_ws.py | 61 ++++++++++------- src/surrealdb/request_message/message.py | 6 +- .../connections/batch_async/__init__.py | 0 .../connections/batch_async/test_async_ws.py | 37 +++++++++++ .../descriptors/test_cbor_ws.py | 29 ++------- .../request_message/test_request_message.py | 2 +- 9 files changed, 192 insertions(+), 108 deletions(-) create mode 100644 tests/unit_tests/connections/batch_async/__init__.py create mode 100644 tests/unit_tests/connections/batch_async/test_async_ws.py diff --git a/src/surrealdb/connections/async_http.py b/src/surrealdb/connections/async_http.py index 6e5fa89c..313950bf 100644 --- a/src/surrealdb/connections/async_http.py +++ b/src/surrealdb/connections/async_http.py @@ -98,16 +98,19 @@ def set_token(self, token: str) -> None: self.token = token async def authenticate(self) -> None: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=self.token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=self.token) + self.id = message.id return await self._send(message, "authenticating") async def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id await self._send(message, "invalidating") self.token = None async def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = await self._send(message, "signup") self.check_response_for_result(response, "signup") self.token = response["result"] @@ -115,7 +118,6 @@ async def signup(self, vars: Dict) -> str: async def signin(self, vars: dict) -> dict: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -124,24 +126,26 @@ async def signin(self, vars: dict) -> dict: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = await self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return response["result"] async def info(self) -> dict: - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id response = await self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] async def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.token, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id _ = await self._send(message, "use") self.namespace = namespace self.database = database @@ -152,11 +156,11 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = await self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -167,11 +171,11 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = await self._send(message, "query", bypass=True) return response @@ -185,8 +189,9 @@ async def create( buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data + RequestMethod.CREATE, collection=thing, data=data ) + self.id = message.id response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] @@ -194,7 +199,8 @@ async def create( async def delete( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -203,8 +209,9 @@ async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data + RequestMethod.INSERT, collection=table, params=data ) + self.id = message.id response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -213,8 +220,9 @@ async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -229,8 +237,9 @@ async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data + RequestMethod.MERGE, record_id=thing, data=data ) + self.id = message.id response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -239,14 +248,16 @@ async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data + RequestMethod.PATCH, collection=thing, params=data ) + self.id = message.id response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] async def select(self, thing: str) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -255,14 +266,16 @@ async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data + RequestMethod.UPDATE, record_id=thing, data=data ) + self.id = message.id response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] async def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] @@ -271,8 +284,9 @@ async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data + RequestMethod.UPSERT, record_id=thing, data=data ) + self.id = message.id response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 4cfe4e0b..e0ea6343 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -49,6 +49,7 @@ def __init__( self.id: str = str(uuid.uuid4()) self.token: Optional[str] = None self.socket = None + self.recv_lock = asyncio.Lock() async def _send( self, message: RequestMessage, process: str, bypass: bool = False @@ -58,7 +59,8 @@ async def _send( self.socket is not None ) # will always not be None as the self.connect ensures there's a connection await self.socket.send(message.WS_CBOR_DESCRIPTOR) - response = decode(await self.socket.recv()) + async with self.recv_lock: + response = decode(await self.socket.recv()) if bypass is False: self.check_response_for_error(response, process) return response @@ -78,23 +80,25 @@ async def connect(self, url: Optional[str] = None) -> None: ) async def authenticate(self, token: str) -> dict: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) + self.id = message.id return await self._send(message, "authenticating") async def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id await self._send(message, "invalidating") self.token = None async def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = await self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] async def signin(self, vars: Dict[str, Any]) -> str: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -103,35 +107,37 @@ async def signin(self, vars: Dict[str, Any]) -> str: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = await self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return response["result"] async def info(self) -> Optional[dict]: - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id outcome = await self._send(message, "getting database information") self.check_response_for_result(outcome, "getting database information") return outcome["result"] async def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.id, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id await self._send(message, "use") async def query(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = await self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -140,32 +146,36 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = await self._send(message, "query", bypass=True) return response async def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] async def let(self, key: str, value: Any) -> None: - message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) + message = RequestMessage(RequestMethod.LET, key=key, value=value) + self.id = message.id await self._send(message, "letting") async def unset(self, key: str) -> None: - message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) + message = RequestMessage(RequestMethod.UNSET, params=[key]) + self.id = message.id await self._send(message, "unsetting") async def select( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -180,8 +190,9 @@ async def create( buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data + RequestMethod.CREATE, collection=thing, data=data ) + self.id = message.id response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] @@ -190,8 +201,9 @@ async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data + RequestMethod.UPDATE, record_id=thing, data=data ) + self.id = message.id response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] @@ -200,8 +212,9 @@ async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data + RequestMethod.MERGE, record_id=thing, data=data ) + self.id = message.id response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -210,8 +223,9 @@ async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data + RequestMethod.PATCH, collection=thing, params=data ) + self.id = message.id response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] @@ -219,7 +233,8 @@ async def patch( async def delete( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -228,8 +243,9 @@ async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data + RequestMethod.INSERT, collection=table, params=data ) + self.id = message.id response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -238,18 +254,19 @@ async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: message = RequestMessage( - self.id, RequestMethod.LIVE, table=table, ) + self.id = message.id response = await self._send(message, "live") self.check_response_for_result(response, "live") return response["result"] @@ -281,15 +298,17 @@ async def listen_live(): yield result async def kill(self, query_uuid: Union[str, UUID]) -> None: - message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) + message = RequestMessage(RequestMethod.KILL, uuid=query_uuid) + self.id = message.id await self._send(message, "kill") async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data + RequestMethod.UPSERT, record_id=thing, data=data ) + self.id = message.id response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/connections/blocking_http.py b/src/surrealdb/connections/blocking_http.py index 01380b48..af82ddd0 100644 --- a/src/surrealdb/connections/blocking_http.py +++ b/src/surrealdb/connections/blocking_http.py @@ -57,16 +57,19 @@ def set_token(self, token: str) -> None: self.token = token def authenticate(self, token: str) -> dict: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) + self.id = message.id return self._send(message, "authenticating") def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id self._send(message, "invalidating") self.token = None def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = self._send(message, "signup") self.check_response_for_result(response, "signup") self.token = response["result"] @@ -74,7 +77,6 @@ def signup(self, vars: Dict) -> str: def signin(self, vars: dict) -> str: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -83,24 +85,26 @@ def signin(self, vars: dict) -> str: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return str(response["result"]) def info(self): - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id response = self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.token, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id _ = self._send(message, "use") self.namespace = namespace self.database = database @@ -111,11 +115,11 @@ def query(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -126,11 +130,11 @@ def query_raw(self, query: str, params: Optional[dict] = None) -> dict: for key, value in self.vars.items(): params[key] = value message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query", bypass=True) return response @@ -144,14 +148,16 @@ def create( buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data + RequestMethod.CREATE, collection=thing, data=data ) + self.id = message.id response = self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -160,8 +166,9 @@ def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data + RequestMethod.INSERT, collection=table, params=data ) + self.id = message.id response = self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -170,8 +177,9 @@ def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -186,8 +194,9 @@ def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data + RequestMethod.MERGE, record_id=thing, data=data ) + self.id = message.id response = self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -196,14 +205,16 @@ def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict[Any, Any]] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data + RequestMethod.PATCH, collection=thing, params=data ) + self.id = message.id response = self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -212,14 +223,16 @@ def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data + RequestMethod.UPDATE, record_id=thing, data=data ) + self.id = message.id response = self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] @@ -228,8 +241,9 @@ def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data + RequestMethod.UPSERT, record_id=thing, data=data ) + self.id = message.id response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/connections/blocking_ws.py b/src/surrealdb/connections/blocking_ws.py index 97d1c99f..d00b4e4b 100644 --- a/src/surrealdb/connections/blocking_ws.py +++ b/src/surrealdb/connections/blocking_ws.py @@ -62,23 +62,25 @@ def _send( return response def authenticate(self, token: str) -> dict: - message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) + message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) + self.id = message.id return self._send(message, "authenticating") def invalidate(self) -> None: - message = RequestMessage(self.id, RequestMethod.INVALIDATE) + message = RequestMessage(RequestMethod.INVALIDATE) + self.id = message.id self._send(message, "invalidating") self.token = None def signup(self, vars: Dict) -> str: - message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) + message = RequestMessage(RequestMethod.SIGN_UP, data=vars) + self.id = message.id response = self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] def signin(self, vars: Dict[str, Any]) -> str: message = RequestMessage( - self.id, RequestMethod.SIGN_IN, username=vars.get("username"), password=vars.get("password"), @@ -87,35 +89,37 @@ def signin(self, vars: Dict[str, Any]) -> str: namespace=vars.get("namespace"), variables=vars.get("variables"), ) + self.id = message.id response = self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] return response["result"] def info(self) -> dict: - message = RequestMessage(self.id, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) + self.id = message.id response = self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] def use(self, namespace: str, database: str) -> None: message = RequestMessage( - self.id, RequestMethod.USE, namespace=namespace, database=database, ) + self.id = message.id self._send(message, "use") def query(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -124,30 +128,34 @@ def query_raw(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( - self.id, RequestMethod.QUERY, query=query, params=params, ) + self.id = message.id response = self._send(message, "query", bypass=True) return response def version(self) -> str: - message = RequestMessage(self.id, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) + self.id = message.id response = self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] def let(self, key: str, value: Any) -> None: - message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) + message = RequestMessage(RequestMethod.LET, key=key, value=value) + self.id = message.id self._send(message, "letting") def unset(self, key: str) -> None: - message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) + message = RequestMessage(RequestMethod.UNSET, params=[key]) + self.id = message.id self._send(message, "unsetting") def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) + message = RequestMessage(RequestMethod.SELECT, params=[thing]) + self.id = message.id response = self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -162,28 +170,31 @@ def create( buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, RequestMethod.CREATE, collection=thing, data=data + RequestMethod.CREATE, collection=thing, data=data ) + self.id = message.id response = self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] def live(self, table: Union[str, Table], diff: bool = False) -> UUID: message = RequestMessage( - self.id, RequestMethod.LIVE, table=table, ) + self.id = message.id response = self._send(message, "live") self.check_response_for_result(response, "live") return response["result"] def kill(self, query_uuid: Union[str, UUID]) -> None: - message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) + message = RequestMessage(RequestMethod.KILL, uuid=query_uuid) + self.id = message.id self._send(message, "kill") def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: - message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) + message = RequestMessage(RequestMethod.DELETE, record_id=thing) + self.id = message.id response = self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -192,8 +203,9 @@ def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT, collection=table, params=data + RequestMethod.INSERT, collection=table, params=data ) + self.id = message.id response = self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -202,8 +214,9 @@ def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.INSERT_RELATION, table=table, params=data + RequestMethod.INSERT_RELATION, table=table, params=data ) + self.id = message.id response = self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -212,8 +225,9 @@ def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.MERGE, record_id=thing, data=data + RequestMethod.MERGE, record_id=thing, data=data ) + self.id = message.id response = self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -222,8 +236,9 @@ def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.PATCH, collection=thing, params=data + RequestMethod.PATCH, collection=thing, params=data ) + self.id = message.id response = self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] @@ -261,8 +276,9 @@ def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPDATE, record_id=thing, data=data + RequestMethod.UPDATE, record_id=thing, data=data ) + self.id = message.id response = self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] @@ -271,8 +287,9 @@ def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, RequestMethod.UPSERT, record_id=thing, data=data + RequestMethod.UPSERT, record_id=thing, data=data ) + self.id = message.id response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] diff --git a/src/surrealdb/request_message/message.py b/src/surrealdb/request_message/message.py index 509dbb93..59ce5aac 100644 --- a/src/surrealdb/request_message/message.py +++ b/src/surrealdb/request_message/message.py @@ -1,3 +1,5 @@ +import uuid + from surrealdb.request_message.descriptors.cbor_ws import WsCborDescriptor from surrealdb.request_message.methods import RequestMethod @@ -6,7 +8,7 @@ class RequestMessage: WS_CBOR_DESCRIPTOR = WsCborDescriptor() - def __init__(self, id_for_request, method: RequestMethod, **kwargs) -> None: - self.id = id_for_request + def __init__(self, method: RequestMethod, **kwargs) -> None: + self.id = str(uuid.uuid4()) self.method = method self.kwargs = kwargs diff --git a/tests/unit_tests/connections/batch_async/__init__.py b/tests/unit_tests/connections/batch_async/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py new file mode 100644 index 00000000..e71638f9 --- /dev/null +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -0,0 +1,37 @@ +import asyncio +from unittest import main, IsolatedAsyncioTestCase + +from surrealdb.connections.async_ws import AsyncWsSurrealConnection + + +class TestAsyncWsSurrealConnection(IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.url = "ws://localhost:8000" + self.password = "root" + self.username = "root" + self.vars_params = { + "username": self.username, + "password": self.password, + } + self.database_name = "test_db" + self.namespace = "test_ns" + self.data = { + "username": self.username, + "password": self.password, + } + self.connection = AsyncWsSurrealConnection(self.url) + _ = await self.connection.signin(self.vars_params) + _ = await self.connection.use(namespace=self.namespace, database=self.database_name) + + async def test_batch(self): + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(self.connection.query("select $p**2 as ret from {}", dict(p=num))) for num in range(5)] + + outcome = [t.result()[0]["ret"] for t in tasks] + self.assertEqual([0, 1, 4, 9, 16], outcome) + await self.connection.socket.close() + + +if __name__ == "__main__": + main() diff --git a/tests/unit_tests/request_message/descriptors/test_cbor_ws.py b/tests/unit_tests/request_message/descriptors/test_cbor_ws.py index 2db08555..ed3dece6 100644 --- a/tests/unit_tests/request_message/descriptors/test_cbor_ws.py +++ b/tests/unit_tests/request_message/descriptors/test_cbor_ws.py @@ -7,19 +7,19 @@ class TestWsCborAdapter(TestCase): def test_use_pass(self): - message = RequestMessage(1, RequestMethod.USE, namespace="ns", database="db") + message = RequestMessage(RequestMethod.USE, namespace="ns", database="db") outcome = message.WS_CBOR_DESCRIPTOR self.assertIsInstance(outcome, bytes) def test_use_fail(self): - message = RequestMessage(1, RequestMethod.USE, namespace="ns", database=1) + message = RequestMessage(RequestMethod.USE, namespace="ns", database=1) with self.assertRaises(ValueError) as context: message.WS_CBOR_DESCRIPTOR self.assertEqual( "Invalid schema for Cbor WS encoding for use: {'params': [{1: ['must be of string type']}]}", str(context.exception) ) - message = RequestMessage(1, RequestMethod.USE, namespace="ns") + message = RequestMessage(RequestMethod.USE, namespace="ns") with self.assertRaises(ValueError) as context: message.WS_CBOR_DESCRIPTOR self.assertEqual( @@ -28,18 +28,17 @@ def test_use_fail(self): ) def test_info_pass(self): - message = RequestMessage(1, RequestMethod.INFO) + message = RequestMessage(RequestMethod.INFO) outcome = message.WS_CBOR_DESCRIPTOR self.assertIsInstance(outcome, bytes) def test_version_pass(self): - message = RequestMessage(1, RequestMethod.VERSION) + message = RequestMessage(RequestMethod.VERSION) outcome = message.WS_CBOR_DESCRIPTOR self.assertIsInstance(outcome, bytes) def test_signin_pass_root(self): message = RequestMessage( - 1, RequestMethod.SIGN_IN, username="user", password="pass" @@ -49,7 +48,6 @@ def test_signin_pass_root(self): def test_signin_pass_root_with_none(self): message = RequestMessage( - 1, RequestMethod.SIGN_IN, username="username", password="pass", @@ -62,7 +60,6 @@ def test_signin_pass_root_with_none(self): def test_signin_pass_account(self): message = RequestMessage( - 1, RequestMethod.SIGN_IN, username="username", password="pass", @@ -75,7 +72,6 @@ def test_signin_pass_account(self): def test_authenticate_pass(self): message = RequestMessage( - 1, RequestMethod.AUTHENTICATE, token="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJTdXJyZWFsREIiLCJpYXQiOjE1MTYyMzkwMjIsIm5iZiI6MTUxNjIzOTAyMiwiZXhwIjoxODM2NDM5MDIyLCJOUyI6InRlc3QiLCJEQiI6InRlc3QiLCJTQyI6InVzZXIiLCJJRCI6InVzZXI6dG9iaWUifQ.N22Gp9ze0rdR06McGj1G-h2vu6a6n9IVqUbMFJlOxxA" ) @@ -84,7 +80,6 @@ def test_authenticate_pass(self): def test_invalidate_pass(self): message = RequestMessage( - 1, RequestMethod.INVALIDATE ) outcome = message.WS_CBOR_DESCRIPTOR @@ -92,7 +87,6 @@ def test_invalidate_pass(self): def test_let_pass(self): message = RequestMessage( - 1, RequestMethod.LET, key="key", value="value" @@ -102,7 +96,6 @@ def test_let_pass(self): def test_unset_pass(self): message = RequestMessage( - 1, RequestMethod.UNSET, params=["one", "two", "three"] ) @@ -111,7 +104,6 @@ def test_unset_pass(self): def test_live_pass(self): message = RequestMessage( - 1, RequestMethod.LIVE, table="person" ) @@ -120,7 +112,6 @@ def test_live_pass(self): def test_kill_pass(self): message = RequestMessage( - 1, RequestMethod.KILL, uuid="0189d6e3-8eac-703a-9a48-d9faa78b44b9" ) @@ -129,7 +120,6 @@ def test_kill_pass(self): def test_query_pass(self): message = RequestMessage( - 1, RequestMethod.QUERY, query="query" ) @@ -138,7 +128,6 @@ def test_query_pass(self): def test_create_pass_params(self): message = RequestMessage( - 1, RequestMethod.CREATE, collection="person", data={"table": "table"} @@ -148,7 +137,6 @@ def test_create_pass_params(self): def test_insert_pass_dict(self): message = RequestMessage( - 1, RequestMethod.INSERT, collection="table", params={"key": "value"} @@ -158,7 +146,6 @@ def test_insert_pass_dict(self): def test_insert_pass_list(self): message = RequestMessage( - 1, RequestMethod.INSERT, collection="table", params=[{"key": "value"}, {"key": "value"}] @@ -168,7 +155,6 @@ def test_insert_pass_list(self): def test_patch_pass(self): message = RequestMessage( - 1, RequestMethod.PATCH, collection="table", params=[{"key": "value"}, {"key": "value"}] @@ -178,7 +164,6 @@ def test_patch_pass(self): def test_select_pass(self): message = RequestMessage( - 1, RequestMethod.SELECT, params=["table", "user"], ) @@ -187,7 +172,6 @@ def test_select_pass(self): def test_update_pass(self): message = RequestMessage( - 1, RequestMethod.UPDATE, record_id="test", data={"table": "table"} @@ -197,7 +181,6 @@ def test_update_pass(self): def test_upsert_pass(self): message = RequestMessage( - 1, RequestMethod.UPSERT, record_id="test", data={"table": "table"} @@ -207,7 +190,6 @@ def test_upsert_pass(self): def test_merge_pass(self): message = RequestMessage( - 1, RequestMethod.MERGE, record_id="test", data={"table": "table"} @@ -217,7 +199,6 @@ def test_merge_pass(self): def test_delete_pass(self): message = RequestMessage( - 1, RequestMethod.DELETE, record_id="test", ) diff --git a/tests/unit_tests/request_message/test_request_message.py b/tests/unit_tests/request_message/test_request_message.py index 90c87541..7b1afd0b 100644 --- a/tests/unit_tests/request_message/test_request_message.py +++ b/tests/unit_tests/request_message/test_request_message.py @@ -9,7 +9,7 @@ def setUp(self): self.method = RequestMethod.USE def test_init(self): - request_message = RequestMessage(1, self.method, one="two", three="four") + request_message = RequestMessage(self.method, one="two", three="four") self.assertEqual(request_message.method, self.method) self.assertEqual(request_message.kwargs, {"one": "two", "three": "four"}) From 6ede6bacae2d6a5d76460f7a0686223c19653c46 Mon Sep 17 00:00:00 2001 From: Antonin ENFRUN Date: Thu, 27 Feb 2025 13:19:47 +0100 Subject: [PATCH 02/10] tests: test ordering of concurrent query in async mode (#166) CI tests are failing on this merge but this is because it's introducing a new test case into the `async-batching` branch that the `async-batching` branch needs to address --- tests/unit_tests/connections/batch_async/test_async_ws.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py index e71638f9..2d07bddd 100644 --- a/tests/unit_tests/connections/batch_async/test_async_ws.py +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -26,9 +26,9 @@ async def asyncSetUp(self): async def test_batch(self): async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(self.connection.query("select $p**2 as ret from {}", dict(p=num))) for num in range(5)] + tasks = [tg.create_task(self.connection.query("RETURN sleep(duration::from::millis($d)) or $p**2", dict(d=10 if num%2 else 0, p=num))) for num in range(5)] - outcome = [t.result()[0]["ret"] for t in tasks] + outcome = [t.result() for t in tasks] self.assertEqual([0, 1, 4, 9, 16], outcome) await self.connection.socket.close() From 3499108fbae42a551d67633be07fc1f91a7ce89c Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Thu, 27 Feb 2025 15:56:28 +0000 Subject: [PATCH 03/10] adding a lock for async batching --- src/surrealdb/connections/async_ws.py | 19 ++++++++++++++++--- src/surrealdb/connections/socket_state.py | 9 +++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 src/surrealdb/connections/socket_state.py diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index e0ea6343..276a85e7 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -1,7 +1,6 @@ """ A basic async connection to a SurrealDB instance. """ - import asyncio import uuid from asyncio import Queue @@ -49,7 +48,8 @@ def __init__( self.id: str = str(uuid.uuid4()) self.token: Optional[str] = None self.socket = None - self.recv_lock = asyncio.Lock() + self.lock = asyncio.Lock() + self.ref = dict() async def _send( self, message: RequestMessage, process: str, bypass: bool = False @@ -58,9 +58,22 @@ async def _send( assert ( self.socket is not None ) # will always not be None as the self.connect ensures there's a connection + + query_id = message.id await self.socket.send(message.WS_CBOR_DESCRIPTOR) - async with self.recv_lock: + + async with self.lock: response = decode(await self.socket.recv()) + self.ref[response["id"]] = response + + # wait for ID to be returned + while self.ref.get(query_id) is None: + await asyncio.sleep(0) # The await simply yields to the executor to avoid deadlocks + + # set the response and clean up + response = self.ref[query_id] + del self.ref[query_id] + if bypass is False: self.check_response_for_error(response, process) return response diff --git a/src/surrealdb/connections/socket_state.py b/src/surrealdb/connections/socket_state.py new file mode 100644 index 00000000..2b630738 --- /dev/null +++ b/src/surrealdb/connections/socket_state.py @@ -0,0 +1,9 @@ + + +# singleton + +# dict for sockets + +# socket with smart reference counter + +# \ No newline at end of file From 8ba7570ce4a27dba8ce2f2dee565d54384c6f98b Mon Sep 17 00:00:00 2001 From: Antonin ENFRUN Date: Fri, 28 Feb 2025 14:35:49 +0100 Subject: [PATCH 04/10] Allow concurrent query for async ws connections (#167) Co-authored-by: Maxwell Flitton --- src/surrealdb/connections/async_template.py | 74 ++++----- src/surrealdb/connections/async_ws.py | 142 +++++++++--------- .../connections/authenticate/test_async_ws.py | 1 - .../connections/batch_async/test_async_ws.py | 1 - .../connections/create/test_async_ws.py | 7 - .../connections/delete/test_async_ws.py | 3 - .../connections/info/test_async_ws.py | 1 - .../connections/insert/test_async_ws.py | 2 - .../insert_relation/test_async_ws.py | 2 - .../connections/let/test_async_ws.py | 1 - .../connections/live/test_async_http.py | 1 - .../connections/live/test_async_ws.py | 1 - .../connections/merge/test_async_ws.py | 6 - .../connections/patch/test_async_ws.py | 3 - .../connections/query/test_async_http.py | 1 - .../connections/query/test_async_ws.py | 1 - .../connections/select/test_async_ws.py | 1 - .../connections/signin/test_async_ws.py | 8 - .../subscribe_live/test_async_ws.py | 2 - .../connections/unset/test_async_ws.py | 1 - .../connections/update/test_async_ws.py | 6 - .../connections/upsert/test_async_ws.py | 6 - .../connections/version/test_async_ws.py | 1 - tests/unit_tests/data_types/test_datetimes.py | 1 - 24 files changed, 104 insertions(+), 169 deletions(-) diff --git a/src/surrealdb/connections/async_template.py b/src/surrealdb/connections/async_template.py index 70e6053c..c63dac0f 100644 --- a/src/surrealdb/connections/async_template.py +++ b/src/surrealdb/connections/async_template.py @@ -7,7 +7,7 @@ class AsyncTemplate: - async def connect(self, url: str) -> Coroutine[Any, Any, None]: + async def connect(self, url: str) -> None: """Connects to a local or remote database endpoint. Args: @@ -18,17 +18,17 @@ async def connect(self, url: str) -> Coroutine[Any, Any, None]: # Connect to a remote endpoint await db.connect('https://cloud.surrealdb.com/rpc'); """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"connect not implemented for: {self}") - async def close(self) -> Coroutine[Any, Any, None]: + async def close(self) -> None: """Closes the persistent connection to the database. Example: await db.close() """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"close not implemented for: {self}") - async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: + async def use(self, namespace: str, database: str) -> None: """Switch to a specific namespace and database. Args: @@ -38,9 +38,9 @@ async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: Example: await db.use('test', 'test') """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"use not implemented for: {self}") - async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: + async def authenticate(self, token: str) -> None: """Authenticate the current connection with a JWT token. Args: @@ -51,7 +51,7 @@ async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"authenticate not implemented for: {self}") - async def invalidate(self) -> Coroutine[Any, Any, None]: + async def invalidate(self) -> None: """Invalidate the authentication for the current connection. Example: @@ -59,7 +59,7 @@ async def invalidate(self) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"invalidate not implemented for: {self}") - async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: + async def signup(self, vars: Dict) -> str: """Sign this connection up to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signup) @@ -81,7 +81,7 @@ async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: """ raise NotImplementedError(f"signup not implemented for: {self}") - async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: + async def signin(self, vars: Dict) -> str: """Sign this connection in to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signin) @@ -94,9 +94,9 @@ async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: password: 'surrealdb', }) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"signin not implemented for: {self}") - async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: + async def let(self, key: str, value: Any) -> None: """Assign a value as a variable for this connection. Args: @@ -115,7 +115,7 @@ async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"let not implemented for: {self}") - async def unset(self, key: str) -> Coroutine[Any, Any, None]: + async def unset(self, key: str) -> None: """Removes a variable for this connection. Args: @@ -124,11 +124,11 @@ async def unset(self, key: str) -> Coroutine[Any, Any, None]: Example: await db.unset('name') """ - raise NotImplementedError(f"let not implemented for: {self}") + raise NotImplementedError(f"unset not implemented for: {self}") async def query( self, query: str, vars: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Run a unset of SurrealQL statements against the database. Args: @@ -145,7 +145,7 @@ async def query( async def select( self, thing: Union[str, RecordID, Table] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Select all records in a table (or other entity), or a specific record, in the database. @@ -158,13 +158,13 @@ async def select( Example: db.select('person') """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"select not implemented for: {self}") async def create( self, thing: Union[str, RecordID, Table], data: Optional[Union[List[dict], dict]] = None, - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Create a record in the database. This function will run the following query in the database: @@ -181,7 +181,7 @@ async def create( async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Update all records in a table, or a specific record, in the database. This function replaces the current document / record data with the @@ -207,11 +207,11 @@ async def update( }, }) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"update not implemented for: {self}") async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Insert records into the database, or to update them if they exist. @@ -239,7 +239,7 @@ async def upsert( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Modify by deep merging all records in a table, or a specific record, in the database. This function merges the current document / record data with the @@ -267,11 +267,11 @@ async def merge( }) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"merge not implemented for: {self}") async def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Apply JSON Patch changes to all records, or a specific record, in the database. This function patches the current document / record data with @@ -296,11 +296,11 @@ async def patch( { 'op': "remove", "path": "/temp" }, ]) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"patch not implemented for: {self}") async def delete( self, thing: Union[str, RecordID, Table] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """Delete all records in a table, or a specific record, from the database. This function will run the following query in the database: @@ -324,11 +324,11 @@ async def info(self) -> Coroutine[Any, Any, dict]: Example: await db.info() """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"info not implemented for: {self}") async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """ Inserts one or multiple records in the database. @@ -343,11 +343,11 @@ async def insert( await db.insert('person', [{ name: 'Tobie'}, { name: 'Jaime'}]) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"insert not implemented for: {self}") async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: + ) -> Union[List[dict], dict]: """ Inserts one or multiple relations in the database. @@ -362,11 +362,11 @@ async def insert_relation( await db.insert_relation('likes', { in: person:1, id: 'object', out: person:2}) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"insert_relation not implemented for: {self}") async def live( self, table: Union[str, Table], diff: bool = False - ) -> Coroutine[Any, Any, UUID]: + ) -> UUID: """Initiates a live query for a specified table name. Args: @@ -381,11 +381,11 @@ async def live( Example: await db.live('person') """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"live not implemented for: {self}") async def subscribe_live( self, query_uuid: Union[str, UUID] - ) -> Coroutine[Any, Any, Queue]: + ) -> Queue: """Returns a queue that receives notification messages from a running live query. Args: @@ -397,9 +397,9 @@ async def subscribe_live( Example: await db.subscribe_live(UUID) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"subscribe_live not implemented for: {self}") - async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: + async def kill(self, query_uuid: Union[str, UUID]) -> None: """Kills a running live query by it's UUID. Args: @@ -409,4 +409,4 @@ async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: await db.kill(UUID) """ - raise NotImplementedError(f"query not implemented for: {self}") + raise NotImplementedError(f"kill not implemented for: {self}") diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 276a85e7..7eb0f693 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -1,9 +1,10 @@ """ A basic async connection to a SurrealDB instance. """ + import asyncio import uuid -from asyncio import Queue +from asyncio import Queue, Task, Future, AbstractEventLoop from typing import Optional, Any, Dict, Union, List, AsyncGenerator from uuid import UUID @@ -45,67 +46,81 @@ def __init__( self.raw_url: str = f"{self.url.raw_url}/rpc" self.host: Optional[str] = self.url.hostname self.port: Optional[int] = self.url.port - self.id: str = str(uuid.uuid4()) self.token: Optional[str] = None self.socket = None - self.lock = asyncio.Lock() - self.ref = dict() + self.loop: AbstractEventLoop|None = None + self.qry:dict[str, Future] = {} + self.recv_task:Task[None]|None = None + self.live_queues:dict[str, list] = {} + + async def _recv_task(self): + assert self.socket + async for data in self.socket: + response = decode(data) + if (response_id := response.get("id")): + if fut := self.qry.get(response_id): + fut.set_result(response) + else: + live_id = str(response['result']['id']) + for queue in self.live_queues.get(live_id, []): + queue.put_nowait(response['result']) async def _send( self, message: RequestMessage, process: str, bypass: bool = False ) -> dict: await self.connect() assert ( - self.socket is not None + self.socket is not None and self.loop is not None ) # will always not be None as the self.connect ensures there's a connection + # setup future to wait for response + fut = self.loop.create_future() query_id = message.id - await self.socket.send(message.WS_CBOR_DESCRIPTOR) - - async with self.lock: - response = decode(await self.socket.recv()) - self.ref[response["id"]] = response - - # wait for ID to be returned - while self.ref.get(query_id) is None: - await asyncio.sleep(0) # The await simply yields to the executor to avoid deadlocks + self.qry[query_id] = fut + try: + # correlate mesage to query, send and forget it + await self.socket.send(message.WS_CBOR_DESCRIPTOR) + del message - # set the response and clean up - response = self.ref[query_id] - del self.ref[query_id] + # wait for response + response = await fut + finally: + del self.qry[query_id] if bypass is False: self.check_response_for_error(response, process) return response async def connect(self, url: Optional[str] = None) -> None: + if self.socket: + return + # overwrite params if passed in if url is not None: self.url = Url(url) self.raw_url = f"{self.url.raw_url}/rpc" self.host = self.url.hostname self.port = self.url.port - if self.socket is None: - self.socket = await websockets.connect( - self.raw_url, - max_size=None, - subprotocols=[websockets.Subprotocol("cbor")], - ) + + self.socket = await websockets.connect( + self.raw_url, + max_size=None, + subprotocols=[websockets.Subprotocol("cbor")], + ) + self.loop = asyncio.get_running_loop() + self.recv_task = asyncio.create_task(self._recv_task()) async def authenticate(self, token: str) -> dict: message = RequestMessage(RequestMethod.AUTHENTICATE, token=token) - self.id = message.id return await self._send(message, "authenticating") async def invalidate(self) -> None: message = RequestMessage(RequestMethod.INVALIDATE) - self.id = message.id await self._send(message, "invalidating") self.token = None async def signup(self, vars: Dict) -> str: message = RequestMessage(RequestMethod.SIGN_UP, data=vars) - self.id = message.id response = await self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] @@ -120,7 +135,6 @@ async def signin(self, vars: Dict[str, Any]) -> str: namespace=vars.get("namespace"), variables=vars.get("variables"), ) - self.id = message.id response = await self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] @@ -128,7 +142,6 @@ async def signin(self, vars: Dict[str, Any]) -> str: async def info(self) -> Optional[dict]: message = RequestMessage(RequestMethod.INFO) - self.id = message.id outcome = await self._send(message, "getting database information") self.check_response_for_result(outcome, "getting database information") return outcome["result"] @@ -139,7 +152,6 @@ async def use(self, namespace: str, database: str) -> None: namespace=namespace, database=database, ) - self.id = message.id await self._send(message, "use") async def query(self, query: str, params: Optional[dict] = None) -> dict: @@ -150,7 +162,6 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: query=query, params=params, ) - self.id = message.id response = await self._send(message, "query") self.check_response_for_result(response, "query") return response["result"][0]["result"] @@ -163,32 +174,27 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: query=query, params=params, ) - self.id = message.id response = await self._send(message, "query", bypass=True) return response async def version(self) -> str: message = RequestMessage(RequestMethod.VERSION) - self.id = message.id response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] async def let(self, key: str, value: Any) -> None: message = RequestMessage(RequestMethod.LET, key=key, value=value) - self.id = message.id await self._send(message, "letting") async def unset(self, key: str) -> None: message = RequestMessage(RequestMethod.UNSET, params=[key]) - self.id = message.id await self._send(message, "unsetting") async def select( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: message = RequestMessage(RequestMethod.SELECT, params=[thing]) - self.id = message.id response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] @@ -205,7 +211,6 @@ async def create( message = RequestMessage( RequestMethod.CREATE, collection=thing, data=data ) - self.id = message.id response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] @@ -216,7 +221,6 @@ async def update( message = RequestMessage( RequestMethod.UPDATE, record_id=thing, data=data ) - self.id = message.id response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] @@ -227,7 +231,6 @@ async def merge( message = RequestMessage( RequestMethod.MERGE, record_id=thing, data=data ) - self.id = message.id response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -238,7 +241,6 @@ async def patch( message = RequestMessage( RequestMethod.PATCH, collection=thing, params=data ) - self.id = message.id response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] @@ -247,7 +249,6 @@ async def delete( self, thing: Union[str, RecordID, Table] ) -> Union[List[dict], dict]: message = RequestMessage(RequestMethod.DELETE, record_id=thing) - self.id = message.id response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] @@ -258,7 +259,6 @@ async def insert( message = RequestMessage( RequestMethod.INSERT, collection=table, params=data ) - self.id = message.id response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -269,7 +269,6 @@ async def insert_relation( message = RequestMessage( RequestMethod.INSERT_RELATION, table=table, params=data ) - self.id = message.id response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] @@ -279,41 +278,29 @@ async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: RequestMethod.LIVE, table=table, ) - self.id = message.id response = await self._send(message, "live") self.check_response_for_result(response, "live") - return response["result"] + uuid = response["result"] + assert uuid not in self.live_queues + self.live_queues[str(uuid)] = [] + return uuid - async def subscribe_live( + def subscribe_live( self, query_uuid: Union[str, UUID] - ) -> AsyncGenerator[dict, None]: + ) -> Queue: result_queue = Queue() - - async def listen_live(): - """ - Listen for live updates from the WebSocket and put them into the queue. - """ - try: - while True: - response = decode(await self.socket.recv()) - if response.get("result", {}).get("id") == query_uuid: - await result_queue.put(response["result"]["result"]) - except Exception as e: - print("Error in live subscription:", e) - await result_queue.put({"error": str(e)}) - - asyncio.create_task(listen_live()) - - while True: - result = await result_queue.get() - if "error" in result: - raise Exception(f"Error in live subscription: {result['error']}") - yield result + suid = str(query_uuid) + self.live_queues[suid].append(result_queue) + async def _iter(): + while True: + ret = await result_queue.get() + yield ret['result'] + return _iter() async def kill(self, query_uuid: Union[str, UUID]) -> None: message = RequestMessage(RequestMethod.KILL, uuid=query_uuid) - self.id = message.id await self._send(message, "kill") + self.live_queues.pop(str(query_uuid), None) async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None @@ -321,22 +308,28 @@ async def upsert( message = RequestMessage( RequestMethod.UPSERT, record_id=thing, data=data ) - self.id = message.id response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] async def close(self): - await self.socket.close() + if self.recv_task: + self.recv_task.cancel() + try: + await self.recv_task + except asyncio.CancelledError: + pass + + if self.socket is not None: + await self.socket.close() + async def __aenter__(self) -> "AsyncWsSurrealConnection": """ Asynchronous context manager entry. Initializes a websocket connection and returns the connection instance. """ - self.socket = await websockets.connect( - self.raw_url, max_size=None, subprotocols=[websockets.Subprotocol("cbor")] - ) + await self.connect() return self async def __aexit__(self, exc_type, exc_value, traceback) -> None: @@ -344,5 +337,4 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: Asynchronous context manager exit. Closes the websocket connection upon exiting the context. """ - if self.socket is not None: - await self.socket.close() + await self.close() \ No newline at end of file diff --git a/tests/unit_tests/connections/authenticate/test_async_ws.py b/tests/unit_tests/connections/authenticate/test_async_ws.py index 53707a45..2304191e 100644 --- a/tests/unit_tests/connections/authenticate/test_async_ws.py +++ b/tests/unit_tests/connections/authenticate/test_async_ws.py @@ -21,7 +21,6 @@ async def asyncSetUp(self): async def test_authenticate(self): outcome = await self.connection.authenticate(token=self.connection.token) - await self.connection.socket.close() diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py index 2d07bddd..6ee6a0ab 100644 --- a/tests/unit_tests/connections/batch_async/test_async_ws.py +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -30,7 +30,6 @@ async def test_batch(self): outcome = [t.result() for t in tasks] self.assertEqual([0, 1, 4, 9, 16], outcome) - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/create/test_async_ws.py b/tests/unit_tests/connections/create/test_async_ws.py index 6db2b7fe..ff545baa 100644 --- a/tests/unit_tests/connections/create/test_async_ws.py +++ b/tests/unit_tests/connections/create/test_async_ws.py @@ -35,7 +35,6 @@ async def test_create_string(self): 1 ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_string_with_data(self): outcome = await self.connection.create("user", self.data) @@ -53,7 +52,6 @@ async def test_create_string_with_data(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_string_with_data_and_id(self): first_outcome = await self.connection.create("user:tobie", self.data) @@ -73,7 +71,6 @@ async def test_create_string_with_data_and_id(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_record_id(self): record_id = RecordID("user",1) @@ -87,7 +84,6 @@ async def test_create_record_id(self): ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_record_id_with_data(self): record_id = RecordID("user", 1) @@ -107,7 +103,6 @@ async def test_create_record_id_with_data(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_table(self): table = Table("user") @@ -120,7 +115,6 @@ async def test_create_table(self): ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_create_table_with_data(self): table = Table("user") @@ -139,7 +133,6 @@ async def test_create_table_with_data(self): self.assertEqual(self.username, outcome[0]["username"]) await self.connection.query("DELETE user;") - await self.connection.socket.close() diff --git a/tests/unit_tests/connections/delete/test_async_ws.py b/tests/unit_tests/connections/delete/test_async_ws.py index 4a6f73f7..cc6ad40c 100644 --- a/tests/unit_tests/connections/delete/test_async_ws.py +++ b/tests/unit_tests/connections/delete/test_async_ws.py @@ -43,14 +43,12 @@ async def test_delete_string(self): self.check_no_change(outcome) outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(outcome, []) - await self.connection.socket.close() async def test_delete_record_id(self): first_outcome = await self.connection.delete(self.record_id) self.check_no_change(first_outcome) outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(outcome, []) - await self.connection.socket.close() async def test_delete_table(self): await self.connection.query("CREATE user:jaime SET name = 'Jaime';") @@ -59,7 +57,6 @@ async def test_delete_table(self): self.assertEqual(2, len(first_outcome)) outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(outcome, []) - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/info/test_async_ws.py b/tests/unit_tests/connections/info/test_async_ws.py index 7c77fe1e..85816de2 100644 --- a/tests/unit_tests/connections/info/test_async_ws.py +++ b/tests/unit_tests/connections/info/test_async_ws.py @@ -21,7 +21,6 @@ async def asyncSetUp(self): async def test_info(self): outcome = await self.connection.info() - await self.connection.socket.close() # TODO => confirm that the info is what we expect diff --git a/tests/unit_tests/connections/insert/test_async_ws.py b/tests/unit_tests/connections/insert/test_async_ws.py index 88f51dc8..9f6ee9d2 100644 --- a/tests/unit_tests/connections/insert/test_async_ws.py +++ b/tests/unit_tests/connections/insert/test_async_ws.py @@ -46,7 +46,6 @@ async def test_insert_string_with_data(self): 2 ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_insert_record_id_result_error(self): record_id = RecordID("user","tobie") @@ -59,7 +58,6 @@ async def test_insert_record_id_result_error(self): True ) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/insert_relation/test_async_ws.py b/tests/unit_tests/connections/insert_relation/test_async_ws.py index 9ae31560..3f484575 100644 --- a/tests/unit_tests/connections/insert_relation/test_async_ws.py +++ b/tests/unit_tests/connections/insert_relation/test_async_ws.py @@ -76,7 +76,6 @@ async def test_insert_relation_record_ids(self): ) await self.connection.query("DELETE user;") await self.connection.query("DELETE likes;") - await self.connection.socket.close() async def test_insert_relation_record_id(self): data = { @@ -94,7 +93,6 @@ async def test_insert_relation_record_id(self): ) await self.connection.query("DELETE user;") await self.connection.query("DELETE likes;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/let/test_async_ws.py b/tests/unit_tests/connections/let/test_async_ws.py index a4faeda5..fb6a98f7 100644 --- a/tests/unit_tests/connections/let/test_async_ws.py +++ b/tests/unit_tests/connections/let/test_async_ws.py @@ -30,7 +30,6 @@ async def test_let(self): outcome = await self.connection.query('SELECT * FROM person WHERE name.first = $name.first') self.assertEqual({'first': 'Tobie', 'last': 'Morgan Hitchcock'}, outcome[0]["name"]) await self.connection.query("DELETE person;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/live/test_async_http.py b/tests/unit_tests/connections/live/test_async_http.py index 80b054e7..1253dea4 100644 --- a/tests/unit_tests/connections/live/test_async_http.py +++ b/tests/unit_tests/connections/live/test_async_http.py @@ -26,7 +26,6 @@ async def test_query(self): outcome = await self.connection.live("user") self.assertEqual(UUID, type(outcome)) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/live/test_async_ws.py b/tests/unit_tests/connections/live/test_async_ws.py index 80b054e7..1253dea4 100644 --- a/tests/unit_tests/connections/live/test_async_ws.py +++ b/tests/unit_tests/connections/live/test_async_ws.py @@ -26,7 +26,6 @@ async def test_query(self): outcome = await self.connection.live("user") self.assertEqual(UUID, type(outcome)) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/merge/test_async_ws.py b/tests/unit_tests/connections/merge/test_async_ws.py index 1f151842..92a5f09b 100644 --- a/tests/unit_tests/connections/merge/test_async_ws.py +++ b/tests/unit_tests/connections/merge/test_async_ws.py @@ -51,7 +51,6 @@ async def test_merge_string(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_string_with_data(self): first_outcome = await self.connection.merge("user:tobie", self.data) @@ -59,7 +58,6 @@ async def test_merge_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_record_id(self): first_outcome = await self.connection.merge(self.record_id) @@ -67,7 +65,6 @@ async def test_merge_record_id(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_record_id_with_data(self): outcome = await self.connection.merge(self.record_id, self.data) @@ -77,7 +74,6 @@ async def test_merge_record_id_with_data(self): outcome[0] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_table(self): table = Table("user") @@ -87,7 +83,6 @@ async def test_merge_table(self): self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_merge_table_with_data(self): table = Table("user") @@ -96,7 +91,6 @@ async def test_merge_table_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/patch/test_async_ws.py b/tests/unit_tests/connections/patch/test_async_ws.py index f4c54327..d6c408bd 100644 --- a/tests/unit_tests/connections/patch/test_async_ws.py +++ b/tests/unit_tests/connections/patch/test_async_ws.py @@ -43,7 +43,6 @@ async def test_patch_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_patch_record_id_with_data(self): outcome = await self.connection.patch(self.record_id, self.data) @@ -51,7 +50,6 @@ async def test_patch_record_id_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_patch_table_with_data(self): table = Table("user") @@ -60,7 +58,6 @@ async def test_patch_table_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/query/test_async_http.py b/tests/unit_tests/connections/query/test_async_http.py index 444f6fb1..4044d34b 100644 --- a/tests/unit_tests/connections/query/test_async_http.py +++ b/tests/unit_tests/connections/query/test_async_http.py @@ -55,7 +55,6 @@ async def test_query(self): ] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/query/test_async_ws.py b/tests/unit_tests/connections/query/test_async_ws.py index 7ffcf483..541bb116 100644 --- a/tests/unit_tests/connections/query/test_async_ws.py +++ b/tests/unit_tests/connections/query/test_async_ws.py @@ -55,7 +55,6 @@ async def test_query(self): ] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/select/test_async_ws.py b/tests/unit_tests/connections/select/test_async_ws.py index 71328260..11f830db 100644 --- a/tests/unit_tests/connections/select/test_async_ws.py +++ b/tests/unit_tests/connections/select/test_async_ws.py @@ -42,7 +42,6 @@ async def test_select(self): await self.connection.query("DELETE user;") await self.connection.query("DELETE users;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/signin/test_async_ws.py b/tests/unit_tests/connections/signin/test_async_ws.py index 23245bdc..c49618c1 100644 --- a/tests/unit_tests/connections/signin/test_async_ws.py +++ b/tests/unit_tests/connections/signin/test_async_ws.py @@ -47,8 +47,6 @@ async def test_signin_root(self): self.assertIsNotNone(response) _ = await self.connection.query("DELETE user;") _ = await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() async def test_signin_namespace(self): connection = AsyncWsSurrealConnection(self.url) @@ -61,8 +59,6 @@ async def test_signin_namespace(self): self.assertIsNotNone(response) _ = await self.connection.query("DELETE user;") _ = await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() async def test_signin_database(self): connection = AsyncWsSurrealConnection(self.url) @@ -76,8 +72,6 @@ async def test_signin_database(self): self.assertIsNotNone(response) _ = await self.connection.query("DELETE user;") _ = await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() async def test_signin_record(self): vars = { @@ -99,8 +93,6 @@ async def test_signin_record(self): await self.connection.query("DELETE user;") await self.connection.query("REMOVE TABLE user;") - await self.connection.socket.close() - await connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/subscribe_live/test_async_ws.py b/tests/unit_tests/connections/subscribe_live/test_async_ws.py index 3afdb882..2254382c 100644 --- a/tests/unit_tests/connections/subscribe_live/test_async_ws.py +++ b/tests/unit_tests/connections/subscribe_live/test_async_ws.py @@ -50,8 +50,6 @@ async def test_live_subscription(self): # Cleanup the subscription await self.pub_connection.query("DELETE user;") - await self.pub_connection.socket.close() - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/unset/test_async_ws.py b/tests/unit_tests/connections/unset/test_async_ws.py index bb4633d1..113c7ec6 100644 --- a/tests/unit_tests/connections/unset/test_async_ws.py +++ b/tests/unit_tests/connections/unset/test_async_ws.py @@ -38,7 +38,6 @@ async def test_unset(self): self.assertEqual([], outcome) await self.connection.query("DELETE person;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/update/test_async_ws.py b/tests/unit_tests/connections/update/test_async_ws.py index 0452d59b..3e8c2e93 100644 --- a/tests/unit_tests/connections/update/test_async_ws.py +++ b/tests/unit_tests/connections/update/test_async_ws.py @@ -51,7 +51,6 @@ async def test_update_string(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_string_with_data(self): first_outcome = await self.connection.update("user:tobie", self.data) @@ -59,7 +58,6 @@ async def test_update_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_record_id(self): first_outcome = await self.connection.update(self.record_id) @@ -67,7 +65,6 @@ async def test_update_record_id(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_record_id_with_data(self): outcome = await self.connection.update(self.record_id, self.data) @@ -77,7 +74,6 @@ async def test_update_record_id_with_data(self): outcome[0] ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_table(self): table = Table("user") @@ -87,7 +83,6 @@ async def test_update_table(self): self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_update_table_with_data(self): table = Table("user") @@ -96,7 +91,6 @@ async def test_update_table_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/upsert/test_async_ws.py b/tests/unit_tests/connections/upsert/test_async_ws.py index 5bcd949b..e237d29a 100644 --- a/tests/unit_tests/connections/upsert/test_async_ws.py +++ b/tests/unit_tests/connections/upsert/test_async_ws.py @@ -52,7 +52,6 @@ async def test_upsert_string(self): outcome = await self.connection.query("SELECT * FROM user;") # self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_string_with_data(self): first_outcome = await self.connection.upsert("user:tobie", self.data) @@ -60,7 +59,6 @@ async def test_upsert_string_with_data(self): outcome = await self.connection.query("SELECT * FROM user;") # self.check_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_record_id(self): first_outcome = await self.connection.upsert(self.record_id) @@ -68,7 +66,6 @@ async def test_upsert_record_id(self): outcome = await self.connection.query("SELECT * FROM user;") # self.check_no_change(outcome[0]) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_record_id_with_data(self): outcome = await self.connection.upsert(self.record_id, self.data) @@ -78,7 +75,6 @@ async def test_upsert_record_id_with_data(self): # outcome[0] # ) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_table(self): table = Table("user") @@ -89,7 +85,6 @@ async def test_upsert_table(self): # self.check_no_change(outcome[1], random_id=True) await self.connection.query("DELETE user;") - await self.connection.socket.close() async def test_upsert_table_with_data(self): table = Table("user") @@ -99,7 +94,6 @@ async def test_upsert_table_with_data(self): self.assertEqual(2, len(outcome)) # self.check_change(outcome[0], random_id=True) await self.connection.query("DELETE user;") - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/connections/version/test_async_ws.py b/tests/unit_tests/connections/version/test_async_ws.py index 5e602d60..fa29f2b8 100644 --- a/tests/unit_tests/connections/version/test_async_ws.py +++ b/tests/unit_tests/connections/version/test_async_ws.py @@ -21,7 +21,6 @@ async def asyncSetUp(self): async def test_version(self): self.assertEqual(str, type(await self.connection.version())) - await self.connection.socket.close() if __name__ == "__main__": diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py index 311a6f6a..a221710b 100644 --- a/tests/unit_tests/data_types/test_datetimes.py +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -76,7 +76,6 @@ async def test_datetime_iso_format(self): # Cleanup await self.connection.query("DELETE datetime_tests;") - await self.connection.socket.close() if __name__ == "__main__": From 2ea5d994894ae7af5b0e752ceb17f5d49b35a265 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Fri, 28 Feb 2025 14:19:24 +0000 Subject: [PATCH 05/10] updating tests --- .../connections/batch_async/test_async_ws.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py index 6ee6a0ab..b6f6e0d6 100644 --- a/tests/unit_tests/connections/batch_async/test_async_ws.py +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -2,6 +2,7 @@ from unittest import main, IsolatedAsyncioTestCase from surrealdb.connections.async_ws import AsyncWsSurrealConnection +import sys class TestAsyncWsSurrealConnection(IsolatedAsyncioTestCase): @@ -25,11 +26,16 @@ async def asyncSetUp(self): _ = await self.connection.use(namespace=self.namespace, database=self.database_name) async def test_batch(self): - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(self.connection.query("RETURN sleep(duration::from::millis($d)) or $p**2", dict(d=10 if num%2 else 0, p=num))) for num in range(5)] + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + if python_version == "3.9" or python_version == "3.10": + print("async batching is being bypassed due to python versions 3.9 and 3.10 not supporting async task group") + else: + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(self.connection.query("RETURN sleep(duration::from::millis($d)) or $p**2", dict(d=10 if num%2 else 0, p=num))) for num in range(5)] + + outcome = [t.result() for t in tasks] + self.assertEqual([0, 1, 4, 9, 16], outcome) - outcome = [t.result() for t in tasks] - self.assertEqual([0, 1, 4, 9, 16], outcome) if __name__ == "__main__": From 6315897267127e07effcb0c5e4146c84354f668c Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Fri, 28 Feb 2025 14:24:39 +0000 Subject: [PATCH 06/10] updating tests --- tests/unit_tests/connections/batch_async/test_async_ws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py index b6f6e0d6..17bfa36e 100644 --- a/tests/unit_tests/connections/batch_async/test_async_ws.py +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -35,7 +35,7 @@ async def test_batch(self): outcome = [t.result() for t in tasks] self.assertEqual([0, 1, 4, 9, 16], outcome) - + await self.connection.close() if __name__ == "__main__": From 4f167353114b5dc80e36d8113658e5a9983a2bc7 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Mar 2025 09:38:32 +0000 Subject: [PATCH 07/10] updating tests --- .github/workflows/tests.yml | 13 +++++++------ .../connections/batch_async/test_async_ws.py | 17 +++++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2c4c1cad..24443275 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,17 +42,18 @@ jobs: - name: Install dependencies run: pip install -r requirements.txt - - name: Run unit tests (HTTP) + - name: Run unit tests run: python -m unittest discover -s tests env: PYTHONPATH: ./src SURREALDB_URL: http://localhost:8000 + SURREALDB_VERSION: ${{ matrix.surrealdb-version }} - - name: Run unit tests (WebSocket) - run: python -m unittest discover -s tests - env: - PYTHONPATH: ./src - SURREALDB_URL: ws://localhost:8000 +# - name: Run unit tests (WebSocket) +# run: python -m unittest discover -s tests +# env: +# PYTHONPATH: ./src +# SURREALDB_URL: ws://localhost:8000 diff --git a/tests/unit_tests/connections/batch_async/test_async_ws.py b/tests/unit_tests/connections/batch_async/test_async_ws.py index 2d07bddd..6a747338 100644 --- a/tests/unit_tests/connections/batch_async/test_async_ws.py +++ b/tests/unit_tests/connections/batch_async/test_async_ws.py @@ -2,6 +2,7 @@ from unittest import main, IsolatedAsyncioTestCase from surrealdb.connections.async_ws import AsyncWsSurrealConnection +import os class TestAsyncWsSurrealConnection(IsolatedAsyncioTestCase): @@ -25,12 +26,16 @@ async def asyncSetUp(self): _ = await self.connection.use(namespace=self.namespace, database=self.database_name) async def test_batch(self): - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(self.connection.query("RETURN sleep(duration::from::millis($d)) or $p**2", dict(d=10 if num%2 else 0, p=num))) for num in range(5)] - - outcome = [t.result() for t in tasks] - self.assertEqual([0, 1, 4, 9, 16], outcome) - await self.connection.socket.close() + # async batching doesn't work for surrealDB v2.1.0" or lower + if os.environ.get("SURREALDB_VERSION") == "v2.1.0": + pass + else: + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(self.connection.query("RETURN sleep(duration::from::millis($d)) or $p**2", dict(d=10 if num%2 else 0, p=num))) for num in range(5)] + + outcome = [t.result() for t in tasks] + self.assertEqual([0, 1, 4, 9, 16], outcome) + await self.connection.socket.close() if __name__ == "__main__": From 3476b1f13ef91e69acc9c402d04db014f908c3c1 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Mar 2025 10:03:47 +0000 Subject: [PATCH 08/10] updating tests --- .github/workflows/tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 24443275..6d580bac 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,8 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - surrealdb-version: ["v2.1.0", "v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour + # surrealdb-version: ["v2.1.0", "v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour + surrealdb-version: ["v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour name: Python ${{ matrix.python-version }} - SurrealDB ${{ matrix.surrealdb-version }} steps: - name: Checkout repository From 58604a1f9c5719d18eda4d3bbb2a9937d4f4d02f Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Mar 2025 10:19:52 +0000 Subject: [PATCH 09/10] updating tests --- .github/workflows/tests.yml | 2 +- src/surrealdb/connections/async_ws.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6d580bac..e6bda247 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,7 +20,7 @@ jobs: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] # surrealdb-version: ["v2.1.0", "v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour - surrealdb-version: ["v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour + surrealdb-version: ["v2.1.1", "v2.1.2", "v2.1.3", "v2.1.4"] # v2.0.0 has different UPSERT behaviour and v2.1.0 does not support async batching name: Python ${{ matrix.python-version }} - SurrealDB ${{ matrix.surrealdb-version }} steps: - name: Checkout repository diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 7eb0f693..ea337383 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -1,11 +1,9 @@ """ A basic async connection to a SurrealDB instance. """ - import asyncio -import uuid from asyncio import Queue, Task, Future, AbstractEventLoop -from typing import Optional, Any, Dict, Union, List, AsyncGenerator +from typing import Optional, Any, Dict, Union, List from uuid import UUID import websockets From 0f30f93b83d29aa76d8cedc3f53fb374923a76a9 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Mar 2025 10:26:19 +0000 Subject: [PATCH 10/10] formatting code --- src/surrealdb/connections/async_http.py | 24 +++------- src/surrealdb/connections/async_template.py | 8 +--- src/surrealdb/connections/async_ws.py | 50 ++++++++------------- src/surrealdb/connections/blocking_http.py | 24 +++------- src/surrealdb/connections/blocking_ws.py | 24 +++------- src/surrealdb/connections/socket_state.py | 9 ---- 6 files changed, 39 insertions(+), 100 deletions(-) delete mode 100644 src/surrealdb/connections/socket_state.py diff --git a/src/surrealdb/connections/async_http.py b/src/surrealdb/connections/async_http.py index 313950bf..0d027752 100644 --- a/src/surrealdb/connections/async_http.py +++ b/src/surrealdb/connections/async_http.py @@ -188,9 +188,7 @@ async def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) self.id = message.id response = await self._send(message, "create") self.check_response_for_result(response, "create") @@ -208,9 +206,7 @@ async def delete( async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) self.id = message.id response = await self._send(message, "insert") self.check_response_for_result(response, "insert") @@ -236,9 +232,7 @@ async def unset(self, key: str) -> None: async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) self.id = message.id response = await self._send(message, "merge") self.check_response_for_result(response, "merge") @@ -247,9 +241,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) self.id = message.id response = await self._send(message, "patch") self.check_response_for_result(response, "patch") @@ -265,9 +257,7 @@ async def select(self, thing: str) -> Union[List[dict], dict]: async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) self.id = message.id response = await self._send(message, "update") self.check_response_for_result(response, "update") @@ -283,9 +273,7 @@ async def version(self) -> str: async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) self.id = message.id response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") diff --git a/src/surrealdb/connections/async_template.py b/src/surrealdb/connections/async_template.py index c63dac0f..340b809d 100644 --- a/src/surrealdb/connections/async_template.py +++ b/src/surrealdb/connections/async_template.py @@ -364,9 +364,7 @@ async def insert_relation( """ raise NotImplementedError(f"insert_relation not implemented for: {self}") - async def live( - self, table: Union[str, Table], diff: bool = False - ) -> UUID: + async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: """Initiates a live query for a specified table name. Args: @@ -383,9 +381,7 @@ async def live( """ raise NotImplementedError(f"live not implemented for: {self}") - async def subscribe_live( - self, query_uuid: Union[str, UUID] - ) -> Queue: + async def subscribe_live(self, query_uuid: Union[str, UUID]) -> Queue: """Returns a queue that receives notification messages from a running live query. Args: diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index ea337383..bf362278 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -1,6 +1,7 @@ """ A basic async connection to a SurrealDB instance. """ + import asyncio from asyncio import Queue, Task, Future, AbstractEventLoop from typing import Optional, Any, Dict, Union, List @@ -46,22 +47,22 @@ def __init__( self.port: Optional[int] = self.url.port self.token: Optional[str] = None self.socket = None - self.loop: AbstractEventLoop|None = None - self.qry:dict[str, Future] = {} - self.recv_task:Task[None]|None = None - self.live_queues:dict[str, list] = {} + self.loop: AbstractEventLoop | None = None + self.qry: dict[str, Future] = {} + self.recv_task: Task[None] | None = None + self.live_queues: dict[str, list] = {} async def _recv_task(self): assert self.socket async for data in self.socket: response = decode(data) - if (response_id := response.get("id")): + if response_id := response.get("id"): if fut := self.qry.get(response_id): fut.set_result(response) else: - live_id = str(response['result']['id']) + live_id = str(response["result"]["id"]) for queue in self.live_queues.get(live_id, []): - queue.put_nowait(response['result']) + queue.put_nowait(response["result"]) async def _send( self, message: RequestMessage, process: str, bypass: bool = False @@ -206,9 +207,7 @@ async def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] @@ -216,9 +215,7 @@ async def create( async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] @@ -226,9 +223,7 @@ async def update( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] @@ -236,9 +231,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] @@ -254,9 +247,7 @@ async def delete( async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] @@ -283,16 +274,16 @@ async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: self.live_queues[str(uuid)] = [] return uuid - def subscribe_live( - self, query_uuid: Union[str, UUID] - ) -> Queue: + def subscribe_live(self, query_uuid: Union[str, UUID]) -> Queue: result_queue = Queue() suid = str(query_uuid) self.live_queues[suid].append(result_queue) + async def _iter(): while True: ret = await result_queue.get() - yield ret['result'] + yield ret["result"] + return _iter() async def kill(self, query_uuid: Union[str, UUID]) -> None: @@ -303,9 +294,7 @@ async def kill(self, query_uuid: Union[str, UUID]) -> None: async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] @@ -321,7 +310,6 @@ async def close(self): if self.socket is not None: await self.socket.close() - async def __aenter__(self) -> "AsyncWsSurrealConnection": """ Asynchronous context manager entry. @@ -335,4 +323,4 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: Asynchronous context manager exit. Closes the websocket connection upon exiting the context. """ - await self.close() \ No newline at end of file + await self.close() diff --git a/src/surrealdb/connections/blocking_http.py b/src/surrealdb/connections/blocking_http.py index af82ddd0..f7cbc356 100644 --- a/src/surrealdb/connections/blocking_http.py +++ b/src/surrealdb/connections/blocking_http.py @@ -147,9 +147,7 @@ def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) self.id = message.id response = self._send(message, "create") self.check_response_for_result(response, "create") @@ -165,9 +163,7 @@ def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) self.id = message.id response = self._send(message, "insert") self.check_response_for_result(response, "insert") @@ -193,9 +189,7 @@ def unset(self, key: str) -> None: def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) self.id = message.id response = self._send(message, "merge") self.check_response_for_result(response, "merge") @@ -204,9 +198,7 @@ def merge( def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict[Any, Any]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) self.id = message.id response = self._send(message, "patch") self.check_response_for_result(response, "patch") @@ -222,9 +214,7 @@ def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) self.id = message.id response = self._send(message, "update") self.check_response_for_result(response, "update") @@ -240,9 +230,7 @@ def version(self) -> str: def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) self.id = message.id response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") diff --git a/src/surrealdb/connections/blocking_ws.py b/src/surrealdb/connections/blocking_ws.py index d00b4e4b..0ba4548d 100644 --- a/src/surrealdb/connections/blocking_ws.py +++ b/src/surrealdb/connections/blocking_ws.py @@ -169,9 +169,7 @@ def create( if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) - message = RequestMessage( - RequestMethod.CREATE, collection=thing, data=data - ) + message = RequestMessage(RequestMethod.CREATE, collection=thing, data=data) self.id = message.id response = self._send(message, "create") self.check_response_for_result(response, "create") @@ -202,9 +200,7 @@ def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: def insert( self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.INSERT, collection=table, params=data - ) + message = RequestMessage(RequestMethod.INSERT, collection=table, params=data) self.id = message.id response = self._send(message, "insert") self.check_response_for_result(response, "insert") @@ -224,9 +220,7 @@ def insert_relation( def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.MERGE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.MERGE, record_id=thing, data=data) self.id = message.id response = self._send(message, "merge") self.check_response_for_result(response, "merge") @@ -235,9 +229,7 @@ def merge( def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.PATCH, collection=thing, params=data - ) + message = RequestMessage(RequestMethod.PATCH, collection=thing, params=data) self.id = message.id response = self._send(message, "patch") self.check_response_for_result(response, "patch") @@ -275,9 +267,7 @@ def subscribe_live( def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPDATE, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPDATE, record_id=thing, data=data) self.id = message.id response = self._send(message, "update") self.check_response_for_result(response, "update") @@ -286,9 +276,7 @@ def update( def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: - message = RequestMessage( - RequestMethod.UPSERT, record_id=thing, data=data - ) + message = RequestMessage(RequestMethod.UPSERT, record_id=thing, data=data) self.id = message.id response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") diff --git a/src/surrealdb/connections/socket_state.py b/src/surrealdb/connections/socket_state.py deleted file mode 100644 index 2b630738..00000000 --- a/src/surrealdb/connections/socket_state.py +++ /dev/null @@ -1,9 +0,0 @@ - - -# singleton - -# dict for sockets - -# socket with smart reference counter - -# \ No newline at end of file