diff --git a/bin/target_driver.sh b/bin/target_driver.sh index 5f6e055..4a170f6 100755 --- a/bin/target_driver.sh +++ b/bin/target_driver.sh @@ -12,6 +12,7 @@ git checkout "$version" git pull origin "$version" cd .. cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/codec/packstream/v1/from_driver/test_packstream.py +cp driver/tests/unit/common/codec/packstream/test_structure.py tests/unit/common/codec/packstream/from_driver/test_structure.py cp -r driver/tests/unit/common/vector/* tests/vector/from_driver towncrier create -c "Target driver version ${version}. diff --git a/changelog.d/63.clean.md b/changelog.d/63.clean.md new file mode 100644 index 0000000..23b0985 --- /dev/null +++ b/changelog.d/63.clean.md @@ -0,0 +1,6 @@ +Improve packstream `Structure` class. + + * Implement `repr` to match Python implementation. + * Remove `__hash__` implementation to match Python implementation. + * Implement `__getitem__` and `__setitem__` to be on par with Python implementation. + * Copy tests for `Structure` from the driver project. diff --git a/src/codec/packstream.rs b/src/codec/packstream.rs index 55a2e36..8f6f178 100644 --- a/src/codec/packstream.rs +++ b/src/codec/packstream.rs @@ -16,9 +16,9 @@ mod v1; use pyo3::basic::CompareOp; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyIndexError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyTuple}; +use pyo3::types::PyBytes; use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit}; use crate::register_package; @@ -46,6 +46,40 @@ pub struct Structure { fields: Vec, } +impl Structure { + fn eq(&self, other: &Self, py: Python<'_>) -> PyResult { + if self.tag != other.tag || self.fields.len() != other.fields.len() { + return Ok(false); + } + for (a, b) in self + .fields + .iter() + .map(|e| e.bind(py)) + .zip(other.fields.iter().map(|e| e.bind(py))) + { + if !a.eq(b)? { + return Ok(false); + } + } + Ok(true) + } + + fn compute_index(&self, index: isize) -> PyResult { + Ok(if index < 0 { + self.fields + .len() + .checked_sub(-index as usize) + .ok_or_else(|| PyErr::new::("field index out of range"))? + } else { + let index = index as usize; + if index >= self.fields.len() { + return Err(PyErr::new::("field index out of range")); + } + index + }) + } +} + #[pymethods] impl Structure { #[new] @@ -64,26 +98,15 @@ impl Structure { PyBytes::new(py, &[self.tag]) } - #[getter(fields)] - fn read_fields<'py>(&self, py: Python<'py>) -> PyResult> { - PyTuple::new(py, &self.fields) - } - - fn eq(&self, other: &Self, py: Python<'_>) -> PyResult { - if self.tag != other.tag || self.fields.len() != other.fields.len() { - return Ok(false); - } - for (a, b) in self - .fields - .iter() - .map(|e| e.bind(py)) - .zip(other.fields.iter().map(|e| e.bind(py))) - { - if !a.eq(b)? { - return Ok(false); - } - } - Ok(true) + fn __repr__(&self, py: Python<'_>) -> PyResult { + let mut args = format!(r"b'{}'", self.tag as char); + self.fields.iter().try_for_each(|field| { + let repr = field.bind(py).repr()?; + args.push_str(", "); + args.push_str(&repr.to_cow()?); + Ok::<_, PyErr>(()) + })?; + Ok(format!("Structure({args})")) } fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult { @@ -94,12 +117,18 @@ impl Structure { }) } - fn __hash__(&self, py: Python<'_>) -> PyResult { - let mut fields_hash = 0; - for field in &self.fields { - fields_hash += field.bind(py).hash()?; - } - Ok(fields_hash.wrapping_add(self.tag.into())) + fn __len__(&self) -> usize { + self.fields.len() + } + + fn __getitem__(&self, index: isize, py: Python<'_>) -> PyResult { + Ok(self.fields[self.compute_index(index)?].clone_ref(py)) + } + + fn __setitem__(&mut self, index: isize, value: PyObject) -> PyResult<()> { + let index = self.compute_index(index)?; + self.fields[index] = value; + Ok(()) } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { diff --git a/tests/codec/packstream/from_driver/__init__.py b/tests/codec/packstream/from_driver/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/packstream/from_driver/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/codec/packstream/from_driver/test_structure.py b/tests/codec/packstream/from_driver/test_structure.py new file mode 100644 index 0000000..7f031b2 --- /dev/null +++ b/tests/codec/packstream/from_driver/test_structure.py @@ -0,0 +1,139 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.packstream import Structure + + +@pytest.mark.parametrize( + "args", + ( + (b"T", 1, 2, 3, "abc", 1.2, None, False), + (b"F",), + ), +) +def test_structure_accessors(args): + tag = args[0] + fields = list(args[1:]) + s1 = Structure(*args) + assert s1.tag == tag + assert s1.fields == fields + + +@pytest.mark.parametrize( + ("other", "expected"), + ( + (Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True), + (Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False), + (Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False), + (Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False), + (Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False), + ( + Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]), + False, + ), + (Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False), + (Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False), + (Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False), + (object(), NotImplemented), + ), +) +def test_structure_equality(other, expected): + s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]) + assert s1.__eq__(other) is expected # noqa: PLC2801 + if expected is NotImplemented: + assert s1.__ne__(other) is NotImplemented # noqa: PLC2801 + else: + assert s1.__ne__(other) is not expected # noqa: PLC2801 + + +@pytest.mark.parametrize( + ("args", "expected"), + ( + ((b"F", 1, 2), "Structure(b'F', 1, 2)"), + ((b"f", [1, 2]), "Structure(b'f', [1, 2])"), + ( + (b"T", 1.3, None, {"a": "b"}), + "Structure(b'T', 1.3, None, {'a': 'b'})", + ), + ), +) +def test_structure_repr(args, expected): + s1 = Structure(*args) + assert repr(s1) == expected + assert str(s1) == expected + + # Ensure that the repr is consistent with the constructor + assert eval(repr(s1)) == s1 + assert eval(str(s1)) == s1 + + +@pytest.mark.parametrize( + ("fields", "expected"), + ( + ((), 0), + (([],), 1), + ((1, 2), 2), + ((1, 2, []), 3), + (([1, 2], {"a": "foo", "b": "bar"}), 2), + ), +) +def test_structure_len(fields, expected): + structure = Structure(b"F", *fields) + assert len(structure) == expected + + +def test_structure_getitem(): + fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}] + structure = Structure(b"F", *fields) + for i, field in enumerate(fields): + assert structure[i] == field + assert structure[-len(fields) + i] == field + with pytest.raises(IndexError): + _ = structure[len(fields)] + with pytest.raises(IndexError): + _ = structure[-len(fields) - 1] + + +def test_structure_setitem(): + test_value = object() + fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}] + structure = Structure(b"F", *fields) + for i, original_value in enumerate(fields): + structure[i] = test_value + assert structure[i] == test_value + assert structure[-len(fields) + i] == test_value + assert structure[i] != original_value + assert structure[-len(fields) + i] != original_value + + structure[i] = original_value + assert structure[i] == original_value + assert structure[-len(fields) + i] == original_value + + structure[-len(fields) + i] = test_value + assert structure[i] == test_value + assert structure[-len(fields) + i] == test_value + assert structure[i] != original_value + assert structure[-len(fields) + i] != original_value + + structure[-len(fields) + i] = original_value + assert structure[i] == original_value + assert structure[-len(fields) + i] == original_value + with pytest.raises(IndexError): + structure[len(fields)] = test_value + with pytest.raises(IndexError): + structure[-len(fields) - 1] = test_value diff --git a/tests/codec/packstream/v1/test_injection.py b/tests/codec/packstream/v1/test_injection.py index 6869af9..d2f3384 100644 --- a/tests/codec/packstream/v1/test_injection.py +++ b/tests/codec/packstream/v1/test_injection.py @@ -124,7 +124,7 @@ def test_rust_struct_access(): assert struct.tag == tag assert isinstance(struct.tag, bytes) - assert struct.fields == tuple(fields) + assert struct.fields == fields def test_rust_struct_equal():