Skip to content

Commit e78e4f9

Browse files
authored
Implement set variables for http query (#122)
1 parent 4e89316 commit e78e4f9

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

surrealdb/connection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class RequestData:
2525

2626
class Connection:
2727
_queues: Dict[int, dict]
28-
_namespace: str | None
29-
_database: str | None
30-
_auth_token: str | None
28+
_namespace: str | None = None
29+
_database: str | None = None
30+
_auth_token: str | None = None
3131

3232
def __init__(
3333
self,

surrealdb/connection_http.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import threading
2-
from typing import Any
2+
from typing import Any, Tuple
33

44
import requests
55

@@ -8,7 +8,7 @@
88

99

1010
class HTTPConnection(Connection):
11-
_request_variables: dict[str, Any]
11+
_request_variables: dict[str, Any] = {}
1212
_request_variables_lock = threading.Lock()
1313

1414
async def use(self, namespace: str, database: str) -> None:
@@ -21,7 +21,8 @@ async def set(self, key: str, value):
2121

2222
async def unset(self, key: str):
2323
with self._request_variables_lock:
24-
del self._request_variables[key]
24+
if self._request_variables.get(key) is not None:
25+
del self._request_variables[key]
2526

2627
async def connect(self) -> None:
2728
if self._base_url is None:
@@ -34,6 +35,15 @@ async def connect(self) -> None:
3435
"connection failed. check server is up and base url is correct"
3536
)
3637

38+
def _prepare_query_method_params(self, params: Tuple) -> Tuple:
39+
query, variables = params
40+
variables = (
41+
{**variables, **self._request_variables}
42+
if variables
43+
else self._request_variables.copy()
44+
)
45+
return query, variables
46+
3747
async def _make_request(self, request_data: RequestData):
3848
if self._namespace is None:
3949
raise SurrealDbConnectionError("namespace not set")
@@ -51,6 +61,9 @@ async def _make_request(self, request_data: RequestData):
5161
if self._auth_token is not None:
5262
headers["Authorization"] = f"Bearer {self._auth_token}"
5363

64+
if request_data.method.lower() == "query":
65+
request_data.params = self._prepare_query_method_params(request_data.params)
66+
5467
request_payload = self._encoder(
5568
{
5669
"id": request_data.id,

tests/unit/test_http_connection.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,36 @@ async def asyncSetUp(self):
1616
async def test_send(self):
1717
await self.http_con.use('test', 'test')
1818
_ = await self.http_con.send('signin', {'user': 'root', 'pass': 'root'})
19+
20+
async def test_prepare_query_params(self):
21+
query_params = ("SOME SQL QUERY;", {
22+
"key1": "key1"
23+
})
24+
await self.http_con.set("key2", "key2")
25+
await self.http_con.set("key3", "key3")
26+
27+
params = self.http_con._prepare_query_method_params(query_params)
28+
self.assertEqual(query_params[0], params[0])
29+
self.assertEqual({
30+
"key1": "key1",
31+
"key2": "key2",
32+
"key3": "key3",
33+
}, params[1])
34+
35+
await self.http_con.unset("key3")
36+
37+
params = self.http_con._prepare_query_method_params(query_params)
38+
self.assertEqual(query_params[0], params[0])
39+
self.assertEqual({
40+
"key1": "key1",
41+
"key2": "key2",
42+
}, params[1])
43+
44+
await self.http_con.unset("key1") # variable key not part of prev set variables
45+
46+
params = self.http_con._prepare_query_method_params(query_params)
47+
self.assertEqual(query_params[0], params[0])
48+
self.assertEqual({
49+
"key1": "key1",
50+
"key2": "key2",
51+
}, params[1])

0 commit comments

Comments
 (0)