Skip to content

Commit 9ac5ec5

Browse files
Fix cleanup ops for inferred schema (#403)
* Add filtering on inferred node types/rel types with empty string * Update Unit tests for schema * Update changelog * Filter out invalid labels/rel types
1 parent 4e1e783 commit 9ac5ec5

File tree

3 files changed

+42
-52
lines changed

3 files changed

+42
-52
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
- Fixed documentation for PdfLoader
88
- Fixed a bug where the `format` argument for `OllamaLLM` was not propagated to the client.
9+
- Fixed `AttributeError` in `SchemaFromTextExtractor` when filtering out node/relationship types with no labels.
910
- Fixed an import error in `VertexAIEmbeddings`.
1011

1112
## 1.9.0

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -498,49 +498,38 @@ def _filter_invalid_patterns(
498498

499499
return filtered_patterns
500500

501-
def _filter_nodes_without_labels(
502-
self, node_types: List[Dict[str, Any]]
501+
def _filter_items_without_labels(
502+
self, items: List[Dict[str, Any]], item_type: str
503503
) -> List[Dict[str, Any]]:
504-
"""
505-
Filter out node types that have no labels.
506-
507-
Args:
508-
node_types: List of node type definitions.
509-
510-
Returns:
511-
Filtered list of node types containing only those with valid labels.
512-
"""
513-
filtered_nodes = []
514-
for node_type in node_types:
515-
if node_type.get("label"):
516-
filtered_nodes.append(node_type)
504+
"""Filter out items that have no valid labels."""
505+
filtered_items = []
506+
for item in items:
507+
if isinstance(item, str):
508+
if item and " " not in item and not item.startswith("{"):
509+
filtered_items.append({"label": item})
510+
elif item:
511+
logging.info(
512+
f"Filtering out {item_type} with invalid label: {item}"
513+
)
514+
elif isinstance(item, dict) and item.get("label"):
515+
filtered_items.append(item)
517516
else:
518-
logging.info(f"Filtering out node type with missing label: {node_type}")
517+
logging.info(f"Filtering out {item_type} with missing label: {item}")
518+
return filtered_items
519519

520-
return filtered_nodes
520+
def _filter_nodes_without_labels(
521+
self, node_types: List[Dict[str, Any]]
522+
) -> List[Dict[str, Any]]:
523+
"""Filter out node types that have no labels."""
524+
return self._filter_items_without_labels(node_types, "node type")
521525

522526
def _filter_relationships_without_labels(
523527
self, relationship_types: List[Dict[str, Any]]
524528
) -> List[Dict[str, Any]]:
525-
"""
526-
Filter out relationship types that have no labels.
527-
528-
Args:
529-
relationship_types: List of relationship type definitions.
530-
531-
Returns:
532-
Filtered list of relationship types containing only those with valid labels.
533-
"""
534-
filtered_relationships = []
535-
for rel_type in relationship_types:
536-
if rel_type.get("label"):
537-
filtered_relationships.append(rel_type)
538-
else:
539-
logging.info(
540-
f"Filtering out relationship type with missing label: {rel_type}"
541-
)
542-
543-
return filtered_relationships
529+
"""Filter out relationship types that have no labels."""
530+
return self._filter_items_without_labels(
531+
relationship_types, "relationship type"
532+
)
544533

545534
@validate_call
546535
async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema:

tests/unit/experimental/components/test_schema.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -794,12 +794,11 @@ def schema_json_with_nodes_without_labels() -> str:
794794
{"name": "name", "type": "STRING"}
795795
]
796796
},
797-
{
798-
"label": "Organization",
799-
"properties": [
800-
{"name": "name", "type": "STRING"}
801-
]
802-
}
797+
"Organization",
798+
"",
799+
"Company",
800+
"Invalid description with spaces",
801+
"{\\"invalid\\": \\"json object\\"}"
803802
],
804803
"relationship_types": [
805804
{
@@ -852,12 +851,11 @@ def schema_json_with_relationships_without_labels() -> str:
852851
{"name": "since", "type": "DATE"}
853852
]
854853
},
855-
{
856-
"label": "MANAGES",
857-
"properties": [
858-
{"name": "since", "type": "DATE"}
859-
]
860-
}
854+
"MANAGES",
855+
"",
856+
"SUPERVISES",
857+
"invalid relationship description",
858+
"{\\"invalid\\": \\"json\\"}"
861859
],
862860
"patterns": [
863861
["Person", "WORKS_FOR", "Organization"],
@@ -921,10 +919,11 @@ async def test_schema_from_text_filters_nodes_without_labels(
921919
# run the schema extraction
922920
schema = await schema_from_text.run(text="Sample text for extraction")
923921

924-
# verify that nodes without labels were filtered out (2 out of 4 nodes should be removed)
925-
assert len(schema.node_types) == 2
922+
# verify that nodes without labels were filtered out (5 out of 8 nodes should be removed)
923+
assert len(schema.node_types) == 3
926924
assert schema.node_type_from_label("Person") is not None
927925
assert schema.node_type_from_label("Organization") is not None
926+
assert schema.node_type_from_label("Company") is not None
928927

929928
# verify that the pattern is still valid with the remaining nodes
930929
assert schema.patterns is not None
@@ -946,11 +945,12 @@ async def test_schema_from_text_filters_relationships_without_labels(
946945
# run the schema extraction
947946
schema = await schema_from_text.run(text="Sample text for extraction")
948947

949-
# verify that relationships without labels were filtered out (2 out of 4 relationships should be removed)
948+
# verify that relationships without labels were filtered out (5 out of 8 relationships should be removed)
950949
assert schema.relationship_types is not None
951-
assert len(schema.relationship_types) == 2
950+
assert len(schema.relationship_types) == 3
952951
assert schema.relationship_type_from_label("WORKS_FOR") is not None
953952
assert schema.relationship_type_from_label("MANAGES") is not None
953+
assert schema.relationship_type_from_label("SUPERVISES") is not None
954954

955955
# verify that the patterns are still valid with the remaining relationships
956956
assert schema.patterns is not None

0 commit comments

Comments
 (0)