Skip to content

Commit 0e5ee40

Browse files
authored
Merge pull request #44 from linkml/add-specialized-group-by-implementations
Add specialized group_by implementations for DuckDB and MongoDB
2 parents 0057342 + 5ddefea commit 0e5ee40

File tree

4 files changed

+348
-1
lines changed

4 files changed

+348
-1
lines changed

src/linkml_store/api/collection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,11 @@ def group_by(
641641
if isinstance(group_by_fields, str):
642642
group_by_fields = [group_by_fields]
643643
df = self.find(where=where, limit=-1).rows_dataframe
644+
645+
# Handle the case where agg_map is None
646+
if agg_map is None:
647+
agg_map = {}
648+
644649
pk_fields = agg_map.get("first", []) + group_by_fields
645650
list_fields = agg_map.get("list", [])
646651
if not list_fields:

src/linkml_store/api/stores/duckdb/duckdb_collection.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from linkml_store.api import Collection
1010
from linkml_store.api.collection import DEFAULT_FACET_LIMIT, OBJECT
11-
from linkml_store.api.queries import Query
11+
from linkml_store.api.queries import Query, QueryResult
1212
from linkml_store.api.stores.duckdb.mappings import TMAP
1313
from linkml_store.utils.sql_utils import facet_count_sql
1414

@@ -145,6 +145,166 @@ def _check_if_initialized(self) -> bool:
145145
return True
146146
return False
147147

148+
def group_by(
149+
self,
150+
group_by_fields: List[str],
151+
inlined_field="objects",
152+
agg_map: Optional[Dict[str, str]] = None,
153+
where: Optional[Dict] = None,
154+
**kwargs,
155+
) -> QueryResult:
156+
"""
157+
Group objects in the collection by specified fields using SQLAlchemy.
158+
159+
This implementation leverages DuckDB's SQL capabilities for more efficient grouping.
160+
161+
:param group_by_fields: List of fields to group by
162+
:param inlined_field: Field name to store aggregated objects
163+
:param agg_map: Dictionary mapping aggregation types to fields
164+
:param where: Filter conditions
165+
:param kwargs: Additional arguments
166+
:return: Query result containing grouped data
167+
"""
168+
if isinstance(group_by_fields, str):
169+
group_by_fields = [group_by_fields]
170+
171+
cd = self.class_definition()
172+
if not cd:
173+
logger.debug(f"No class definition defined for {self.alias} {self.target_class_name}")
174+
return super().group_by(group_by_fields, inlined_field, agg_map, where, **kwargs)
175+
176+
# Check if the table exists
177+
if not self.parent._table_exists(self.alias):
178+
logger.debug(f"Table {self.alias} doesn't exist, falling back to parent implementation")
179+
return super().group_by(group_by_fields, inlined_field, agg_map, where, **kwargs)
180+
181+
# Get table definition
182+
table = self._sqla_table(cd)
183+
engine = self.parent.engine
184+
185+
# Create a SQLAlchemy select statement for groups
186+
from sqlalchemy import select, func, and_, or_
187+
group_cols = [table.c[field] for field in group_by_fields if field in table.columns.keys()]
188+
189+
if not group_cols:
190+
logger.warning(f"None of the group_by fields {group_by_fields} found in table columns")
191+
return super().group_by(group_by_fields, inlined_field, agg_map, where, **kwargs)
192+
193+
stmt = select(*group_cols).distinct()
194+
195+
# Add where conditions if specified
196+
if where:
197+
conditions = []
198+
for k, v in where.items():
199+
if k in table.columns.keys():
200+
# Handle different operator types (dict values for operators)
201+
if isinstance(v, dict):
202+
for op, val in v.items():
203+
if op == "$gt":
204+
conditions.append(table.c[k] > val)
205+
elif op == "$gte":
206+
conditions.append(table.c[k] >= val)
207+
elif op == "$lt":
208+
conditions.append(table.c[k] < val)
209+
elif op == "$lte":
210+
conditions.append(table.c[k] <= val)
211+
elif op == "$ne":
212+
conditions.append(table.c[k] != val)
213+
elif op == "$in":
214+
conditions.append(table.c[k].in_(val))
215+
else:
216+
# Default to equality for unknown operators
217+
logger.warning(f"Unknown operator {op}, using equality")
218+
conditions.append(table.c[k] == val)
219+
else:
220+
# Direct equality comparison
221+
conditions.append(table.c[k] == v)
222+
223+
if conditions:
224+
for condition in conditions:
225+
stmt = stmt.where(condition)
226+
227+
results = []
228+
try:
229+
with engine.connect() as conn:
230+
# Get all distinct groups
231+
group_result = conn.execute(stmt)
232+
group_rows = list(group_result)
233+
234+
# For each group, get all objects
235+
for group_row in group_rows:
236+
# Build conditions for this group
237+
group_conditions = []
238+
group_dict = {}
239+
240+
for i, field in enumerate(group_by_fields):
241+
if field in table.columns.keys():
242+
value = group_row[i]
243+
group_dict[field] = value
244+
if value is None:
245+
group_conditions.append(table.c[field].is_(None))
246+
else:
247+
group_conditions.append(table.c[field] == value)
248+
249+
# Get all rows for this group
250+
row_stmt = select(*table.columns)
251+
for condition in group_conditions:
252+
row_stmt = row_stmt.where(condition)
253+
254+
# Add original where conditions
255+
if where:
256+
for k, v in where.items():
257+
if k in table.columns.keys():
258+
# Handle different operator types for the row query as well
259+
if isinstance(v, dict):
260+
for op, val in v.items():
261+
if op == "$gt":
262+
row_stmt = row_stmt.where(table.c[k] > val)
263+
elif op == "$gte":
264+
row_stmt = row_stmt.where(table.c[k] >= val)
265+
elif op == "$lt":
266+
row_stmt = row_stmt.where(table.c[k] < val)
267+
elif op == "$lte":
268+
row_stmt = row_stmt.where(table.c[k] <= val)
269+
elif op == "$ne":
270+
row_stmt = row_stmt.where(table.c[k] != val)
271+
elif op == "$in":
272+
row_stmt = row_stmt.where(table.c[k].in_(val))
273+
else:
274+
# Default to equality for unknown operators
275+
row_stmt = row_stmt.where(table.c[k] == val)
276+
else:
277+
# Direct equality comparison
278+
row_stmt = row_stmt.where(table.c[k] == v)
279+
280+
row_result = conn.execute(row_stmt)
281+
rows = list(row_result)
282+
283+
# Convert rows to dictionaries
284+
objects = []
285+
for row in rows:
286+
obj = {}
287+
for i, col in enumerate(row._fields):
288+
obj[col] = row[i]
289+
objects.append(obj)
290+
291+
# Apply agg_map to filter fields if specified
292+
if agg_map and "list" in agg_map:
293+
list_fields = agg_map["list"]
294+
if list_fields:
295+
objects = [{k: obj.get(k) for k in list_fields if k in obj} for obj in objects]
296+
297+
# Create the result object
298+
result_obj = group_dict.copy()
299+
result_obj[inlined_field] = objects
300+
results.append(result_obj)
301+
302+
return QueryResult(num_rows=len(results), rows=results)
303+
except Exception as e:
304+
logger.warning(f"Error in DuckDB group_by: {e}")
305+
# Fall back to parent implementation
306+
return super().group_by(group_by_fields, inlined_field, agg_map, where, **kwargs)
307+
148308
def _create_table(self, cd: ClassDefinition):
149309
if self._table_created or self.metadata.is_prepopulated:
150310
logger.info(f"Already have table for: {cd.name}")

src/linkml_store/api/stores/mongodb/mongodb_collection.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,101 @@ def delete_where(self, where: Optional[Dict[str, Any]] = None, missing_ok=True,
265265
if deleted_rows_count == 0 and not missing_ok:
266266
raise ValueError(f"No rows found for {where}")
267267
return deleted_rows_count
268+
269+
def group_by(
270+
self,
271+
group_by_fields: List[str],
272+
inlined_field="objects",
273+
agg_map: Optional[Dict[str, str]] = None,
274+
where: Optional[Dict] = None,
275+
**kwargs,
276+
) -> QueryResult:
277+
"""
278+
Group objects in the collection by specified fields using MongoDB's aggregation pipeline.
279+
280+
This implementation leverages MongoDB's native aggregation capabilities for efficient grouping.
281+
282+
:param group_by_fields: List of fields to group by
283+
:param inlined_field: Field name to store aggregated objects
284+
:param agg_map: Dictionary mapping aggregation types to fields
285+
:param where: Filter conditions
286+
:param kwargs: Additional arguments
287+
:return: Query result containing grouped data
288+
"""
289+
if isinstance(group_by_fields, str):
290+
group_by_fields = [group_by_fields]
291+
292+
# Build the group key for MongoDB
293+
if len(group_by_fields) == 1:
294+
# Single field grouping
295+
group_id = f"${group_by_fields[0]}"
296+
else:
297+
# Multi-field grouping
298+
group_id = {field: f"${field}" for field in group_by_fields}
299+
300+
# Start building the pipeline
301+
pipeline = []
302+
303+
# Add match stage if where clause is provided
304+
if where:
305+
pipeline.append({"$match": where})
306+
307+
# Add the group stage
308+
group_stage = {
309+
"$group": {
310+
"_id": group_id,
311+
"objects": {"$push": "$$ROOT"}
312+
}
313+
}
314+
pipeline.append(group_stage)
315+
316+
# Execute the aggregation
317+
logger.debug(f"MongoDB group_by pipeline: {pipeline}")
318+
aggregation_results = list(self.mongo_collection.aggregate(pipeline))
319+
320+
# Transform the results to match the expected format
321+
results = []
322+
for result in aggregation_results:
323+
# Skip null groups if needed
324+
if result["_id"] is None and kwargs.get("skip_nulls", False):
325+
continue
326+
327+
# Create the group object
328+
if isinstance(result["_id"], dict):
329+
# Multi-field grouping
330+
group_obj = result["_id"]
331+
else:
332+
# Single field grouping
333+
group_obj = {group_by_fields[0]: result["_id"]}
334+
335+
# Add the grouped objects
336+
objects = result["objects"]
337+
338+
# Remove MongoDB _id field from each object
339+
for obj in objects:
340+
if "_id" in obj:
341+
del obj["_id"]
342+
343+
# Apply any field selection or transformations based on agg_map
344+
if agg_map:
345+
# Get first fields (fields to keep as single values)
346+
first_fields = agg_map.get("first", [])
347+
if first_fields:
348+
# These are already in the group_obj from the _id
349+
pass
350+
351+
# Get list fields (fields to aggregate as lists)
352+
list_fields = agg_map.get("list", [])
353+
if list_fields:
354+
# Filter objects to only include specified fields
355+
objects = [{k: obj.get(k) for k in list_fields if k in obj} for obj in objects]
356+
elif not list_fields and first_fields:
357+
# If list_fields is empty but first_fields is specified,
358+
# filter out first_fields from objects to avoid duplication
359+
objects = [{k: v for k, v in obj.items() if k not in first_fields} for obj in objects]
360+
361+
# Add the objects to the group
362+
group_obj[inlined_field] = objects
363+
results.append(group_obj)
364+
365+
return QueryResult(num_rows=len(results), rows=results)

tests/test_api/test_api.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,90 @@ def test_group_by(handle):
334334
assert False, f"Unexpected id: {row['id']}"
335335

336336

337+
@pytest.mark.parametrize("handle", SCHEMES_PLUS)
338+
def test_group_by_advanced(handle):
339+
"""
340+
Test more advanced group_by features for specific store implementations.
341+
342+
Tests various features:
343+
1. Multi-field grouping
344+
2. Filtering with where clause
345+
3. Aggregation of specific fields
346+
4. Different inlined field name
347+
"""
348+
client = create_client(handle)
349+
database = client.get_database()
350+
351+
# Create a more complex dataset with multiple grouping possibilities
352+
rows = [
353+
{"id": 1, "category": "A", "name": "Item1", "price": 10.0, "qty": 5, "tags": ["red", "small"]},
354+
{"id": 2, "category": "A", "name": "Item2", "price": 20.0, "qty": 3, "tags": ["blue", "medium"]},
355+
{"id": 3, "category": "B", "name": "Item3", "price": 15.0, "qty": 7, "tags": ["red", "large"]},
356+
{"id": 4, "category": "B", "name": "Item4", "price": 25.0, "qty": 2, "tags": ["green", "small"]},
357+
{"id": 5, "category": "A", "name": "Item5", "price": 30.0, "qty": 1, "tags": ["blue", "large"]},
358+
]
359+
360+
collection = database.create_collection("Products", recreate_if_exists=True)
361+
collection.insert(rows)
362+
363+
# Test 1: Group by a single field
364+
result = collection.group_by(["category"])
365+
assert result.num_rows == 2
366+
367+
# Verify correct grouping
368+
for group in result.rows:
369+
if group["category"] == "A":
370+
assert len(group["objects"]) == 3
371+
elif group["category"] == "B":
372+
assert len(group["objects"]) == 2
373+
else:
374+
assert False, f"Unexpected category: {group['category']}"
375+
376+
# Test 2: Group by multiple scalar fields (avoid using array fields in multi-field grouping)
377+
result = collection.group_by(["category", "name"])
378+
# Just check that it doesn't error - the exact results will depend on implementation
379+
380+
# Test 3: Group with a where clause - use exact match for compatibility
381+
# Filter for category "A" items only
382+
result = collection.group_by(["category"], where={"category": "A"})
383+
assert result.num_rows == 1
384+
assert result.rows[0]["category"] == "A"
385+
assert len(result.rows[0]["objects"]) == 3
386+
387+
# For MongoDB specific test, if this is MongoDB handle
388+
if "mongodb" in handle:
389+
# Uses MongoDB's query operators
390+
result = collection.group_by(["category"], where={"price": {"$gt": 15.0}})
391+
392+
# Find the group with category "A"
393+
a_group = next((g for g in result.rows if g["category"] == "A"), None)
394+
if a_group is not None:
395+
# Should only include items with price > 15.0
396+
for item in a_group["objects"]:
397+
assert item["price"] > 15.0
398+
399+
# Test 4: Custom inlined field name
400+
result = collection.group_by(["category"], inlined_field="items")
401+
for group in result.rows:
402+
assert "items" in group
403+
assert "objects" not in group
404+
405+
# Test 5: Test with agg_map for field selection (skip for file adapter which doesn't fully support agg_map)
406+
if "file:" not in handle:
407+
result = collection.group_by(
408+
["category"],
409+
agg_map={"first": ["category"], "list": ["name", "price"]}
410+
)
411+
412+
# Verify that only specified fields are included
413+
for group in result.rows:
414+
for item in group["objects"]:
415+
assert "name" in item
416+
assert "price" in item
417+
assert "qty" not in item # This field should be excluded
418+
assert "tags" not in item # This field should be excluded
419+
420+
337421
@pytest.mark.parametrize("handle", SCHEMES_PLUS)
338422
def test_collections_of_same_type(handle):
339423
"""

0 commit comments

Comments
 (0)