Skip to content

Commit 8ae0f25

Browse files
committed
Improve packstream Structure class
* Implement `repr` to match Python implementation * Copy tests for `Structure` from the driver project
1 parent c5b2e9f commit 8ae0f25

File tree

5 files changed

+116
-4
lines changed

5 files changed

+116
-4
lines changed

bin/target_driver.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ git checkout "$version"
1212
git pull origin "$version"
1313
cd ..
1414
cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/codec/packstream/v1/from_driver/test_packstream.py
15+
cp driver/tests/unit/common/codec/packstream/test_structure.py tests/unit/common/codec/packstream/from_driver/test_structure.py
1516
cp -r driver/tests/unit/common/vector/* tests/vector/from_driver
1617

1718
towncrier create -c "Target driver version ${version}<ISSUES_LIST>.

changelog.d/63.clean.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Improve packstream `Structure` class<ISSUES_LIST>.
2+
3+
* Implement `repr` to match Python implementation.
4+
* Copy tests for `Structure` from the driver project.

src/codec/packstream.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ mod v1;
1818
use pyo3::basic::CompareOp;
1919
use pyo3::exceptions::PyValueError;
2020
use pyo3::prelude::*;
21-
use pyo3::types::{PyBytes, PyTuple};
21+
use pyo3::types::PyBytes;
2222
use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit};
2323

2424
use crate::register_package;
@@ -64,9 +64,20 @@ impl Structure {
6464
PyBytes::new(py, &[self.tag])
6565
}
6666

67-
#[getter(fields)]
68-
fn read_fields<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
69-
PyTuple::new(py, &self.fields)
67+
fn __repr__(slf: &Bound<'_, Self>) -> PyResult<String> {
68+
let args = {
69+
let slf = slf.borrow();
70+
let py = slf.py();
71+
let mut args = format!(r"b'{}'", slf.tag as char);
72+
slf.fields.iter().try_for_each(|field| {
73+
let repr = field.bind(py).repr()?;
74+
args.push_str(", ");
75+
args.push_str(&repr.to_cow()?);
76+
Ok::<_, PyErr>(())
77+
})?;
78+
args
79+
};
80+
Ok(format!("Structure({args})"))
7081
}
7182

7283
fn eq(&self, other: &Self, py: Python<'_>) -> PyResult<bool> {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
19+
from neo4j._codec.packstream import Structure
20+
21+
22+
@pytest.mark.parametrize(
23+
"args",
24+
(
25+
(b"T", 1, 2, 3, "abc", 1.2, None, False),
26+
(b"F",),
27+
),
28+
)
29+
def test_structure_accessors(args):
30+
tag = args[0]
31+
fields = list(args[1:])
32+
s1 = Structure(*args)
33+
assert s1.tag == tag
34+
assert s1.fields == fields
35+
36+
37+
@pytest.mark.parametrize(
38+
("other", "expected"),
39+
(
40+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True),
41+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False),
42+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False),
43+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False),
44+
(Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False),
45+
(
46+
Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]),
47+
False,
48+
),
49+
(Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
50+
(Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
51+
(Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False),
52+
(object(), NotImplemented),
53+
),
54+
)
55+
def test_structure_equality(other, expected):
56+
s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None])
57+
assert s1.__eq__(other) is expected # noqa: PLC2801
58+
if expected is NotImplemented:
59+
assert s1.__ne__(other) is NotImplemented # noqa: PLC2801
60+
else:
61+
assert s1.__ne__(other) is not expected # noqa: PLC2801
62+
63+
64+
@pytest.mark.parametrize(
65+
("args", "expected"),
66+
(
67+
((b"F", 1, 2), "Structure(b'F', 1, 2)"),
68+
((b"f", [1, 2]), "Structure(b'f', [1, 2])"),
69+
(
70+
(b"T", 1.3, None, {"a": "b"}),
71+
"Structure(b'T', 1.3, None, {'a': 'b'})",
72+
),
73+
),
74+
)
75+
def test_structure_repr(args, expected):
76+
s1 = Structure(*args)
77+
assert repr(s1) == expected
78+
assert str(s1) == expected
79+
80+
# Ensure that the repr is consistent with the constructor
81+
assert eval(repr(s1)) == s1
82+
assert eval(str(s1)) == s1

0 commit comments

Comments
 (0)