Skip to content

Commit 3d44d18

Browse files
author
Tomás Link
committed
Add beam.testing.utils.equal_to function to display colorful differences
1 parent db7be55 commit 3d44d18

File tree

10 files changed

+442
-4
lines changed

10 files changed

+442
-4
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ classifiers = [
4949
"Topic :: Scientific/Engineering",
5050
]
5151
dependencies = [
52+
"jinja2~=3.1",
5253
"pyyaml~=6.0",
5354
"rich~=14.0",
54-
"jinja2~=3.1",
5555
]
5656

5757
[project.optional-dependencies]
@@ -182,7 +182,7 @@ strict = true
182182
ignore_missing_imports = true
183183
files = "src"
184184
mypy_path = "src"
185-
disable_error_code = ["union-attr", "no-any-return"]
185+
disable_error_code = ["union-attr", "no-any-return", "var-annotated"]
186186
explicit_package_bases = true
187187

188188
[[tool.mypy.overrides]]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Testing utilities for Apache Beam pipelines."""

src/gfw/common/beam/testing/utils.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""General utilities for testing Apache Beam pipelines."""
2+
3+
from itertools import zip_longest
4+
from typing import Any, Callable, Iterable, List, Sequence
5+
6+
from apache_beam.testing.util import BeamAssertException
7+
from rich.console import Console, Group, RenderableType
8+
9+
from gfw.common.diff import compare_items, render_diff_panel
10+
from gfw.common.sorting import sort_dicts
11+
12+
13+
def _default_equals_fn(e: Any, a: Any) -> bool:
14+
return e == a
15+
16+
17+
def _raise_with_diff(diffs: Sequence[RenderableType]) -> None:
18+
# Set up a Rich Console that records output
19+
console = Console(record=True, force_terminal=True, width=130)
20+
21+
# Render diffs to console (only into memory, not to screen)
22+
console.print(Group(*diffs))
23+
24+
# Export the captured diff as text with ANSI codes
25+
diff_text = console.export_text(styles=True)
26+
27+
# Raise exception with embedded colored diff
28+
raise BeamAssertException(f"PCollection contents differ: \n{diff_text}.")
29+
30+
31+
def equal_to(
32+
expected: List[Any], equals_fn: Callable[[Any, Any], bool] = _default_equals_fn
33+
) -> Callable[[List[Any]], None]:
34+
"""Drop-in replacement for `apache_beam.testing.util.equal_to` with rich diff output.
35+
36+
This matcher performs unordered comparison of top-level elements in actual and expected
37+
PCollection outputs, just like Apache Beam's `equal_to`. However, it adds a rich diff
38+
visualization to help debug mismatches by rendering side-by-side differences.
39+
40+
Use in tests with `assert_that(pcoll, equal_to(expected))`.
41+
42+
Note:
43+
- Only top-level permutations are considered equal:
44+
`[1, 2]` and `[2, 1]` are equal, but `[[1, 2]]` and `[[2, 1]]` are not.
45+
46+
- If elements are not directly comparable, a fallback comparison using
47+
a custom equality function or deep diff is used. This helps handle:
48+
1) Collections with types that don't have a deterministic sort order
49+
(e.g., pyarrow Tables as of 0.14.1).
50+
2) Collections containing elements of different types.
51+
52+
Args:
53+
expected: Iterable of expected PCollection elements.
54+
equals_fn: Optional function `(expected_item, actual_item) -> bool` to customize equality.
55+
56+
Returns:
57+
A matcher function for use with `apache_beam.testing.util.assert_that`.
58+
"""
59+
60+
def _matcher(actual: Iterable[Any]) -> None:
61+
expected_list = [sort_dicts(e) for e in expected]
62+
actual_list = [sort_dicts(e) for e in actual]
63+
64+
try:
65+
if actual_list == expected_list:
66+
return
67+
except TypeError:
68+
pass
69+
70+
# Slower method, fallback comparison.
71+
unmatched_expected = expected_list[:]
72+
unmatched_actual = []
73+
for a in actual_list:
74+
for i, e in enumerate(unmatched_expected):
75+
if equals_fn(e, a):
76+
unmatched_expected.pop(i)
77+
break
78+
else:
79+
unmatched_actual.append(a)
80+
81+
if not unmatched_actual and not unmatched_expected:
82+
return
83+
84+
diffs = []
85+
for i, (a, b) in enumerate(
86+
zip_longest(unmatched_actual, unmatched_expected, fillvalue={}), 1
87+
):
88+
left, right, changed = compare_items(a, b)
89+
if changed:
90+
diffs.append(render_diff_panel(left, right, i))
91+
92+
if diffs: # Diffs found. Raise exception with colorized output.
93+
_raise_with_diff(diffs)
94+
95+
return _matcher

src/gfw/common/diff.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""General utilities for generating diffs between objects."""
2+
3+
from difflib import ndiff
4+
from typing import Any, Tuple
5+
6+
from rich.columns import Columns
7+
from rich.panel import Panel
8+
from rich.pretty import pretty_repr
9+
10+
11+
def diff_lines(a: str, b: str) -> Tuple[str, str, bool]:
12+
"""Generate a line-by-line diff of two strings with rich markup.
13+
14+
Args:
15+
a:
16+
First multi-line string to compare.
17+
b:
18+
Second multi-line string to compare.
19+
20+
Returns:
21+
A tuple of (a_diff, b_diff, changed) where:
22+
- a_diff: The first string annotated with diff highlights.
23+
- b_diff: The second string annotated with diff highlights.
24+
- changed: True if any differences were found, False otherwise.
25+
"""
26+
a_lines, b_lines = a.splitlines(), b.splitlines()
27+
a_out, b_out = [], []
28+
changed = False
29+
for line in ndiff(a_lines, b_lines):
30+
tag, content = line[0], line[2:]
31+
if tag == " ":
32+
a_out.append(f" {content}")
33+
b_out.append(f" {content}")
34+
elif tag == "-":
35+
changed = True
36+
a_out.append(f"[red]- {content}[/red]")
37+
b_out.append("") # line not in b
38+
elif tag == "+":
39+
changed = True
40+
a_out.append("") # line not in a
41+
b_out.append(f"[green]+ {content}[/green]")
42+
43+
return "\n".join(a_out), "\n".join(b_out), changed
44+
45+
46+
def compare_items(a: Any, b: Any) -> Tuple[str, str, bool]:
47+
"""Generate a rich diff of two objects' pretty-printed representations.
48+
49+
Args:
50+
a:
51+
First object to compare.
52+
b:
53+
Second object to compare.
54+
55+
Returns:
56+
The object returned by diff_lines.
57+
"""
58+
return diff_lines(
59+
pretty_repr(a, indent_size=4, max_width=20),
60+
pretty_repr(b, indent_size=4, max_width=20),
61+
)
62+
63+
64+
def render_diff_panel(left: str, right: str, idx: int) -> Columns:
65+
"""Render side-by-side panels of diff strings for visual comparison.
66+
67+
Args:
68+
left:
69+
The left-side diff string (usually actual output).
70+
71+
right:
72+
The right-side diff string (usually expected output).
73+
74+
idx:
75+
Index number for labeling the diff panels.
76+
77+
Returns:
78+
A rich Columns object containing two Panels side-by-side.
79+
"""
80+
return Columns(
81+
[
82+
Panel(left, title=f"Actual #{idx}", expand=True),
83+
Panel(right, title=f"Expected #{idx}", expand=True),
84+
],
85+
expand=True,
86+
equal=True,
87+
)

src/gfw/common/sorting.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Utilities for sorting data structures."""
2+
3+
from typing import Any
4+
5+
6+
def sort_dicts(obj: Any) -> Any:
7+
"""Recursively sorts dict keys to get consistent ordering for comparison.
8+
9+
Lists, tuples, and other types are returned unchanged (except their contents
10+
get sorted recursively if they are dicts).
11+
12+
Args:
13+
obj: Any nested structure (dict, list, tuple, or other).
14+
15+
Returns:
16+
A new structure with dict keys sorted recursively.
17+
"""
18+
if isinstance(obj, dict):
19+
# Sort keys and recursively apply to values
20+
return {k: sort_dicts(obj[k]) for k in sorted(obj)}
21+
elif isinstance(obj, list):
22+
# Recursively apply to each element
23+
return [sort_dicts(e) for e in obj]
24+
elif isinstance(obj, tuple):
25+
# Recursively apply to each element and keep tuple type
26+
return tuple(sort_dicts(e) for e in obj)
27+
else:
28+
# Return base case unchanged
29+
return obj

tests/beam/testing/__init__.py

Whitespace-only changes.

tests/beam/testing/test_utils.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import pytest
2+
from apache_beam.testing.util import BeamAssertException
3+
4+
from gfw.common.beam.testing.utils import equal_to, _default_equals_fn
5+
6+
7+
def test_default_equals_fn():
8+
assert _default_equals_fn(1, 1)
9+
assert not _default_equals_fn(1, 2)
10+
11+
12+
def test_equal_to_match_exact_order():
13+
expected = [1, 2, 3]
14+
actual = [1, 2, 3]
15+
16+
matcher = equal_to(expected)
17+
matcher(actual) # should not raise
18+
19+
20+
def test_equal_to_match_different_order():
21+
expected = [1, 2, 3]
22+
actual = [3, 1, 2]
23+
24+
matcher = equal_to(expected)
25+
matcher(actual) # should not raise
26+
27+
28+
def test_equal_to_empty_lists():
29+
matcher = equal_to([])
30+
matcher([]) # should not raise
31+
32+
33+
def test_equal_to_mismatch_raises():
34+
expected = [1, 2]
35+
actual = [1, 3]
36+
37+
matcher = equal_to(expected)
38+
39+
with pytest.raises(BeamAssertException) as e:
40+
matcher(actual)
41+
42+
# The exception message should contain a substring hinting at mismatch
43+
assert "PCollection contents differ" in str(e.value)
44+
45+
46+
def test_equal_to_type_error_handling():
47+
class Uncomparable:
48+
def __eq__(self, other):
49+
raise TypeError()
50+
51+
a = [Uncomparable()]
52+
b = [Uncomparable()]
53+
54+
def safe_equals(x, y):
55+
return type(x) is type(y)
56+
57+
matcher = equal_to(b, equals_fn=safe_equals)
58+
59+
# Should not raise, because our fallback considers them equal by type
60+
matcher(a)
61+
62+
63+
def test_equal_to_custom_equals_fn():
64+
expected = [1, 2, 3]
65+
actual = [3, 2, 1]
66+
67+
def reversed_equals(e, a):
68+
return e == a
69+
70+
matcher = equal_to(expected, equals_fn=reversed_equals)
71+
matcher(actual) # should not raise
72+
73+
# A custom equals that never matches causes exception
74+
def never_equals(e, a):
75+
return False
76+
77+
matcher = equal_to(expected, equals_fn=never_equals)
78+
with pytest.raises(BeamAssertException):
79+
matcher(actual)
80+
81+
82+
def test_equal_to_handles_nested_dicts_order():
83+
expected = [{"b": 1, "a": 2}]
84+
actual = [{"a": 2, "b": 1}]
85+
86+
matcher = equal_to(expected)
87+
matcher(actual) # Should not raise because dict keys sorted recursively
88+
89+
90+
def test_equal_to_handles_unmatched_extra_and_missing():
91+
expected = [1, 2]
92+
actual = [1, 2, 3]
93+
94+
matcher = equal_to(expected)
95+
96+
with pytest.raises(BeamAssertException):
97+
matcher(actual)
98+
99+
expected = [1, 2, 3]
100+
actual = [1, 2]
101+
102+
matcher = equal_to(expected)
103+
with pytest.raises(BeamAssertException):
104+
matcher(actual)

tests/beam/transforms/test_pubsub.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def test_read_and_decode_from_pubsub():
3737
pubsub_messages = [
3838
dict(
3939
data=b'{"test": 123}',
40-
attributes={"key": "value"},
40+
attributes={
41+
"key2": "value2",
42+
"key1": "value1",
43+
},
4144
)
4245
]
4346

@@ -57,7 +60,10 @@ def test_read_and_decode_from_pubsub():
5760
expected = [
5861
{
5962
"data": '{"test": 123}',
60-
"attributes": {"key": "value"}
63+
"attributes": {
64+
"key1": "value1",
65+
"key2": "value2",
66+
}
6167
}
6268
]
6369

0 commit comments

Comments
 (0)