Skip to content

Commit 89da78d

Browse files
committed
Add some more tests
1 parent 22f3d30 commit 89da78d

File tree

5 files changed

+328
-103
lines changed

5 files changed

+328
-103
lines changed

docs/sphinx/ug_objects.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@ GenericNodeData can also be initialized with keyword args like this::
207207

208208
obj = GenericNodeData(a=1, b=2)
209209

210+
Trees that contain GenericNodeData objects can be serialized and deserialized
211+
using the :meth:`~nutree.tree.Tree.save` and :meth:`~nutree.tree.Tree.load`
212+
methods::
213+
214+
tree.save(file_path, mapper=GenericNodeData.serialize_mapper)
215+
...
216+
tree2 = Tree.load(file_path, mapper=GenericNodeData.deserialize_mapper)
217+
210218
.. warning::
211219
The :class:`~nutree.common.GenericNodeData` provides a hash value because
212220
any class that is hashable, so it can be used as a data object. However, the

nutree/common.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,17 @@ class StopTraversal(IterationControl):
107107
"""Raised or returned by traversal callbacks to stop iteration.
108108
109109
Optionally, a return value may be passed.
110-
Note that if a callback returns ``False``, this will be converted to an
110+
Note that if a callback returns ``False``, this will be converted to a
111111
``StopTraversal(None)`` exception.
112112
"""
113113

114114
def __init__(self, value=None):
115115
self.value = value
116116

117117

118-
#:
118+
#: Generic callback for `tree.filter()`, `tree.copy()`, ...
119119
PredicateCallbackType = Callable[["Node"], Union[None, bool, IterationControl]]
120-
#:
120+
#: Generic callback for `tree.to_dot()`, ...
121121
MapperCallbackType = Callable[["Node", dict], Union[None, Any]]
122122
#: Callback for `tree.save()`
123123
SerializeMapperType = Callable[["Node", dict], Union[None, dict]]
@@ -221,10 +221,27 @@ def __getattr__(self, name: str) -> Any:
221221
except KeyError:
222222
raise AttributeError(name) from None
223223

224-
@staticmethod
225-
def serialize_mapper(nutree_node, data):
224+
@classmethod
225+
def serialize_mapper(cls, nutree_node: Node, data: dict) -> Union[None, dict]:
226+
"""Serialize the data object to a dictionary.
227+
228+
Example::
229+
230+
tree.save(file_path, mapper=GenericNodeData.serialize_mapper)
231+
232+
"""
226233
return nutree_node.data._dict.copy()
227234

235+
@classmethod
236+
def deserialize_mapper(cls, nutree_node: Node, data: dict) -> Union[str, object]:
237+
"""Serialize the data object to a dictionary.
238+
239+
Example::
240+
241+
tree = Tree.load(file_path, mapper=GenericNodeData.deserialize_mapper)
242+
"""
243+
return cls(**data)
244+
228245

229246
def get_version() -> str:
230247
from nutree import __version__
@@ -323,7 +340,7 @@ def call_traversal_cb(fn: Callable, node: Node, memo: Any) -> False | None:
323340
RuntimeWarning,
324341
stacklevel=3,
325342
)
326-
raise StopTraversal(e.value) from None
343+
raise StopTraversal(e.value) from e
327344
return None
328345

329346

tests/test_core.py

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111
from nutree import AmbiguousMatchError, IterMethod, Node, Tree
12-
from nutree.common import SkipBranch, StopTraversal
12+
from nutree.common import SkipBranch, StopTraversal, check_python_version
1313
from nutree.fs import load_tree_from_fs
1414

1515
from . import fixture
@@ -23,6 +23,12 @@ def _make_tree_2():
2323
return t
2424

2525

26+
class TestCommon:
27+
def test_check_python_version(self):
28+
assert check_python_version((3, 7)) is True
29+
assert check_python_version((99, 1)) is False
30+
31+
2632
class TestBasics:
2733
def test_add_child(self):
2834
tree = Tree("fixture")
@@ -354,7 +360,7 @@ def test_data_id(self):
354360
)
355361
assert tree._self_check()
356362

357-
def test_search(self):
363+
def test_find(self):
358364
tree = self.tree
359365

360366
records = tree["Records"]
@@ -637,20 +643,94 @@ def cb(node, memo):
637643
tree.visit(cb, method=IterMethod.LEVEL_ORDER)
638644
assert ",".join(res) == "A,B,a1,a2,b1,a11,a12,b11"
639645

646+
def test_visit_cb(self):
647+
"""
648+
Tree<'fixture'>
649+
├── A
650+
│ ├── a1
651+
│ │ ├── a11
652+
│ │ ╰── a12
653+
│ ╰── a2
654+
╰── B
655+
╰── b1
656+
╰── b11
657+
"""
658+
tree = fixture.create_tree()
659+
640660
res = []
641661

642662
def cb(node, memo):
643663
res.append(node.name)
644664
if node.name == "a1":
645665
return SkipBranch
666+
if node.name == "b1":
667+
return StopTraversal
668+
669+
res_2 = tree.visit(cb)
670+
671+
assert res_2 is None
672+
assert ",".join(res) == "A,a1,a2,B,b1"
673+
674+
res = []
675+
676+
def cb(node, memo):
677+
res.append(node.name)
678+
if node.name == "a1":
679+
raise SkipBranch(and_self=True)
646680
if node.name == "b1":
647681
raise StopTraversal("Found b1")
648682

649683
res_2 = tree.visit(cb)
650684

651685
assert res_2 == "Found b1"
686+
# and_self does not skip self in this case
652687
assert ",".join(res) == "A,a1,a2,B,b1"
653688

689+
res = []
690+
691+
def cb(node, memo):
692+
res.append(node.name)
693+
if node.name == "a12":
694+
raise StopIteration
695+
696+
res_2 = tree.visit(cb)
697+
698+
assert ",".join(res) == "A,a1,a11,a12"
699+
700+
res = []
701+
702+
def cb(node, memo):
703+
res.append(node.name)
704+
if node.name == "a12":
705+
return StopIteration
706+
707+
res_2 = tree.visit(cb)
708+
709+
assert ",".join(res) == "A,a1,a11,a12"
710+
711+
res = []
712+
713+
def cb(node, memo):
714+
res.append(node.name)
715+
if node.name == "a12":
716+
return False
717+
718+
res_2 = tree.visit(cb)
719+
720+
assert ",".join(res) == "A,a1,a11,a12"
721+
722+
res = []
723+
724+
def cb(node, memo):
725+
res.append(node.name)
726+
if node.name == "b1":
727+
return 17
728+
729+
with pytest.raises(
730+
ValueError, match="callback should not return values except for"
731+
):
732+
res_2 = tree.visit(cb)
733+
654734

655735
class TestMutate:
656736
def test_add(self):
@@ -983,6 +1063,17 @@ def test_tree_copy_to(self):
9831063
)
9841064

9851065
def test_filter(self):
1066+
"""
1067+
Tree<'fixture'>
1068+
├── A
1069+
│ ├── a1
1070+
│ │ ├── a11
1071+
│ │ ╰── a12
1072+
│ ╰── a2
1073+
╰── B
1074+
╰── b1
1075+
╰── b11
1076+
"""
9861077
tree = fixture.create_tree()
9871078

9881079
def pred(node):
@@ -1005,6 +1096,17 @@ def pred(node):
10051096
)
10061097

10071098
def test_filtered(self):
1099+
"""
1100+
Tree<'fixture'>
1101+
├── A
1102+
│ ├── a1
1103+
│ │ ├── a11
1104+
│ │ ╰── a12
1105+
│ ╰── a2
1106+
╰── B
1107+
╰── b1
1108+
╰── b11
1109+
"""
10081110
tree = fixture.create_tree()
10091111

10101112
def pred(node):
@@ -1026,6 +1128,57 @@ def pred(node):
10261128
""",
10271129
)
10281130

1131+
def pred(node):
1132+
if node.name == "a12":
1133+
raise SkipBranch
1134+
return "2" in node.name.lower()
1135+
1136+
tree_2 = tree.filtered(predicate=pred)
1137+
1138+
assert tree_2._self_check()
1139+
assert fixture.check_content(
1140+
tree_2,
1141+
"""
1142+
Tree<*>
1143+
╰── A
1144+
╰── a2
1145+
╰── a2
1146+
""",
1147+
)
1148+
1149+
def pred(node):
1150+
if node.name == "a12":
1151+
raise StopIteration
1152+
return "2" in node.name.lower()
1153+
1154+
tree_2 = tree.filtered(predicate=pred)
1155+
1156+
assert tree_2._self_check()
1157+
assert fixture.check_content(
1158+
tree_2,
1159+
"""
1160+
Tree<*>
1161+
""",
1162+
)
1163+
1164+
tree_2 = tree.filtered(predicate=None)
1165+
1166+
assert tree_2._self_check()
1167+
assert fixture.check_content(
1168+
tree_2,
1169+
"""
1170+
Tree<*>
1171+
├── A
1172+
│ ├── a1
1173+
│ │ ├── a11
1174+
│ │ ╰── a12
1175+
│ ╰── a2
1176+
╰── B
1177+
╰── b1
1178+
╰── b11
1179+
""",
1180+
)
1181+
10291182

10301183
class TestFS:
10311184
@pytest.mark.skipif(os.name == "nt", reason="windows has different eol size")

tests/test_serialize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def test_serialize_compressed(self):
8787
with fixture.WritableTempFile("r+t") as temp_file:
8888
tree.save(temp_file.name, compression=zipfile.ZIP_DEFLATED)
8989
tree_2 = Tree.load(temp_file.name)
90+
91+
with pytest.raises(UnicodeDecodeError):
92+
_ = Tree.load(temp_file.name, auto_uncompress=False)
93+
94+
assert fixture.trees_equal(tree, tree_2)
95+
96+
with fixture.WritableTempFile("r+t") as temp_file:
97+
tree.save(temp_file.name, compression=True)
98+
tree_2 = Tree.load(temp_file.name)
9099
assert fixture.trees_equal(tree, tree_2)
91100

92101
with fixture.WritableTempFile("r+t") as temp_file:
@@ -99,6 +108,17 @@ def test_serialize_compressed(self):
99108
tree_2 = Tree.load(temp_file.name)
100109
assert fixture.trees_equal(tree, tree_2)
101110

111+
def test_serialize_uncompressed(self):
112+
tree = fixture.create_tree()
113+
tree.add_child("äöüß: \u00e4\u00f6\u00fc\u00df")
114+
tree.add_child("emoji: 😀")
115+
116+
with fixture.WritableTempFile("r+t") as temp_file:
117+
tree.save(temp_file.name, compression=False)
118+
tree_2 = Tree.load(temp_file.name)
119+
120+
assert fixture.trees_equal(tree, tree_2)
121+
102122
def _test_serialize_objects(self, *, mode: str):
103123
"""Save/load an object tree with clones.
104124

0 commit comments

Comments
 (0)