Skip to content

Commit 9622672

Browse files
authored
Fix None encoding and decoding for SurrealDB v2.2.x and later (surrealdb#184)
1 parent 79e7d03 commit 9622672

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

src/surrealdb/cbor2/_decoder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,14 @@ def decode_bigfloat(self) -> Decimal:
604604

605605
return self.set_shareable(Decimal(sig) * (2 ** Decimal(exp)))
606606

607+
def decode_none(self) -> None:
608+
# Semantic tag 6
609+
value = self._decode()
610+
if not isinstance(value, type(None)):
611+
raise CBORDecodeValueError("invalid None value " + str(value))
612+
613+
return self.set_shareable(None)
614+
607615
def decode_stringref(self) -> str | bytes:
608616
# Semantic tag 25
609617
if self._stringref_namespace is None:
@@ -786,7 +794,7 @@ def decode_float64(self) -> float:
786794
3: CBORDecoder.decode_negative_bignum,
787795
4: CBORDecoder.decode_fraction,
788796
5: CBORDecoder.decode_bigfloat,
789-
6: lambda self: None,
797+
6: CBORDecoder.decode_none,
790798
25: CBORDecoder.decode_stringref,
791799
28: CBORDecoder.decode_shareable,
792800
29: CBORDecoder.decode_sharedref,

src/surrealdb/cbor2/_encoder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,14 @@ def encode_boolean(self, value: bool) -> None:
658658

659659
def encode_none(self, value: None) -> None:
660660
# tag(6, None)
661-
self._fp_write(b"\xd9\x00\x06\xf6")
662-
# self._fp_write(b"\xf6")
661+
662+
# Note that although just \xf6 (major type 6, 22) is already null in CBOR,
663+
# which this cbor2 implementation decodes it as Python None,
664+
# SurrealDB's variant of CBOR wraps it with \xc6 (major type 6, tag 6)
665+
# which results in \xc6\xf6.
666+
# We try to be compliant as much as SurrealDB's variant of CBOR
667+
# and therefore encode it as \xc6\xf6.
668+
self._fp_write(b"\xc6\xf6")
663669

664670
def encode_undefined(self, value: UndefinedType) -> None:
665671
self._fp_write(b"\xf7")

tests/unit_tests/data_types/test_none.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
will have to look into schema objects for more complete serialization.
77
"""
88
from unittest import main, IsolatedAsyncioTestCase
9+
from os import environ
910

1011
from surrealdb.connections.async_ws import AsyncWsSurrealConnection
1112
from surrealdb.data.types.record_id import RecordID
@@ -30,7 +31,7 @@ async def asyncSetUp(self):
3031
await self.connection.use(namespace=self.namespace, database=self.database_name)
3132

3233
# Cleanup
33-
await self.connection.query("DELETE person;")
34+
await self.connection.query("REMOVE TABLE person;")
3435

3536
async def test_none(self):
3637
schema = """
@@ -63,5 +64,45 @@ async def test_none(self):
6364
await self.connection.query("REMOVE TABLE person;")
6465
await self.connection.close()
6566

67+
async def test_none_with_query(self):
68+
is_sdb21 = environ.get("SURREALDB_VERSION", "v2.1.4").startswith("v2.1.")
69+
schema = """
70+
DEFINE TABLE person SCHEMAFULL TYPE NORMAL;
71+
DEFINE FIELD name ON person TYPE string;
72+
DEFINE FIELD nums ON person TYPE array<option<int>>;
73+
"""
74+
await self.connection.query(schema)
75+
outcome = await self.connection.query(
76+
"UPSERT person MERGE {id: $id, name: $name, nums: $nums}",
77+
{"id": [1,2], "name": "John", "nums": [None]}
78+
)
79+
record_check = RecordID(table_name="person", identifier=[1, 2])
80+
self.assertEqual(1, len(outcome))
81+
self.assertEqual(record_check, outcome[0]["id"])
82+
self.assertEqual("John", outcome[0]["name"])
83+
if is_sdb21:
84+
self.assertEqual([], outcome[0].get("nums"))
85+
else:
86+
self.assertEqual([None], outcome[0].get("nums"))
87+
88+
outcome = await self.connection.query(
89+
"UPSERT person MERGE {id: $id, name: $name, nums: $nums}",
90+
{"id": [3,4], "name": "Dave", "nums": [None]}
91+
)
92+
record_check = RecordID(table_name="person", identifier=[3, 4])
93+
self.assertEqual(1, len(outcome))
94+
self.assertEqual(record_check, outcome[0]["id"])
95+
self.assertEqual("Dave", outcome[0]["name"])
96+
if is_sdb21:
97+
self.assertEqual([], outcome[0].get("nums"))
98+
else:
99+
self.assertEqual([None], outcome[0].get("nums"))
100+
101+
outcome = await self.connection.query("SELECT * FROM person")
102+
self.assertEqual(2, len(outcome))
103+
104+
await self.connection.query("REMOVE TABLE person;")
105+
await self.connection.close()
106+
66107
if __name__ == "__main__":
67108
main()

0 commit comments

Comments
 (0)