Skip to content

Commit 937b43e

Browse files
maxwellflittontobiemhremade
authored
adding unittest for base connection object (#120)
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com> Co-authored-by: Remade <mail4remi@gmail.com>
1 parent ca5d4c6 commit 937b43e

File tree

8 files changed

+353
-22
lines changed

8 files changed

+353
-22
lines changed

surrealdb/async_surrealdb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,4 +295,5 @@ async def kill(self, live_query_id: uuid.UUID) -> None:
295295
296296
:param live_query_id: The UUID of the live query to kill.
297297
"""
298+
298299
return await self.__connection.send(METHOD_KILL, live_query_id)

surrealdb/connection.py

Lines changed: 157 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
"""
2+
Defines the base Connection class for sending and receiving requests.
3+
"""
4+
5+
import logging
16
import secrets
27
import string
3-
import logging
48
import threading
59
import uuid
6-
from dataclasses import dataclass
710

11+
from dataclasses import dataclass
812
from typing import Dict, Tuple
913
from surrealdb.constants import (
1014
REQUEST_ID_LENGTH,
@@ -23,19 +27,50 @@
2327

2428

2529
class ResponseType:
30+
"""
31+
Enum-like class representing response types for the connection.
32+
33+
Attributes:
34+
SEND (int): Response type for standard requests.
35+
NOTIFICATION (int): Response type for notifications.
36+
ERROR (int): Response type for errors.
37+
"""
38+
2639
SEND = 1
2740
NOTIFICATION = 2
2841
ERROR = 3
2942

3043

3144
@dataclass
3245
class RequestData:
46+
"""
47+
Represents the data for a request sent over the connection.
48+
49+
Attributes:
50+
id (str): Unique identifier for the request.
51+
method (str): The method name to invoke.
52+
params (Tuple): Parameters for the method.
53+
"""
54+
3355
id: str
3456
method: str
3557
params: Tuple
3658

3759

3860
class Connection:
61+
"""
62+
Base class for managing a connection to the database.
63+
64+
Manages request/response lifecycle, including the use of queues for
65+
handling asynchronous communication.
66+
67+
Attributes:
68+
_queues (Dict[int, dict]): Mapping of response types to their queues.
69+
_namespace (str | None): Current namespace in use.
70+
_database (str | None): Current database in use.
71+
_auth_token (str | None): Authentication token.
72+
"""
73+
3974
_queues: Dict[int, dict]
4075
_locks: Dict[int, threading.Lock]
4176
_namespace: str | None = None
@@ -49,6 +84,15 @@ def __init__(
4984
encoder,
5085
decoder,
5186
):
87+
"""
88+
Initialize the Connection instance.
89+
90+
Args:
91+
base_url (str): The base URL of the server.
92+
logger (logging.Logger): Logger for debugging and tracking activities.
93+
encoder (function): Function to encode the request.
94+
decoder (function): Function to decode the response.
95+
"""
5296
self._encoder = encoder
5397
self._decoder = decoder
5498

@@ -67,47 +111,120 @@ def __init__(
67111
self._logger = logger
68112

69113
async def use(self, namespace: str, database: str) -> None:
70-
pass
114+
"""
115+
Set the namespace and database for subsequent operations.
116+
117+
Args:
118+
namespace (str): The namespace to use.
119+
database (str): The database to use.
120+
"""
121+
raise NotImplementedError("use method must be implemented")
71122

72123
async def connect(self) -> None:
73-
pass
124+
"""
125+
Establish a connection to the server.
126+
"""
127+
raise NotImplementedError("connect method must be implemented")
74128

75129
async def close(self) -> None:
76-
pass
130+
"""
131+
Close the connection to the server.
132+
"""
133+
raise NotImplementedError("close method must be implemented")
134+
135+
async def _make_request(self, request_data: RequestData) -> dict:
136+
"""
137+
Internal method to send a request and handle the response.
138+
Args:
139+
request_data (RequestData): The data to send.
140+
return:
141+
dict: The response data from the request.
142+
"""
143+
raise NotImplementedError("_make_request method must be implemented")
77144

78-
async def _make_request(self, request_data: RequestData):
79-
pass
145+
async def set(self, key: str, value) -> None:
146+
"""
147+
Set a key-value pair in the database.
80148
81-
async def set(self, key: str, value):
82-
pass
149+
Args:
150+
key (str): The key to set.
151+
value: The value to set.
152+
"""
153+
raise NotImplementedError("set method must be implemented")
83154

84-
async def unset(self, key: str):
85-
pass
155+
async def unset(self, key: str) -> None:
156+
"""
157+
Unset a key-value pair in the database.
158+
159+
Args:
160+
key (str): The key to unset.
161+
"""
162+
raise NotImplementedError("unset method must be implemented")
86163

87164
def set_token(self, token: str | None = None) -> None:
165+
"""
166+
Set the authentication token for the connection.
167+
168+
Args:
169+
token (str): The authentication token to be set
170+
"""
88171
self._auth_token = token
89172

90-
def create_response_queue(self, response_type: int, queue_id: str):
173+
def create_response_queue(self, response_type: int, queue_id: str) -> Queue:
174+
"""
175+
Create a response queue for a given response type.
176+
177+
Args:
178+
response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR).
179+
queue_id (str): The unique identifier for the queue.
180+
Returns:
181+
Queue: The response queue for the given response type and queue ID
182+
(existing queues will be overwritten if same ID is used, cannot get existing queue).
183+
"""
91184
lock = self._locks[response_type]
92185
with lock:
93186
response_type_queues = self._queues.get(response_type)
94187
if response_type_queues is None:
95188
response_type_queues = {}
96189

97-
if response_type_queues.get(queue_id) is None:
98-
queue: Queue = Queue(maxsize=0)
190+
queue = response_type_queues.get(queue_id)
191+
if queue is None:
192+
queue = Queue(maxsize=0)
99193
response_type_queues[queue_id] = queue
100194
self._queues[response_type] = response_type_queues
101-
return queue
102195

103-
def get_response_queue(self, response_type: int, queue_id: str):
196+
return queue
197+
198+
def get_response_queue(self, response_type: int, queue_id: str) -> Queue | None:
199+
"""
200+
Get a response queue for a given response type.
201+
202+
Args:
203+
response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR).
204+
queue_id (str): The unique identifier for the queue.
205+
206+
Returns:
207+
Queue: The response queue for the given response type and queue ID
208+
(existing queues will be overwritten if same ID is used).
209+
"""
104210
lock = self._locks[response_type]
105211
with lock:
106212
response_type_queues = self._queues.get(response_type)
107-
if response_type_queues:
108-
return response_type_queues.get(queue_id)
213+
if not response_type_queues:
214+
return None
215+
return response_type_queues.get(queue_id)
109216

110-
def remove_response_queue(self, response_type: int, queue_id: str):
217+
def remove_response_queue(self, response_type: int, queue_id: str) -> None:
218+
"""
219+
Remove a response queue for a given response type.
220+
221+
Notes:
222+
Does not alert if the key is missing
223+
224+
Args:
225+
response_type (int): The response type for the queue (1: SEND, 2: NOTIFICATION, 3: ERROR).
226+
queue_id (str): The unique identifier for the queue.
227+
"""
111228
lock = self._locks[response_type]
112229
with lock:
113230
response_type_queues = self._queues.get(response_type)
@@ -134,6 +251,17 @@ def _prepare_method_params(method: str, params) -> Tuple:
134251
return prepared_params
135252

136253
async def send(self, method: str, *params):
254+
"""
255+
Sends a request to the server with a unique ID and returns the response.
256+
257+
Args:
258+
method (str): The method of the request.
259+
params: Parameters for the request.
260+
261+
Returns:
262+
dict: The response data from the request.
263+
"""
264+
137265
prepared_params = self._prepare_method_params(method, params)
138266
request_data = RequestData(
139267
id=request_id(REQUEST_ID_LENGTH), method=method, params=prepared_params
@@ -156,7 +284,16 @@ async def send(self, method: str, *params):
156284
)
157285
raise e
158286

159-
async def live_notifications(self, live_query_id: uuid.UUID):
287+
async def live_notifications(self, live_query_id: uuid.UUID) -> Queue:
288+
"""
289+
Create a response queue for live notifications by essentially creating a NOTIFICATION response queue.
290+
291+
Args:
292+
live_query_id (uuid.UUID): The unique identifier for the live query.
293+
294+
Returns:
295+
Queue: The response queue for the live notifications.
296+
"""
160297
queue = self.create_response_queue(
161298
ResponseType.NOTIFICATION, str(live_query_id)
162299
)

surrealdb/connection_clib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ async def connect(self):
180180
self._lib.sr_free_string(c_err)
181181

182182
async def close(self):
183-
pass
183+
self._lib.sr_surreal_rpc_free(self._c_surreal_rpc)
184184

185185
async def use(self, namespace: str, database: str) -> None:
186186
self._namespace = namespace

surrealdb/connection_http.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class HTTPConnection(Connection):
1111
_request_variables: dict[str, Any] = {}
1212
_request_variables_lock = threading.Lock()
13+
_is_ready: bool = False
1314

1415
async def use(self, namespace: str, database: str) -> None:
1516
self._namespace = namespace
@@ -34,6 +35,10 @@ async def connect(self) -> None:
3435
raise SurrealDbConnectionError(
3536
"connection failed. check server is up and base url is correct"
3637
)
38+
self._is_ready = True
39+
40+
async def close(self):
41+
self._is_ready = False
3742

3843
def _prepare_query_method_params(self, params: Tuple) -> Tuple:
3944
query, variables = params
@@ -45,6 +50,11 @@ def _prepare_query_method_params(self, params: Tuple) -> Tuple:
4550
return query, variables
4651

4752
async def _make_request(self, request_data: RequestData):
53+
if not self._is_ready:
54+
raise SurrealDbConnectionError(
55+
"connection not ready. Call the connect() method first"
56+
)
57+
4858
if self._namespace is None:
4959
raise SurrealDbConnectionError("namespace not set")
5060

surrealdb/connection_ws.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,23 @@ async def use(self, namespace: str, database: str) -> None:
2828

2929
await self.send(METHOD_USE, namespace, database)
3030

31-
async def set(self, key: str, value):
31+
async def set(self, key: str, value) -> None:
32+
"""
33+
Set a key-value pair in the database.
34+
35+
Args:
36+
key (str): The key to set.
37+
value: The value to set.
38+
"""
3239
await self.send(METHOD_SET, key, value)
3340

3441
async def unset(self, key: str):
42+
"""
43+
Unset a key-value pair in the database.
44+
45+
Args:
46+
key (str): The key to unset.
47+
"""
3548
await self.send(METHOD_UNSET, key)
3649

3750
async def close(self):

surrealdb/data/README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
## What is CBOR?
2+
CBOR is a binary data serialization format similar to JSON but more compact, efficient, and capable of encoding a
3+
broader range of data types. It is useful for exchanging structured data between systems, especially when performance
4+
and size are critical.
5+
6+
## Purpose of the CBOR Implementation
7+
8+
The CBOR code here allows the custom SurrealDB types (e.g., `GeometryPoint`, `Table`, `Range`, etc.) to be serialized
9+
into CBOR binary format and deserialized back into Python objects. This is necessary because these types are not natively
10+
supported by CBOR; thus, custom encoding and decoding logic is implemented.
11+
12+
## Key Components
13+
14+
### Custom Types
15+
16+
`Range` Class: Represents a range with a beginning (`begin`) and end (`end`). These can either be included (`BoundIncluded`) or excluded (`BoundExcluded`).
17+
`Table`, `RecordID`, `GeometryPoint`, etc.: Custom SurrealDB-specific data types, representing domain-specific constructs like tables, records, and geometrical objects.
18+
19+
### CBOR Encoder
20+
21+
The function `default_encoder` is used to encode custom Python objects into CBOR's binary format. This is done by associating a specific CBOR tag (a numeric identifier) with each data type.
22+
23+
For example:
24+
25+
`GeometryPoint` objects are encoded using the tag `TAG_GEOMETRY_POINT` with its coordinates as the value.
26+
`Range` objects are encoded using the tag `TAG_BOUND_EXCLUDED` with a list [begin, end] as its value.
27+
The `CBORTag` class is used to represent tagged data in `CBOR`.
28+
29+
### CBOR Decoder
30+
31+
The function `tag_decoder` is the inverse of `default_encoder`. It takes tagged CBOR data and reconstructs the corresponding Python objects.
32+
33+
For example:
34+
35+
When encountering the `TAG_GEOMETRY_POINT` tag, it creates a `GeometryPoint` object using the tag's value (coordinates).
36+
When encountering the `TAG_RANGE` tag, it creates a `Range` object using the tag's value (begin and end).
37+
38+
### encode and decode Functions
39+
40+
These are high-level functions for serializing and deserializing data:
41+
42+
`encode(obj)`: Converts a Python object into CBOR binary format.
43+
`decode(data)`: Converts CBOR binary data back into a Python object using the custom decoding logic.

0 commit comments

Comments
 (0)