Skip to content

Commit a04ac46

Browse files
committed
[Python] Predicate Push Down for Scan / Read
1 parent 0cc6721 commit a04ac46

File tree

13 files changed

+468
-224
lines changed

13 files changed

+468
-224
lines changed

paimon-python/pypaimon/common/predicate.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,107 @@ def test(self, record: InternalRow) -> bool:
8282
else:
8383
raise ValueError("Unsupported predicate method: {}".format(self.method))
8484

85+
def test_by_value(self, value: Any) -> bool:
86+
if self.method == 'and':
87+
return all(p.test_by_value(value) for p in self.literals)
88+
if self.method == 'or':
89+
t = any(p.test_by_value(value) for p in self.literals)
90+
return t
91+
92+
if self.method == 'equal':
93+
return value == self.literals[0]
94+
if self.method == 'notEqual':
95+
return value != self.literals[0]
96+
if self.method == 'lessThan':
97+
return value < self.literals[0]
98+
if self.method == 'lessOrEqual':
99+
return value <= self.literals[0]
100+
if self.method == 'greaterThan':
101+
return value > self.literals[0]
102+
if self.method == 'greaterOrEqual':
103+
return value >= self.literals[0]
104+
if self.method == 'isNull':
105+
return value is None
106+
if self.method == 'isNotNull':
107+
return value is not None
108+
if self.method == 'startsWith':
109+
if not isinstance(value, str):
110+
return False
111+
return value.startswith(self.literals[0])
112+
if self.method == 'endsWith':
113+
if not isinstance(value, str):
114+
return False
115+
return value.endswith(self.literals[0])
116+
if self.method == 'contains':
117+
if not isinstance(value, str):
118+
return False
119+
return self.literals[0] in value
120+
if self.method == 'in':
121+
return value in self.literals
122+
if self.method == 'notIn':
123+
return value not in self.literals
124+
if self.method == 'between':
125+
return self.literals[0] <= value <= self.literals[1]
126+
127+
raise ValueError(f"Unsupported predicate method: {self.method}")
128+
129+
def test_by_stats(self, stat: dict) -> bool:
130+
if self.method == 'and':
131+
return all(p.test_by_stats(stat) for p in self.literals)
132+
if self.method == 'or':
133+
t = any(p.test_by_stats(stat) for p in self.literals)
134+
return t
135+
136+
null_count = stat["null_counts"][self.field]
137+
row_count = stat["row_count"]
138+
139+
if self.method == 'isNull':
140+
return null_count is not None and null_count > 0
141+
if self.method == 'isNotNull':
142+
return null_count is None or row_count is None or null_count < row_count
143+
144+
min_value = stat["min_values"][self.field]
145+
max_value = stat["max_values"][self.field]
146+
147+
if min_value is None or max_value is None or (null_count is not None and null_count == row_count):
148+
return False
149+
150+
if self.method == 'equal':
151+
return min_value <= self.literals[0] <= max_value
152+
if self.method == 'notEqual':
153+
return not (min_value == self.literals[0] == max_value)
154+
if self.method == 'lessThan':
155+
return self.literals[0] > min_value
156+
if self.method == 'lessOrEqual':
157+
return self.literals[0] >= min_value
158+
if self.method == 'greaterThan':
159+
return self.literals[0] < max_value
160+
if self.method == 'greaterOrEqual':
161+
return self.literals[0] <= max_value
162+
if self.method == 'startsWith':
163+
if not isinstance(min_value, str) or not isinstance(max_value, str):
164+
raise RuntimeError("startsWith predicate on non-str field")
165+
return ((min_value.startswith(self.literals[0]) or min_value < self.literals[0])
166+
and (max_value.startswith(self.literals[0]) or max_value > self.literals[0]))
167+
if self.method == 'endsWith':
168+
return True
169+
if self.method == 'contains':
170+
return True
171+
if self.method == 'in':
172+
for literal in self.literals:
173+
if min_value <= literal <= max_value:
174+
return True
175+
return False
176+
if self.method == 'notIn':
177+
for literal in self.literals:
178+
if min_value == literal == max_value:
179+
return False
180+
return True
181+
if self.method == 'between':
182+
return self.literals[0] <= max_value and self.literals[1] >= min_value
183+
else:
184+
raise ValueError(f"Unsupported predicate method: {self.method}")
185+
85186
def to_arrow(self) -> pyarrow_compute.Expression | bool:
86187
if self.method == 'equal':
87188
return pyarrow_dataset.field(self.field) == self.literals[0]

paimon-python/pypaimon/manifest/manifest_file_manager.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,19 @@ def read(self, manifest_file_name: str, shard_filter=None) -> List[ManifestEntry
5555
file_dict = dict(record['_FILE'])
5656
key_dict = dict(file_dict['_KEY_STATS'])
5757
key_stats = SimpleStats(
58-
min_value=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'],
59-
self.trimmed_primary_key_fields),
60-
max_value=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'],
61-
self.trimmed_primary_key_fields),
62-
null_count=key_dict['_NULL_COUNTS'],
58+
min_values=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'],
59+
self.trimmed_primary_key_fields),
60+
max_values=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'],
61+
self.trimmed_primary_key_fields),
62+
null_counts=key_dict['_NULL_COUNTS'],
6363
)
6464
value_dict = dict(file_dict['_VALUE_STATS'])
6565
value_stats = SimpleStats(
66-
min_value=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'],
67-
self.table.table_schema.fields),
68-
max_value=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'],
69-
self.table.table_schema.fields),
70-
null_count=value_dict['_NULL_COUNTS'],
66+
min_values=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'],
67+
self.table.table_schema.fields),
68+
max_values=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'],
69+
self.table.table_schema.fields),
70+
null_counts=value_dict['_NULL_COUNTS'],
7171
)
7272
file_meta = DataFileMeta(
7373
file_name=file_dict['_FILE_NAME'],
@@ -118,14 +118,14 @@ def write(self, file_name, commit_messages: List[CommitMessage]):
118118
"_MIN_KEY": BinaryRowSerializer.to_bytes(file.min_key),
119119
"_MAX_KEY": BinaryRowSerializer.to_bytes(file.max_key),
120120
"_KEY_STATS": {
121-
"_MIN_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.min_value),
122-
"_MAX_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.max_value),
123-
"_NULL_COUNTS": file.key_stats.null_count,
121+
"_MIN_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.min_values),
122+
"_MAX_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.max_values),
123+
"_NULL_COUNTS": file.key_stats.null_counts,
124124
},
125125
"_VALUE_STATS": {
126-
"_MIN_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.min_value),
127-
"_MAX_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.max_value),
128-
"_NULL_COUNTS": file.value_stats.null_count,
126+
"_MIN_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.min_values),
127+
"_MAX_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.max_values),
128+
"_NULL_COUNTS": file.value_stats.null_counts,
129129
},
130130
"_MIN_SEQUENCE_NUMBER": file.min_sequence_number,
131131
"_MAX_SEQUENCE_NUMBER": file.max_sequence_number,

paimon-python/pypaimon/manifest/manifest_list_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def read(self, manifest_list_name: str) -> List[ManifestFileMeta]:
5858
for record in reader:
5959
stats_dict = dict(record['_PARTITION_STATS'])
6060
partition_stats = SimpleStats(
61-
min_value=BinaryRowDeserializer.from_bytes(
61+
min_values=BinaryRowDeserializer.from_bytes(
6262
stats_dict['_MIN_VALUES'],
6363
self.table.table_schema.get_partition_key_fields()
6464
),
65-
max_value=BinaryRowDeserializer.from_bytes(
65+
max_values=BinaryRowDeserializer.from_bytes(
6666
stats_dict['_MAX_VALUES'],
6767
self.table.table_schema.get_partition_key_fields()
6868
),
69-
null_count=stats_dict['_NULL_COUNTS'],
69+
null_counts=stats_dict['_NULL_COUNTS'],
7070
)
7171
manifest_file_meta = ManifestFileMeta(
7272
file_name=record['_FILE_NAME'],
@@ -90,9 +90,9 @@ def write(self, file_name, manifest_file_metas: List[ManifestFileMeta]):
9090
"_NUM_ADDED_FILES": meta.num_added_files,
9191
"_NUM_DELETED_FILES": meta.num_deleted_files,
9292
"_PARTITION_STATS": {
93-
"_MIN_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.min_value),
94-
"_MAX_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.max_value),
95-
"_NULL_COUNTS": meta.partition_stats.null_count,
93+
"_MIN_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.min_values),
94+
"_MAX_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.max_values),
95+
"_NULL_COUNTS": meta.partition_stats.null_counts,
9696
},
9797
"_SCHEMA_ID": meta.schema_id,
9898
}

paimon-python/pypaimon/manifest/schema/simple_stats.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
@dataclass
2626
class SimpleStats:
27-
min_value: BinaryRow
28-
max_value: BinaryRow
29-
null_count: Optional[List[int]]
27+
min_values: BinaryRow
28+
max_values: BinaryRow
29+
null_counts: Optional[List[int]]
3030

3131

3232
SIMPLE_STATS_SCHEMA = {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
from typing import List, Set
20+
21+
from pypaimon.common.predicate import Predicate
22+
23+
24+
def extract_predicate_to_list(result: list, input_predicate: 'Predicate', keys: List[str]):
25+
if not input_predicate or not keys:
26+
return
27+
28+
if input_predicate.method == 'and':
29+
for sub_predicate in input_predicate.literals:
30+
extract_predicate_to_list(result, sub_predicate, keys)
31+
return
32+
elif input_predicate.method == 'or':
33+
# condition: involved keys all belong to primary keys
34+
involved_fields = _get_all_fields(input_predicate)
35+
if involved_fields and involved_fields.issubset(keys):
36+
result.append(input_predicate)
37+
return
38+
39+
if input_predicate.field in keys:
40+
result.append(input_predicate)
41+
42+
43+
def _get_all_fields(predicate: 'Predicate') -> Set[str]:
44+
if predicate.field is not None:
45+
return {predicate.field}
46+
involved_fields = set()
47+
if predicate.literals:
48+
for sub_predicate in predicate.literals:
49+
involved_fields.update(_get_all_fields(sub_predicate))
50+
return involved_fields
51+
52+
53+
def extract_predicate_to_dict(result: dict, input_predicate: 'Predicate', keys: List[str]):
54+
if not input_predicate or not keys:
55+
return
56+
57+
if input_predicate.method == 'and':
58+
for sub_predicate in input_predicate.literals:
59+
extract_predicate_to_dict(result, sub_predicate, keys)
60+
return
61+
elif input_predicate.method == 'or':
62+
# ensure no recursive and/or
63+
if not input_predicate.literals or any(p.field is None for p in input_predicate.literals):
64+
return
65+
# condition: only one key for 'or', and the key belongs to keys
66+
involved_fields = {p.field for p in input_predicate.literals}
67+
if len(involved_fields) == 1 and (field := involved_fields.pop()) in keys:
68+
result[field].append(input_predicate)
69+
return
70+
71+
if input_predicate.field in keys:
72+
result[input_predicate.field].append(input_predicate)

paimon-python/pypaimon/read/reader/format_avro_reader.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import fastavro
2222
import pyarrow as pa
23+
import pyarrow.compute as pc
2324
import pyarrow.dataset as ds
2425
from pyarrow import RecordBatch
2526

2627
from pypaimon.common.file_io import FileIO
27-
from pypaimon.common.predicate import Predicate
2828
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
2929
from pypaimon.schema.data_types import DataField, PyarrowFieldParser
3030

@@ -35,26 +35,18 @@ class FormatAvroReader(RecordBatchReader):
3535
provided predicate and projection, and converts Avro records to RecordBatch format.
3636
"""
3737

38-
def __init__(self, file_io: FileIO, file_path: str, primary_keys: List[str],
39-
fields: List[str], full_fields: List[DataField], predicate: Predicate, batch_size: int = 4096):
38+
def __init__(self, file_io: FileIO, file_path: str, read_fields: List[str], full_fields: List[DataField],
39+
push_down_predicate: pc.Expression | bool, batch_size: int = 4096):
4040
self._file = file_io.filesystem.open_input_file(file_path)
4141
self._avro_reader = fastavro.reader(self._file)
4242
self._batch_size = batch_size
43-
self._primary_keys = primary_keys
43+
self._push_down_predicate = push_down_predicate
4444

45-
self._fields = fields
45+
self._fields = read_fields
4646
full_fields_map = {field.name: field for field in full_fields}
47-
projected_data_fields = [full_fields_map[name] for name in fields]
47+
projected_data_fields = [full_fields_map[name] for name in read_fields]
4848
self._schema = PyarrowFieldParser.from_paimon_schema(projected_data_fields)
4949

50-
if primary_keys:
51-
# TODO: utilize predicate to improve performance
52-
predicate = None
53-
if predicate is not None:
54-
self._predicate = predicate.to_arrow()
55-
else:
56-
self._predicate = None
57-
5850
def read_arrow_batch(self) -> Optional[RecordBatch]:
5951
pydict_data = {name: [] for name in self._fields}
6052
records_in_batch = 0
@@ -68,12 +60,12 @@ def read_arrow_batch(self) -> Optional[RecordBatch]:
6860

6961
if records_in_batch == 0:
7062
return None
71-
if self._predicate is None:
63+
if self._push_down_predicate is None:
7264
return pa.RecordBatch.from_pydict(pydict_data, self._schema)
7365
else:
7466
pa_batch = pa.Table.from_pydict(pydict_data, self._schema)
7567
dataset = ds.InMemoryDataset(pa_batch)
76-
scanner = dataset.scanner(filter=self._predicate)
68+
scanner = dataset.scanner(filter=self._push_down_predicate)
7769
combine_chunks = scanner.to_table().combine_chunks()
7870
if combine_chunks.num_rows > 0:
7971
return combine_chunks.to_batches()[0]

0 commit comments

Comments
 (0)