1
1
package org .neo4j .spark .service
2
2
3
3
import org .apache .commons .lang3 .StringUtils
4
+ import org .apache .spark .internal .Logging
4
5
import org .apache .spark .sql .SaveMode
6
+ import org .apache .spark .sql .connector .expressions .{SortDirection , SortOrder }
5
7
import org .apache .spark .sql .connector .expressions .aggregate .{AggregateFunc , Count , CountStar , Max , Min , Sum }
6
8
import org .apache .spark .sql .sources .{And , Filter , Or }
7
9
import org .neo4j .cypherdsl .core ._
@@ -102,21 +104,28 @@ class Neo4jQueryWriteStrategy(private val saveMode: SaveMode) extends Neo4jQuery
102
104
}
103
105
104
106
class Neo4jQueryReadStrategy (filters : Array [Filter ] = Array .empty[Filter ],
105
- partitionSkipLimit : PartitionSkipLimit = PartitionSkipLimit .EMPTY ,
107
+ partitionPagination : PartitionPagination = PartitionPagination .EMPTY ,
106
108
requiredColumns : Seq [String ] = Seq .empty,
107
109
aggregateColumns : Array [AggregateFunc ] = Array .empty,
108
- jobId : String = " " ) extends Neo4jQueryStrategy {
110
+ jobId : String = " " ) extends Neo4jQueryStrategy with Logging {
109
111
private val renderer : Renderer = Renderer .getDefaultRenderer
110
112
111
- private val hasSkipLimit : Boolean = partitionSkipLimit .skip != - 1 && partitionSkipLimit .limit != - 1
113
+ private val hasSkipLimit : Boolean = partitionPagination .skip != - 1 && partitionPagination.topN .limit != - 1
112
114
113
115
override def createStatementForQuery (options : Neo4jOptions ): String = {
116
+ if (partitionPagination.topN.orders.nonEmpty) {
117
+ logWarning(
118
+ s """ Top N push-down optimizations with aggregations are not supported for custom queries.
119
+ |\tThese aggregations are going to be ignored.
120
+ |\tPlease specify the aggregations in the custom query directly """ .stripMargin)
121
+ }
114
122
val limitedQuery = if (hasSkipLimit) {
115
123
s """ ${options.query.value}
116
- |SKIP ${partitionSkipLimit .skip} LIMIT ${partitionSkipLimit .limit}
124
+ |SKIP ${partitionPagination .skip} LIMIT ${partitionPagination.topN .limit}
117
125
| """ .stripMargin
118
126
} else {
119
- options.query.value
127
+ s """ ${options.query.value}
128
+ | """ .stripMargin
120
129
}
121
130
s """ WITH ${" $" }scriptResult AS ${Neo4jQueryStrategy .VARIABLE_SCRIPT_RESULT }
122
131
| $limitedQuery""" .stripMargin
@@ -130,16 +139,39 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
130
139
.named(Neo4jUtil .RELATIONSHIP_ALIAS )
131
140
132
141
val matchQuery : StatementBuilder .OngoingReadingWithoutWhere = filterRelationship(sourceNode, targetNode, relationship)
133
-
134
142
val returnExpressions : Seq [Expression ] = buildReturnExpression(sourceNode, targetNode, relationship)
135
143
val stmt = if (aggregateColumns.isEmpty) {
136
- buildStatement(options, matchQuery.returning(returnExpressions : _* ), relationship)
144
+ val query = matchQuery.returning(returnExpressions : _* )
145
+ buildStatement(options, query, relationship)
137
146
} else {
138
147
buildStatementAggregation(options, matchQuery, relationship, returnExpressions)
139
148
}
140
149
renderer.render(stmt)
141
150
}
142
151
152
+ private def convertSort (entity : PropertyContainer , order : SortOrder ): SortItem = {
153
+ val sortExpression = order.expression().describe()
154
+
155
+ val container : Option [PropertyContainer ] = entity match {
156
+ case relationship : Relationship =>
157
+ if (sortExpression.contains(s " ${Neo4jUtil .RELATIONSHIP_SOURCE_ALIAS }. " )) {
158
+ Some (relationship.getLeft)
159
+ } else if (sortExpression.contains(s " ${Neo4jUtil .RELATIONSHIP_TARGET_ALIAS }. " )) {
160
+ Some (relationship.getRight)
161
+ } else if (sortExpression.contains(s " ${Neo4jUtil .RELATIONSHIP_ALIAS }. " )) {
162
+ Some (relationship)
163
+ } else {
164
+ None
165
+ }
166
+ case _ => Some (entity)
167
+ }
168
+ val direction = if (order.direction() == SortDirection .ASCENDING ) SortItem .Direction .ASC else SortItem .Direction .DESC
169
+
170
+ Cypher .sort(container
171
+ .map(_.property(sortExpression.removeAlias()))
172
+ .getOrElse(Cypher .name(sortExpression.unquote())), direction)
173
+ }
174
+
143
175
private def buildReturnExpression (sourceNode : Node , targetNode : Node , relationship : Relationship ): Seq [Expression ] = {
144
176
if (requiredColumns.isEmpty) {
145
177
Seq (relationship.getRequiredSymbolicName, sourceNode.as(Neo4jUtil .RELATIONSHIP_SOURCE_ALIAS ), targetNode.as(Neo4jUtil .RELATIONSHIP_TARGET_ALIAS ))
@@ -186,13 +218,13 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
186
218
}
187
219
query
188
220
.`with`(entity)
189
- // Spark does not push down limits when aggregation is involved
221
+ // Spark does not push down limits/top N when aggregation is involved
190
222
.orderBy(id)
191
- .skip(partitionSkipLimit .skip)
192
- .limit(partitionSkipLimit .limit)
223
+ .skip(partitionPagination .skip)
224
+ .limit(partitionPagination.topN .limit)
193
225
.returning(fields : _* )
194
226
} else {
195
- val orderByProp = options.orderBy
227
+ val orderByProp = options.streamingOrderBy
196
228
if (StringUtils .isBlank(orderByProp)) {
197
229
query.returning(fields : _* )
198
230
} else {
@@ -207,37 +239,40 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
207
239
}
208
240
209
241
private def buildStatement (options : Neo4jOptions ,
210
- returning : StatementBuilder .OngoingReadingAndReturn ,
242
+ returning : StatementBuilder .TerminalExposesSkip
243
+ with StatementBuilder .TerminalExposesLimit
244
+ with StatementBuilder .TerminalExposesOrderBy
245
+ with StatementBuilder .BuildableStatement [_],
211
246
entity : PropertyContainer = null ): Statement = {
212
247
213
248
def addSkipLimit (ret : StatementBuilder .TerminalExposesSkip
214
- with StatementBuilder .TerminalExposesLimit
215
- with StatementBuilder .BuildableStatement [_]) = {
249
+ with StatementBuilder .TerminalExposesLimit
250
+ with StatementBuilder .BuildableStatement [_]) = {
216
251
217
- if (partitionSkipLimit .skip == 0 ) {
218
- ret.limit(partitionSkipLimit .limit)
252
+ if (partitionPagination .skip == 0 ) {
253
+ ret.limit(partitionPagination.topN .limit)
219
254
}
220
255
else {
221
- ret.skip(partitionSkipLimit .skip). asInstanceOf [ StatementBuilder . TerminalExposesLimit ]
222
- .limit(partitionSkipLimit .limit)
256
+ ret.skip(partitionPagination .skip)
257
+ .limit(partitionPagination.topN .limit)
223
258
}
224
259
}
225
260
226
261
val ret = if (entity == null ) {
227
262
if (hasSkipLimit) addSkipLimit(returning) else returning
228
263
} else {
229
264
if (hasSkipLimit) {
230
- val id = entity match {
231
- case node : Node => Functions .id(node)
232
- case rel : Relationship => Functions .id(rel)
233
- }
234
- if (options.partitions == 1 ) {
235
- addSkipLimit(returning)
265
+ if (options.partitions == 1 || partitionPagination.topN.orders.nonEmpty) {
266
+ addSkipLimit(returning.orderBy(partitionPagination.topN.orders.map(order => convertSort(entity, order)): _* ))
236
267
} else {
268
+ val id = entity match {
269
+ case node : Node => Functions .id(node)
270
+ case rel : Relationship => Functions .id(rel)
271
+ }
237
272
addSkipLimit(returning.orderBy(id))
238
273
}
239
274
} else {
240
- val orderByProp = options.orderBy
275
+ val orderByProp = options.streamingOrderBy
241
276
if (StringUtils .isBlank(orderByProp)) returning else returning.orderBy(entity.property(orderByProp))
242
277
}
243
278
}
@@ -282,6 +317,7 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
282
317
def propertyOrSymbolicName (col : String ) = {
283
318
if (entity != null ) entity.property(col) else Cypher .name(col)
284
319
}
320
+
285
321
column match {
286
322
case Neo4jUtil .INTERNAL_ID_FIELD => Functions .id(entity.asInstanceOf [Node ]).as(Neo4jUtil .INTERNAL_ID_FIELD )
287
323
case Neo4jUtil .INTERNAL_REL_ID_FIELD => Functions .id(entity.asInstanceOf [Relationship ]).as(Neo4jUtil .INTERNAL_REL_ID_FIELD )
@@ -340,7 +376,7 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
340
376
val ret = if (requiredColumns.isEmpty) {
341
377
matchQuery.returning(node)
342
378
} else {
343
- matchQuery.returning(expressions : _* )
379
+ matchQuery.returning(expressions : _* )
344
380
}
345
381
buildStatement(options, ret, node)
346
382
}
@@ -416,9 +452,9 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
416
452
.map(_._1)
417
453
.map(Cypher .parameter)
418
454
val statement = Cypher .call(options.query.value)
419
- .withArgs(cypherParams : _* )
420
- .`yield`(yieldFields : _* )
421
- .returning(retCols : _* )
455
+ .withArgs(cypherParams : _* )
456
+ .`yield`(yieldFields : _* )
457
+ .returning(retCols : _* )
422
458
.build()
423
459
renderer.render(statement)
424
460
}
@@ -450,7 +486,8 @@ class Neo4jQueryService(private val options: Neo4jOptions,
450
486
case QueryType .RELATIONSHIP => strategy.createStatementForRelationships(options)
451
487
case QueryType .QUERY => strategy.createStatementForQuery(options)
452
488
case QueryType .GDS => strategy.createStatementForGDS(options)
453
- case _ => throw new UnsupportedOperationException (s """ Query Type not supported.
489
+ case _ => throw new UnsupportedOperationException (
490
+ s """ Query Type not supported.
454
491
|You provided ${options.query.queryType},
455
492
|supported types: ${QueryType .values.mkString(" ," )}""" .stripMargin)
456
493
}
0 commit comments