Skip to content

Commit 4e89316

Browse files
authored
Issue fix for Notification queue (#121)
1 parent ae48005 commit 4e89316

File tree

8 files changed

+36
-32
lines changed

8 files changed

+36
-32
lines changed

surrealdb/connection.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from typing import Dict, Tuple
99
from surrealdb.constants import REQUEST_ID_LENGTH
10-
from surrealdb.data.cbor import encode, decode
1110
from asyncio import Queue
1211

1312

@@ -34,7 +33,12 @@ def __init__(
3433
self,
3534
base_url: str,
3635
logger: logging.Logger,
36+
encoder,
37+
decoder,
3738
):
39+
self._encoder = encoder
40+
self._decoder = decoder
41+
3842
self._locks = {
3943
ResponseType.SEND: threading.Lock(),
4044
ResponseType.NOTIFICATION: threading.Lock(),
@@ -58,7 +62,7 @@ async def connect(self) -> None:
5862
async def close(self) -> None:
5963
pass
6064

61-
async def _make_request(self, request_data: RequestData, encoder, decoder):
65+
async def _make_request(self, request_data: RequestData):
6266
pass
6367

6468
async def set(self, key: str, value):
@@ -104,9 +108,7 @@ async def send(self, method: str, *params):
104108
self._logger.debug(f"Request {request_data.id}:", request_data)
105109

106110
try:
107-
result = await self._make_request(
108-
request_data, encoder=encode, decoder=decode
109-
)
111+
result = await self._make_request(request_data)
110112

111113
self._logger.debug(f"Result {request_data.id}:", result)
112114
self._logger.debug(

surrealdb/connection_clib.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class sr_notification_t(ctypes.Structure):
7777

7878

7979
class CLibConnection(Connection):
80-
def __init__(self, base_url: str, logger: logging.Logger):
81-
super().__init__(base_url, logger)
80+
def __init__(self, base_url: str, logger: logging.Logger, encoder, decoder):
81+
super().__init__(base_url, logger, encoder, decoder)
8282

8383
lib_path = get_lib_path()
8484
self._lib = ctypes.CDLL(lib_path)
@@ -194,8 +194,8 @@ async def set(self, key: str, value):
194194
async def unset(self, key: str):
195195
await self.send("unset", key)
196196

197-
async def _make_request(self, request_data: RequestData, encoder, decoder):
198-
request_payload = encoder(
197+
async def _make_request(self, request_data: RequestData):
198+
request_payload = self._encoder(
199199
{
200200
"id": request_data.id,
201201
"method": request_data.method,
@@ -226,4 +226,6 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):
226226

227227
# Free the allocated byte array returned by the C library
228228
self._lib.sr_free_byte_arr(c_res_ptr, result)
229-
return True, response
229+
response_data = self._decoder(response)
230+
231+
return response_data

surrealdb/connection_factory.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from surrealdb.connection_clib import CLibConnection
1313
from surrealdb.connection_http import HTTPConnection
1414
from surrealdb.connection_ws import WebsocketConnection
15+
from surrealdb.data.cbor import encode, decode
1516
from surrealdb.errors import SurrealDbConnectionError
1617

1718

@@ -26,14 +27,14 @@ def create_connection_factory(url: str) -> Connection:
2627

2728
if parsed_url.scheme in WS_CONNECTION_SCHEMES:
2829
logger.debug("websocket url detected, creating a websocket connection")
29-
return WebsocketConnection(url, logger)
30+
return WebsocketConnection(url, logger, encoder=encode, decoder=decode)
3031

3132
if parsed_url.scheme in HTTP_CONNECTION_SCHEMES:
3233
logger.debug("http url detected, creating a http connection")
33-
return HTTPConnection(url, logger)
34+
return HTTPConnection(url, logger, encoder=encode, decoder=decode)
3435

3536
if parsed_url.scheme in CLIB_CONNECTION_SCHEMES:
3637
logger.debug("embedded url detected, creating a clib connection")
37-
return CLibConnection(url, logger)
38+
return CLibConnection(url, logger, encoder=encode, decoder=decode)
3839

3940
raise Exception("no connection type available")

surrealdb/connection_http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def connect(self) -> None:
3434
"connection failed. check server is up and base url is correct"
3535
)
3636

37-
async def _make_request(self, request_data: RequestData, encoder, decoder):
37+
async def _make_request(self, request_data: RequestData):
3838
if self._namespace is None:
3939
raise SurrealDbConnectionError("namespace not set")
4040

@@ -51,7 +51,7 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):
5151
if self._auth_token is not None:
5252
headers["Authorization"] = f"Bearer {self._auth_token}"
5353

54-
request_payload = encoder(
54+
request_payload = self._encoder(
5555
{
5656
"id": request_data.id,
5757
"method": request_data.method,
@@ -62,7 +62,7 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):
6262
response = requests.post(
6363
f"{self._base_url}/rpc", data=request_payload, headers=headers
6464
)
65-
response_data = decoder(response.content)
65+
response_data = self._decoder(response.content)
6666

6767
if 200 > response.status_code > 299 or response_data.get("error"):
6868
raise SurrealDbConnectionError(response_data.get("error").get("message"))

surrealdb/connection_ws.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
11
import asyncio
2-
import logging
32
from asyncio import Task
43

54
from websockets import Subprotocol, ConnectionClosed, connect
65
from websockets.asyncio.client import ClientConnection
76

87
from surrealdb.connection import Connection, ResponseType, RequestData
98
from surrealdb.constants import WS_REQUEST_TIMEOUT
10-
from surrealdb.data.cbor import decode
119
from surrealdb.errors import SurrealDbConnectionError
1210

1311

1412
class WebsocketConnection(Connection):
1513
_ws: ClientConnection
1614
_receiver_task: Task
1715

18-
def __init__(self, base_url: str, logger: logging.Logger):
19-
super().__init__(base_url, logger)
20-
21-
# self._ws = None
22-
# self._receiver_task = None
23-
2416
async def connect(self):
2517
try:
2618
self._ws = await connect(
@@ -49,8 +41,8 @@ async def close(self):
4941
if self._ws:
5042
await self._ws.close()
5143

52-
async def _make_request(self, request_data: RequestData, encoder, decoder):
53-
request_payload = encoder(
44+
async def _make_request(self, request_data: RequestData):
45+
request_payload = self._encoder(
5446
{
5547
"id": request_data.id,
5648
"method": request_data.method,
@@ -85,14 +77,18 @@ async def _make_request(self, request_data: RequestData, encoder, decoder):
8577

8678
async def listen_to_ws(self, ws):
8779
async for message in ws:
88-
response_data = decode(message)
80+
response_data = self._decoder(message)
8981

9082
response_id = response_data.get("id")
9183
if response_id:
9284
queue = self.get_response_queue(ResponseType.SEND, response_id)
9385
await queue.put(response_data)
9486
continue
9587

96-
live_id = response_data.get("result").get("id")
97-
queue = self.get_response_queue(ResponseType.NOTIFICATION, live_id)
88+
live_id = response_data.get("result").get("id") # returned as uuid
89+
queue = self.get_response_queue(ResponseType.NOTIFICATION, str(live_id))
90+
if queue is None:
91+
self._logger.error(f"No notification queue set for {live_id}")
92+
continue
93+
9894
await queue.put(response_data.get("result"))

tests/unit/test_clib_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from unittest import IsolatedAsyncioTestCase
22
from logging import getLogger
33
from surrealdb.connection_clib import CLibConnection
4+
from surrealdb.data.cbor import encode, decode
45

56

67
class TestCLibConnection(IsolatedAsyncioTestCase):
78
async def asyncSetUp(self):
89
self.logger = getLogger(__name__)
910

10-
self.clib = CLibConnection(base_url='surrealkv://', logger=self.logger)
11+
self.clib = CLibConnection(base_url='surrealkv://', logger=self.logger, encoder=encode, decoder=decode)
1112
await self.clib.connect()
1213

1314
async def test_send(self):

tests/unit/test_http_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from unittest import IsolatedAsyncioTestCase
44

55
from surrealdb.connection_http import HTTPConnection
6+
from surrealdb.data.cbor import encode, decode
67

78

89
class TestHTTPConnection(IsolatedAsyncioTestCase):
910
async def asyncSetUp(self):
1011
logger = logging.getLogger(__name__)
1112

12-
self.http_con = HTTPConnection(base_url='http://localhost:8000', logger=logger)
13+
self.http_con = HTTPConnection(base_url='http://localhost:8000', logger=logger, encoder=encode, decoder=decode)
1314
await self.http_con.connect()
1415

1516
async def test_send(self):

tests/unit/test_ws_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from unittest import IsolatedAsyncioTestCase
55

66
from surrealdb.connection_ws import WebsocketConnection
7+
from surrealdb.data.cbor import encode, decode
78

89

910
class TestWSConnection(IsolatedAsyncioTestCase):
1011
async def asyncSetUp(self):
1112
logger = logging.getLogger(__name__)
12-
self.ws_con = WebsocketConnection(base_url='ws://localhost:8000', logger=logger)
13+
self.ws_con = WebsocketConnection(base_url='ws://localhost:8000', logger=logger, encoder=encode, decoder=decode)
1314
await self.ws_con.connect()
1415

1516
async def test_send(self):

0 commit comments

Comments
 (0)