1
- from collections import namedtuple
1
+ from collections import defaultdict , namedtuple
2
2
from typing import Any , NamedTuple , Optional , Tuple
3
3
4
4
from pandas import Series
@@ -72,6 +72,12 @@ def __str__(self) -> str:
72
72
return f"{ self .left_arrow } { self .type_filter } { self .right_arrow } (target{ self .label_filter } )"
73
73
74
74
75
+ class LabelPropertyMapping (NamedTuple ):
76
+ label : str
77
+ property_key : str
78
+ default_value : Optional [Any ] = None
79
+
80
+
75
81
class GraphCypherRunner (IllegalAttrChecker ):
76
82
def __init__ (self , query_runner : QueryRunner , namespace : str , server_version : ServerVersion ) -> None :
77
83
if server_version < ServerVersion (2 , 4 , 0 ):
@@ -131,6 +137,8 @@ def project(
131
137
right_arrow = "-" if inverse else "->" ,
132
138
)
133
139
140
+ label_mappings = defaultdict (list )
141
+
134
142
if nodes :
135
143
if len (nodes ) == 1 or combine_labels_with == "AND" :
136
144
match_pattern = match_pattern ._replace (label_filter = f":{ ':' .join (spec .source_label for spec in nodes )} " )
@@ -157,14 +165,22 @@ def project(
157
165
else :
158
166
raise ValueError (f"Invalid value for combine_labels_with: { combine_labels_with } " )
159
167
168
+ for spec in nodes :
169
+ if spec .properties :
170
+ for prop in spec .properties :
171
+ label_mappings [spec .source_label ].append (
172
+ LabelPropertyMapping (spec .source_label , prop .property_key , prop .default_value )
173
+ )
174
+
175
+ rel_var = ""
160
176
if rels :
161
177
if len (rels ) == 1 :
162
- rel_var = ""
163
178
data_config ["relationshipType" ] = rels [0 ].source_type
164
179
else :
165
180
rel_var = "rel"
166
181
data_config ["relationshipTypes" ] = "type(rel)"
167
182
data_config_is_static = False
183
+
168
184
match_pattern = match_pattern ._replace (
169
185
type_filter = f"[{ rel_var } :{ '|' .join (spec .source_type for spec in rels )} ]"
170
186
)
@@ -179,6 +195,24 @@ def project(
179
195
180
196
match_part = str (match_part )
181
197
198
+ case_part = []
199
+ if label_mappings :
200
+ with_rel = f", { rel_var } " if rel_var else ""
201
+ case_part = [f"WITH source, target{ with_rel } " ]
202
+ for kind in ["source" , "target" ]:
203
+ case_part .append ("CASE" )
204
+
205
+ for label , mappings in label_mappings .items ():
206
+ mappings = ", " .join (f".{ key .property_key } " for key in mappings )
207
+ when_part = f"WHEN '{ label } ' in labels({ kind } ) THEN [{ kind } {{{ mappings } }}]"
208
+ case_part .append (when_part )
209
+
210
+ case_part .append (f"END AS { kind } NodeProperties" )
211
+
212
+ data_config ["sourceNodeProperties" ] = "sourceNodeProperties"
213
+ data_config ["targetNodeProperties" ] = "targetNodeProperties"
214
+ data_config_is_static = False
215
+
182
216
args = ["$graph_name" , "source" , "target" ]
183
217
184
218
if data_config :
@@ -194,9 +228,7 @@ def project(
194
228
195
229
return_part = f"RETURN { self ._namespace } ({ ', ' .join (args )} )"
196
230
197
- query = "\n " .join (part for part in [match_part , return_part ] if part )
198
-
199
- print (query )
231
+ query = "\n " .join (part for part in [match_part , * case_part , return_part ] if part )
200
232
201
233
result = self ._query_runner .run_query_with_logging (
202
234
query ,
@@ -218,16 +250,39 @@ def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
218
250
if isinstance (spec , dict ):
219
251
return [self ._node_projection_spec (node , name ) for name , node in spec .items ()]
220
252
221
- raise TypeError (f"Invalid node projection specification: { spec } " )
253
+ raise TypeError (f"Invalid node projections specification: { spec } " )
222
254
223
255
def _node_projection_spec (self , spec : Any , name : Optional [str ] = None ) -> NodeProjection :
224
256
if isinstance (spec , str ):
225
257
return NodeProjection (name = name or spec , source_label = spec )
226
258
259
+ if name is None :
260
+ raise ValueError (f"Node projections with properties must use the dict syntax: { spec } " )
261
+
262
+ if isinstance (spec , dict ):
263
+ properties = [self ._node_properties_spec (prop , name ) for name , prop in spec .items ()]
264
+ return NodeProjection (name = name , source_label = name , properties = properties )
265
+
266
+ if isinstance (spec , list ):
267
+ properties = [self ._node_properties_spec (prop ) for prop in spec ]
268
+ return NodeProjection (name = name , source_label = name , properties = properties )
269
+
227
270
raise TypeError (f"Invalid node projection specification: { spec } " )
228
271
229
- def _node_properties_spec (self , properties : dict [str , Any ]) -> list [NodeProperty ]:
230
- raise TypeError (f"Invalid node projection specification: { properties } " )
272
+ def _node_properties_spec (self , spec : Any , name : Optional [str ] = None ) -> NodeProperty :
273
+ if isinstance (spec , str ):
274
+ return NodeProperty (name = name or spec , property_key = spec )
275
+
276
+ if name is None :
277
+ raise ValueError (f"Node properties spec must be used with the dict syntax: { spec } " )
278
+
279
+ if spec is True :
280
+ return NodeProperty (name = name , property_key = name )
281
+
282
+ if isinstance (spec , dict ):
283
+ return NodeProperty (name = name , property_key = name , ** spec )
284
+
285
+ raise TypeError (f"Invalid node property specification: { spec } " )
231
286
232
287
def _rel_projections_spec (self , spec : Any ) -> list [RelationshipProjection ]:
233
288
if spec is None or spec is False :
0 commit comments