Skip to content

Commit 72c426b

Browse files
committed
Improve packstream Structure class
* Implement `repr` to match Python implementation * Copy tests for `Structure` from the driver project
1 parent c5b2e9f commit 72c426b

File tree

6 files changed

+218
-29
lines changed

6 files changed

+218
-29
lines changed

bin/target_driver.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ git checkout "$version"
1212
git pull origin "$version"
1313
cd ..
1414
cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/codec/packstream/v1/from_driver/test_packstream.py
15+
cp driver/tests/unit/common/codec/packstream/test_structure.py tests/unit/common/codec/packstream/from_driver/test_structure.py
1516
cp -r driver/tests/unit/common/vector/* tests/vector/from_driver
1617

1718
towncrier create -c "Target driver version ${version}<ISSUES_LIST>.

changelog.d/63.clean.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Improve packstream `Structure` class<ISSUES_LIST>.
2+
3+
* Implement `repr` to match Python implementation.
4+
* Remove `__hash__` implementation to match Python implementation.
5+
* Implement `__getitem__` and `__setitem__` to be on par with Python implementation.
6+
* Copy tests for `Structure` from the driver project.

src/codec/packstream.rs

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
mod v1;
1717

1818
use pyo3::basic::CompareOp;
19-
use pyo3::exceptions::PyValueError;
19+
use pyo3::exceptions::{PyIndexError, PyValueError};
2020
use pyo3::prelude::*;
21-
use pyo3::types::{PyBytes, PyTuple};
21+
use pyo3::types::PyBytes;
2222
use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit};
2323

2424
use crate::register_package;
@@ -46,6 +46,40 @@ pub struct Structure {
4646
fields: Vec<PyObject>,
4747
}
4848

49+
impl Structure {
50+
fn eq(&self, other: &Self, py: Python<'_>) -> PyResult<bool> {
51+
if self.tag != other.tag || self.fields.len() != other.fields.len() {
52+
return Ok(false);
53+
}
54+
for (a, b) in self
55+
.fields
56+
.iter()
57+
.map(|e| e.bind(py))
58+
.zip(other.fields.iter().map(|e| e.bind(py)))
59+
{
60+
if !a.eq(b)? {
61+
return Ok(false);
62+
}
63+
}
64+
Ok(true)
65+
}
66+
67+
fn compute_index(&self, index: isize) -> PyResult<usize> {
68+
Ok(if index < 0 {
69+
self.fields
70+
.len()
71+
.checked_sub(-index as usize)
72+
.ok_or_else(|| PyErr::new::<PyIndexError, _>("field index out of range"))?
73+
} else {
74+
let index = index as usize;
75+
if index >= self.fields.len() {
76+
return Err(PyErr::new::<PyIndexError, _>("field index out of range"));
77+
}
78+
index
79+
})
80+
}
81+
}
82+
4983
#[pymethods]
5084
impl Structure {
5185
#[new]
@@ -64,26 +98,15 @@ impl Structure {
6498
PyBytes::new(py, &[self.tag])
6599
}
66100

67-
#[getter(fields)]
68-
fn read_fields<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
69-
PyTuple::new(py, &self.fields)
70-
}
71-
72-
fn eq(&self, other: &Self, py: Python<'_>) -> PyResult<bool> {
73-
if self.tag != other.tag || self.fields.len() != other.fields.len() {
74-
return Ok(false);
75-
}
76-
for (a, b) in self
77-
.fields
78-
.iter()
79-
.map(|e| e.bind(py))
80-
.zip(other.fields.iter().map(|e| e.bind(py)))
81-
{
82-
if !a.eq(b)? {
83-
return Ok(false);
84-
}
85-
}
86-
Ok(true)
101+
fn __repr__(&self, py: Python<'_>) -> PyResult<String> {
102+
let mut args = format!(r"b'{}'", self.tag as char);
103+
self.fields.iter().try_for_each(|field| {
104+
let repr = field.bind(py).repr()?;
105+
args.push_str(", ");
106+
args.push_str(&repr.to_cow()?);
107+
Ok::<_, PyErr>(())
108+
})?;
109+
Ok(format!("Structure({args})"))
87110
}
88111

89112
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
@@ -94,12 +117,18 @@ impl Structure {
94117
})
95118
}
96119

97-
fn __hash__(&self, py: Python<'_>) -> PyResult<isize> {
98-
let mut fields_hash = 0;
99-
for field in &self.fields {
100-
fields_hash += field.bind(py).hash()?;
101-
}
102-
Ok(fields_hash.wrapping_add(self.tag.into()))
120+
fn __len__(&self) -> usize {
121+
self.fields.len()
122+
}
123+
124+
fn __getitem__(&self, index: isize, py: Python<'_>) -> PyResult<PyObject> {
125+
Ok(self.fields[self.compute_index(index)?].clone_ref(py))
126+
}
127+
128+
fn __setitem__(&mut self, index: isize, value: PyObject) -> PyResult<()> {
129+
let index = self.compute_index(index)?;
130+
self.fields[index] = value;
131+
Ok(())
103132
}
104133

105134
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
19+
from neo4j._codec.packstream import Structure
20+
21+
22+
@pytest.mark.parametrize(
23+
"args",
24+
(
25+
(b"T", 1, 2, 3, "abc", 1.2, None, False),
26+
(b"F",),
27+
),
28+
)
29+
def test_structure_accessors(args):
30+
tag = args[0]
31+
fields = list(args[1:])
32+
s1 = Structure(*args)
33+
assert s1.tag == tag
34+
assert s1.fields == fields
35+
36+
37+
@pytest.mark.parametrize(
38+
("other", "expected"),
39+
(
40+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True),
41+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False),
42+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False),
43+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False),
44+
(Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False),
45+
(
46+
Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]),
47+
False,
48+
),
49+
(Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
50+
(Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
51+
(Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False),
52+
(object(), NotImplemented),
53+
),
54+
)
55+
def test_structure_equality(other, expected):
56+
s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None])
57+
assert s1.__eq__(other) is expected # noqa: PLC2801
58+
if expected is NotImplemented:
59+
assert s1.__ne__(other) is NotImplemented # noqa: PLC2801
60+
else:
61+
assert s1.__ne__(other) is not expected # noqa: PLC2801
62+
63+
64+
@pytest.mark.parametrize(
65+
("args", "expected"),
66+
(
67+
((b"F", 1, 2), "Structure(b'F', 1, 2)"),
68+
((b"f", [1, 2]), "Structure(b'f', [1, 2])"),
69+
(
70+
(b"T", 1.3, None, {"a": "b"}),
71+
"Structure(b'T', 1.3, None, {'a': 'b'})",
72+
),
73+
),
74+
)
75+
def test_structure_repr(args, expected):
76+
s1 = Structure(*args)
77+
assert repr(s1) == expected
78+
assert str(s1) == expected
79+
80+
# Ensure that the repr is consistent with the constructor
81+
assert eval(repr(s1)) == s1
82+
assert eval(str(s1)) == s1
83+
84+
85+
@pytest.mark.parametrize(
86+
("fields", "expected"),
87+
(
88+
((), 0),
89+
(([],), 1),
90+
((1, 2), 2),
91+
((1, 2, []), 3),
92+
(([1, 2], {"a": "foo", "b": "bar"}), 2),
93+
),
94+
)
95+
def test_structure_len(fields, expected):
96+
structure = Structure(b"F", *fields)
97+
assert len(structure) == expected
98+
99+
100+
def test_structure_getitem():
101+
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
102+
structure = Structure(b"F", *fields)
103+
for i, field in enumerate(fields):
104+
assert structure[i] == field
105+
assert structure[-len(fields) + i] == field
106+
with pytest.raises(IndexError):
107+
_ = structure[len(fields)]
108+
with pytest.raises(IndexError):
109+
_ = structure[-len(fields) - 1]
110+
111+
112+
def test_structure_setitem():
113+
test_value = object()
114+
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
115+
structure = Structure(b"F", *fields)
116+
for i, original_value in enumerate(fields):
117+
structure[i] = test_value
118+
assert structure[i] == test_value
119+
assert structure[-len(fields) + i] == test_value
120+
assert structure[i] != original_value
121+
assert structure[-len(fields) + i] != original_value
122+
123+
structure[i] = original_value
124+
assert structure[i] == original_value
125+
assert structure[-len(fields) + i] == original_value
126+
127+
structure[-len(fields) + i] = test_value
128+
assert structure[i] == test_value
129+
assert structure[-len(fields) + i] == test_value
130+
assert structure[i] != original_value
131+
assert structure[-len(fields) + i] != original_value
132+
133+
structure[-len(fields) + i] = original_value
134+
assert structure[i] == original_value
135+
assert structure[-len(fields) + i] == original_value
136+
with pytest.raises(IndexError):
137+
structure[len(fields)] = test_value
138+
with pytest.raises(IndexError):
139+
structure[-len(fields) - 1] = test_value

tests/codec/packstream/v1/test_injection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_rust_struct_access():
124124

125125
assert struct.tag == tag
126126
assert isinstance(struct.tag, bytes)
127-
assert struct.fields == tuple(fields)
127+
assert struct.fields == fields
128128

129129

130130
def test_rust_struct_equal():

0 commit comments

Comments
 (0)