Skip to content

Commit aa27f33

Browse files
committed
Add basic node property support
1 parent 70dd008 commit aa27f33

File tree

2 files changed

+94
-8
lines changed

2 files changed

+94
-8
lines changed

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import namedtuple
1+
from collections import defaultdict, namedtuple
22
from typing import Any, NamedTuple, Optional, Tuple
33

44
from pandas import Series
@@ -72,6 +72,12 @@ def __str__(self) -> str:
7272
return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})"
7373

7474

75+
class LabelPropertyMapping(NamedTuple):
76+
label: str
77+
property_key: str
78+
default_value: Optional[Any] = None
79+
80+
7581
class GraphCypherRunner(IllegalAttrChecker):
7682
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None:
7783
if server_version < ServerVersion(2, 4, 0):
@@ -131,6 +137,8 @@ def project(
131137
right_arrow="-" if inverse else "->",
132138
)
133139

140+
label_mappings = defaultdict(list)
141+
134142
if nodes:
135143
if len(nodes) == 1 or combine_labels_with == "AND":
136144
match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}")
@@ -157,14 +165,22 @@ def project(
157165
else:
158166
raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}")
159167

168+
for spec in nodes:
169+
if spec.properties:
170+
for prop in spec.properties:
171+
label_mappings[spec.source_label].append(
172+
LabelPropertyMapping(spec.source_label, prop.property_key, prop.default_value)
173+
)
174+
175+
rel_var = ""
160176
if rels:
161177
if len(rels) == 1:
162-
rel_var = ""
163178
data_config["relationshipType"] = rels[0].source_type
164179
else:
165180
rel_var = "rel"
166181
data_config["relationshipTypes"] = "type(rel)"
167182
data_config_is_static = False
183+
168184
match_pattern = match_pattern._replace(
169185
type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]"
170186
)
@@ -179,6 +195,24 @@ def project(
179195

180196
match_part = str(match_part)
181197

198+
case_part = []
199+
if label_mappings:
200+
with_rel = f", {rel_var}" if rel_var else ""
201+
case_part = [f"WITH source, target{with_rel}"]
202+
for kind in ["source", "target"]:
203+
case_part.append("CASE")
204+
205+
for label, mappings in label_mappings.items():
206+
mappings = ", ".join(f".{key.property_key}" for key in mappings)
207+
when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mappings}}}]"
208+
case_part.append(when_part)
209+
210+
case_part.append(f"END AS {kind}NodeProperties")
211+
212+
data_config["sourceNodeProperties"] = "sourceNodeProperties"
213+
data_config["targetNodeProperties"] = "targetNodeProperties"
214+
data_config_is_static = False
215+
182216
args = ["$graph_name", "source", "target"]
183217

184218
if data_config:
@@ -194,9 +228,7 @@ def project(
194228

195229
return_part = f"RETURN {self._namespace}({', '.join(args)})"
196230

197-
query = "\n".join(part for part in [match_part, return_part] if part)
198-
199-
print(query)
231+
query = "\n".join(part for part in [match_part, *case_part, return_part] if part)
200232

201233
result = self._query_runner.run_query_with_logging(
202234
query,
@@ -218,16 +250,39 @@ def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
218250
if isinstance(spec, dict):
219251
return [self._node_projection_spec(node, name) for name, node in spec.items()]
220252

221-
raise TypeError(f"Invalid node projection specification: {spec}")
253+
raise TypeError(f"Invalid node projections specification: {spec}")
222254

223255
def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection:
224256
if isinstance(spec, str):
225257
return NodeProjection(name=name or spec, source_label=spec)
226258

259+
if name is None:
260+
raise ValueError(f"Node projections with properties must use the dict syntax: {spec}")
261+
262+
if isinstance(spec, dict):
263+
properties = [self._node_properties_spec(prop, name) for name, prop in spec.items()]
264+
return NodeProjection(name=name, source_label=name, properties=properties)
265+
266+
if isinstance(spec, list):
267+
properties = [self._node_properties_spec(prop) for prop in spec]
268+
return NodeProjection(name=name, source_label=name, properties=properties)
269+
227270
raise TypeError(f"Invalid node projection specification: {spec}")
228271

229-
def _node_properties_spec(self, properties: dict[str, Any]) -> list[NodeProperty]:
230-
raise TypeError(f"Invalid node projection specification: {properties}")
272+
def _node_properties_spec(self, spec: Any, name: Optional[str] = None) -> NodeProperty:
273+
if isinstance(spec, str):
274+
return NodeProperty(name=name or spec, property_key=spec)
275+
276+
if name is None:
277+
raise ValueError(f"Node properties spec must be used with the dict syntax: {spec}")
278+
279+
if spec is True:
280+
return NodeProperty(name=name, property_key=name)
281+
282+
if isinstance(spec, dict):
283+
return NodeProperty(name=name, property_key=name, **spec)
284+
285+
raise TypeError(f"Invalid node property specification: {spec}")
231286

232287
def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]:
233288
if spec is None or spec is False:

graphdatascience/tests/unit/test_graph_cypher.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,34 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien
216216
"targetNodeLabels: labels(target), "
217217
"relationshipTypes: type(rel)})"
218218
)
219+
220+
221+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
222+
def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
223+
G, _ = gds.graph.cypher.project(
224+
"g", nodes=dict(L1=["prop1"], L2=["prop2", "prop3"], L3=dict(prop4=True, prop5=dict()))
225+
)
226+
227+
assert G.name() == "g"
228+
assert runner.last_params() == dict(graph_name="g")
229+
230+
assert runner.last_query() == (
231+
"""MATCH (source)-->(target)
232+
WHERE (source:L1 OR source:L2 OR source:L3) AND (target:L1 OR target:L2 OR target:L3)
233+
WITH source, target
234+
CASE
235+
WHEN 'L1' in labels(source) THEN [source {.prop1}]
236+
WHEN 'L2' in labels(source) THEN [source {.prop2, .prop3}]
237+
WHEN 'L3' in labels(source) THEN [source {.prop4, .prop5}]
238+
END AS sourceNodeProperties
239+
CASE
240+
WHEN 'L1' in labels(target) THEN [target {.prop1}]
241+
WHEN 'L2' in labels(target) THEN [target {.prop2, .prop3}]
242+
WHEN 'L3' in labels(target) THEN [target {.prop4, .prop5}]
243+
END AS targetNodeProperties
244+
RETURN gds.graph.project($graph_name, source, target, {"""
245+
"sourceNodeLabels: labels(source), "
246+
"targetNodeLabels: labels(target), "
247+
"sourceNodeProperties: sourceNodeProperties, "
248+
"targetNodeProperties: targetNodeProperties})"
249+
)

0 commit comments

Comments
 (0)