Skip to content

Commit 70dd008

Browse files
committed
Move non-static data config from params into query
1 parent 53c4733 commit 70dd008

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def project(
120120
query_params = {"graph_name": graph_name}
121121

122122
data_config = {}
123+
data_config_is_static = True
123124

124125
nodes = self._node_projections_spec(nodes)
125126
rels = self._rel_projections_spec(relationships)
@@ -152,6 +153,7 @@ def project(
152153

153154
data_config["sourceNodeLabels"] = "labels(source)"
154155
data_config["targetNodeLabels"] = "labels(target)"
156+
data_config_is_static = False
155157
else:
156158
raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}")
157159

@@ -162,6 +164,7 @@ def project(
162164
else:
163165
rel_var = "rel"
164166
data_config["relationshipTypes"] = "type(rel)"
167+
data_config_is_static = False
165168
match_pattern = match_pattern._replace(
166169
type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]"
167170
)
@@ -179,8 +182,11 @@ def project(
179182
args = ["$graph_name", "source", "target"]
180183

181184
if data_config:
182-
query_params["data_config"] = data_config
183-
args += ["$data_config"]
185+
if data_config_is_static:
186+
query_params["data_config"] = data_config
187+
args += ["$data_config"]
188+
else:
189+
args += [self._render_map(data_config)]
184190

185191
if config:
186192
query_params["config"] = config
@@ -247,6 +253,9 @@ def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> Relatio
247253
def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]:
248254
raise TypeError(f"Invalid relationship projection specification: {properties}")
249255

256+
def _render_map(self, mapping: dict[str, Any]) -> str:
257+
return "{" + ", ".join(f"{key}: {value}" for key, value in mapping.items()) + "}"
258+
250259
#
251260
# def estimate(self, *, nodes: Any, relationships: Any, **config: Any) -> "Series[Any]":
252261
# pass

graphdatascience/tests/unit/test_graph_cypher.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,14 @@ def test_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataSc
136136
G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR")
137137

138138
assert G.name() == "g"
139-
assert runner.last_params() == dict(
140-
graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"}
141-
)
139+
assert runner.last_params() == dict(graph_name="g")
142140

143-
assert (
144-
runner.last_query()
145-
== """MATCH (source)-->(target)
141+
assert runner.last_query() == (
142+
"""MATCH (source)-->(target)
146143
WHERE (source:A OR source:B) AND (target:A OR target:B)
147-
RETURN gds.graph.project($graph_name, source, target, $data_config)"""
144+
RETURN gds.graph.project($graph_name, source, target, {"""
145+
"sourceNodeLabels: labels(source), "
146+
"targetNodeLabels: labels(target)})"
148147
)
149148

150149

@@ -153,17 +152,16 @@ def test_disconnected_nodes_multiple_node_labels_or(runner: CollectingQueryRunne
153152
G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR", allow_disconnected_nodes=True)
154153

155154
assert G.name() == "g"
156-
assert runner.last_params() == dict(
157-
graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"}
158-
)
155+
assert runner.last_params() == dict(graph_name="g")
159156

160-
assert (
161-
runner.last_query()
162-
== """MATCH (source)
157+
assert runner.last_query() == (
158+
"""MATCH (source)
163159
WHERE source:A OR source:B
164160
OPTIONAL MATCH (source)-->(target)
165161
WHERE target:A OR target:B
166-
RETURN gds.graph.project($graph_name, source, target, $data_config)"""
162+
RETURN gds.graph.project($graph_name, source, target, {"""
163+
"sourceNodeLabels: labels(source), "
164+
"targetNodeLabels: labels(target)})"
167165
)
168166

169167

@@ -207,18 +205,14 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien
207205
G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], relationships=["REL1", "REL2"])
208206

209207
assert G.name() == "g"
210-
assert runner.last_params() == dict(
211-
graph_name="g",
212-
data_config={
213-
"sourceNodeLabels": "labels(source)",
214-
"targetNodeLabels": "labels(target)",
215-
"relationshipTypes": "type(rel)",
216-
},
217-
)
208+
assert runner.last_params() == dict(graph_name="g")
218209

219210
assert (
220211
runner.last_query()
221212
== """MATCH (source)-[rel:REL1|REL2]->(target)
222213
WHERE (source:A OR source:B) AND (target:A OR target:B)
223-
RETURN gds.graph.project($graph_name, source, target, $data_config)"""
214+
RETURN gds.graph.project($graph_name, source, target, {"""
215+
"sourceNodeLabels: labels(source), "
216+
"targetNodeLabels: labels(target), "
217+
"relationshipTypes: type(rel)})"
224218
)

0 commit comments

Comments
 (0)