Skip to content

Improve packstream Structure class #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: 6.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/target_driver.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ git checkout "$version"
git pull origin "$version"
cd ..
cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/codec/packstream/v1/from_driver/test_packstream.py
cp driver/tests/unit/common/codec/packstream/test_structure.py tests/unit/common/codec/packstream/from_driver/test_structure.py
cp -r driver/tests/unit/common/vector/* tests/vector/from_driver

towncrier create -c "Target driver version ${version}<ISSUES_LIST>.
Expand Down
6 changes: 6 additions & 0 deletions changelog.d/63.clean.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Improve packstream `Structure` class<ISSUES_LIST>.

* Implement `repr` to match Python implementation.
* Remove `__hash__` implementation to match Python implementation.
* Implement `__getitem__` and `__setitem__` to be on par with Python implementation.
* Copy tests for `Structure` from the driver project.
85 changes: 57 additions & 28 deletions src/codec/packstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
mod v1;

use pyo3::basic::CompareOp;
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyTuple};
use pyo3::types::PyBytes;
use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit};

use crate::register_package;
Expand Down Expand Up @@ -46,6 +46,40 @@ pub struct Structure {
fields: Vec<PyObject>,
}

impl Structure {
fn eq(&self, other: &Self, py: Python<'_>) -> PyResult<bool> {
if self.tag != other.tag || self.fields.len() != other.fields.len() {
return Ok(false);
}
for (a, b) in self
.fields
.iter()
.map(|e| e.bind(py))
.zip(other.fields.iter().map(|e| e.bind(py)))
{
if !a.eq(b)? {
return Ok(false);
}
}
Ok(true)
}

fn compute_index(&self, index: isize) -> PyResult<usize> {
Ok(if index < 0 {
self.fields
.len()
.checked_sub(-index as usize)
.ok_or_else(|| PyErr::new::<PyIndexError, _>("field index out of range"))?
} else {
let index = index as usize;
if index >= self.fields.len() {
return Err(PyErr::new::<PyIndexError, _>("field index out of range"));
}
index
})
}
}

#[pymethods]
impl Structure {
#[new]
Expand All @@ -64,26 +98,15 @@ impl Structure {
PyBytes::new(py, &[self.tag])
}

#[getter(fields)]
fn read_fields<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
PyTuple::new(py, &self.fields)
}

fn eq(&self, other: &Self, py: Python<'_>) -> PyResult<bool> {
if self.tag != other.tag || self.fields.len() != other.fields.len() {
return Ok(false);
}
for (a, b) in self
.fields
.iter()
.map(|e| e.bind(py))
.zip(other.fields.iter().map(|e| e.bind(py)))
{
if !a.eq(b)? {
return Ok(false);
}
}
Ok(true)
fn __repr__(&self, py: Python<'_>) -> PyResult<String> {
let mut args = format!(r"b'{}'", self.tag as char);
self.fields.iter().try_for_each(|field| {
let repr = field.bind(py).repr()?;
args.push_str(", ");
args.push_str(&repr.to_cow()?);
Ok::<_, PyErr>(())
})?;
Ok(format!("Structure({args})"))
}

fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
Expand All @@ -94,12 +117,18 @@ impl Structure {
})
}

fn __hash__(&self, py: Python<'_>) -> PyResult<isize> {
let mut fields_hash = 0;
for field in &self.fields {
fields_hash += field.bind(py).hash()?;
}
Ok(fields_hash.wrapping_add(self.tag.into()))
fn __len__(&self) -> usize {
self.fields.len()
}

fn __getitem__(&self, index: isize, py: Python<'_>) -> PyResult<PyObject> {
Ok(self.fields[self.compute_index(index)?].clone_ref(py))
}

fn __setitem__(&mut self, index: isize, value: PyObject) -> PyResult<()> {
let index = self.compute_index(index)?;
self.fields[index] = value;
Ok(())
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Expand Down
14 changes: 14 additions & 0 deletions tests/codec/packstream/from_driver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
139 changes: 139 additions & 0 deletions tests/codec/packstream/from_driver/test_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from neo4j._codec.packstream import Structure


@pytest.mark.parametrize(
"args",
(
(b"T", 1, 2, 3, "abc", 1.2, None, False),
(b"F",),
),
)
def test_structure_accessors(args):
tag = args[0]
fields = list(args[1:])
s1 = Structure(*args)
assert s1.tag == tag
assert s1.fields == fields


@pytest.mark.parametrize(
("other", "expected"),
(
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True),
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False),
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False),
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False),
(Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False),
(
Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]),
False,
),
(Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
(Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
(Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False),
(object(), NotImplemented),
),
)
def test_structure_equality(other, expected):
s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None])
assert s1.__eq__(other) is expected # noqa: PLC2801
if expected is NotImplemented:
assert s1.__ne__(other) is NotImplemented # noqa: PLC2801
else:
assert s1.__ne__(other) is not expected # noqa: PLC2801


@pytest.mark.parametrize(
("args", "expected"),
(
((b"F", 1, 2), "Structure(b'F', 1, 2)"),
((b"f", [1, 2]), "Structure(b'f', [1, 2])"),
(
(b"T", 1.3, None, {"a": "b"}),
"Structure(b'T', 1.3, None, {'a': 'b'})",
),
),
)
def test_structure_repr(args, expected):
s1 = Structure(*args)
assert repr(s1) == expected
assert str(s1) == expected

# Ensure that the repr is consistent with the constructor
assert eval(repr(s1)) == s1
assert eval(str(s1)) == s1


@pytest.mark.parametrize(
("fields", "expected"),
(
((), 0),
(([],), 1),
((1, 2), 2),
((1, 2, []), 3),
(([1, 2], {"a": "foo", "b": "bar"}), 2),
),
)
def test_structure_len(fields, expected):
structure = Structure(b"F", *fields)
assert len(structure) == expected


def test_structure_getitem():
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
structure = Structure(b"F", *fields)
for i, field in enumerate(fields):
assert structure[i] == field
assert structure[-len(fields) + i] == field
with pytest.raises(IndexError):
_ = structure[len(fields)]
with pytest.raises(IndexError):
_ = structure[-len(fields) - 1]


def test_structure_setitem():
test_value = object()
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
structure = Structure(b"F", *fields)
for i, original_value in enumerate(fields):
structure[i] = test_value
assert structure[i] == test_value
assert structure[-len(fields) + i] == test_value
assert structure[i] != original_value
assert structure[-len(fields) + i] != original_value

structure[i] = original_value
assert structure[i] == original_value
assert structure[-len(fields) + i] == original_value

structure[-len(fields) + i] = test_value
assert structure[i] == test_value
assert structure[-len(fields) + i] == test_value
assert structure[i] != original_value
assert structure[-len(fields) + i] != original_value

structure[-len(fields) + i] = original_value
assert structure[i] == original_value
assert structure[-len(fields) + i] == original_value
with pytest.raises(IndexError):
structure[len(fields)] = test_value
with pytest.raises(IndexError):
structure[-len(fields) - 1] = test_value
2 changes: 1 addition & 1 deletion tests/codec/packstream/v1/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_rust_struct_access():

assert struct.tag == tag
assert isinstance(struct.tag, bytes)
assert struct.fields == tuple(fields)
assert struct.fields == fields


def test_rust_struct_equal():
Expand Down