Skip to content

Commit e971d09

Browse files
fbivilleAndrea Santurbano
andauthored
Implement Top N pushdown (#526)
Fixes #519 Signed-off-by: Florent Biville <florent.biville@neo4j.com> Co-authored-by: Andrea Santurbano <andrea.santurbano@larus-ba.it>
1 parent d3154ae commit e971d09

19 files changed

+370
-201
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package org.neo4j.spark.config
2+
3+
import org.apache.spark.sql.connector.expressions.SortOrder
4+
5+
case class TopN(limit: Int, orders: Array[SortOrder] = Array.empty)

common/src/main/scala/org/neo4j/spark/reader/BasePartitionReader.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
66
import org.apache.spark.sql.sources.Filter
77
import org.apache.spark.sql.types.StructType
88
import org.neo4j.driver.{Record, Session, Transaction, Values}
9-
import org.neo4j.spark.service.{MappingService, Neo4jQueryReadStrategy, Neo4jQueryService, Neo4jQueryStrategy, Neo4jReadMappingStrategy, PartitionSkipLimit}
9+
import org.neo4j.spark.service.{MappingService, Neo4jQueryReadStrategy, Neo4jQueryService, Neo4jQueryStrategy, Neo4jReadMappingStrategy, PartitionPagination}
1010
import org.neo4j.spark.util.{DriverCache, Neo4jOptions, Neo4jUtil, QueryType}
1111

1212
import java.io.IOException
@@ -17,7 +17,7 @@ abstract class BasePartitionReader(private val options: Neo4jOptions,
1717
private val filters: Array[Filter],
1818
private val schema: StructType,
1919
private val jobId: String,
20-
private val partitionSkipLimit: PartitionSkipLimit,
20+
private val partitionSkipLimit: PartitionPagination,
2121
private val scriptResult: java.util.List[java.util.Map[String, AnyRef]],
2222
private val requiredColumns: StructType,
2323
private val aggregateColumns: Array[AggregateFunc]) extends Logging {

common/src/main/scala/org/neo4j/spark/service/Neo4jQueryService.scala

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package org.neo4j.spark.service
22

33
import org.apache.commons.lang3.StringUtils
4+
import org.apache.spark.internal.Logging
45
import org.apache.spark.sql.SaveMode
6+
import org.apache.spark.sql.connector.expressions.{SortDirection, SortOrder}
57
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
68
import org.apache.spark.sql.sources.{And, Filter, Or}
79
import org.neo4j.cypherdsl.core._
@@ -102,21 +104,28 @@ class Neo4jQueryWriteStrategy(private val saveMode: SaveMode) extends Neo4jQuery
102104
}
103105

104106
class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
105-
partitionSkipLimit: PartitionSkipLimit = PartitionSkipLimit.EMPTY,
107+
partitionPagination: PartitionPagination = PartitionPagination.EMPTY,
106108
requiredColumns: Seq[String] = Seq.empty,
107109
aggregateColumns: Array[AggregateFunc] = Array.empty,
108-
jobId: String = "") extends Neo4jQueryStrategy {
110+
jobId: String = "") extends Neo4jQueryStrategy with Logging {
109111
private val renderer: Renderer = Renderer.getDefaultRenderer
110112

111-
private val hasSkipLimit: Boolean = partitionSkipLimit.skip != -1 && partitionSkipLimit.limit != -1
113+
private val hasSkipLimit: Boolean = partitionPagination.skip != -1 && partitionPagination.topN.limit != -1
112114

113115
override def createStatementForQuery(options: Neo4jOptions): String = {
116+
if (partitionPagination.topN.orders.nonEmpty) {
117+
logWarning(
118+
s"""Top N push-down optimizations with aggregations are not supported for custom queries.
119+
|\tThese aggregations are going to be ignored.
120+
|\tPlease specify the aggregations in the custom query directly""".stripMargin)
121+
}
114122
val limitedQuery = if (hasSkipLimit) {
115123
s"""${options.query.value}
116-
|SKIP ${partitionSkipLimit.skip} LIMIT ${partitionSkipLimit.limit}
124+
|SKIP ${partitionPagination.skip} LIMIT ${partitionPagination.topN.limit}
117125
|""".stripMargin
118126
} else {
119-
options.query.value
127+
s"""${options.query.value}
128+
|""".stripMargin
120129
}
121130
s"""WITH ${"$"}scriptResult AS ${Neo4jQueryStrategy.VARIABLE_SCRIPT_RESULT}
122131
|$limitedQuery""".stripMargin
@@ -130,16 +139,39 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
130139
.named(Neo4jUtil.RELATIONSHIP_ALIAS)
131140

132141
val matchQuery: StatementBuilder.OngoingReadingWithoutWhere = filterRelationship(sourceNode, targetNode, relationship)
133-
134142
val returnExpressions: Seq[Expression] = buildReturnExpression(sourceNode, targetNode, relationship)
135143
val stmt = if (aggregateColumns.isEmpty) {
136-
buildStatement(options, matchQuery.returning(returnExpressions : _*), relationship)
144+
val query = matchQuery.returning(returnExpressions: _*)
145+
buildStatement(options, query, relationship)
137146
} else {
138147
buildStatementAggregation(options, matchQuery, relationship, returnExpressions)
139148
}
140149
renderer.render(stmt)
141150
}
142151

152+
private def convertSort(entity: PropertyContainer, order: SortOrder): SortItem = {
153+
val sortExpression = order.expression().describe()
154+
155+
val container: Option[PropertyContainer] = entity match {
156+
case relationship: Relationship =>
157+
if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS}.")) {
158+
Some(relationship.getLeft)
159+
} else if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_TARGET_ALIAS}.")) {
160+
Some(relationship.getRight)
161+
} else if (sortExpression.contains(s"${Neo4jUtil.RELATIONSHIP_ALIAS}.")) {
162+
Some(relationship)
163+
} else {
164+
None
165+
}
166+
case _ => Some(entity)
167+
}
168+
val direction = if (order.direction() == SortDirection.ASCENDING) SortItem.Direction.ASC else SortItem.Direction.DESC
169+
170+
Cypher.sort(container
171+
.map(_.property(sortExpression.removeAlias()))
172+
.getOrElse(Cypher.name(sortExpression.unquote())), direction)
173+
}
174+
143175
private def buildReturnExpression(sourceNode: Node, targetNode: Node, relationship: Relationship): Seq[Expression] = {
144176
if (requiredColumns.isEmpty) {
145177
Seq(relationship.getRequiredSymbolicName, sourceNode.as(Neo4jUtil.RELATIONSHIP_SOURCE_ALIAS), targetNode.as(Neo4jUtil.RELATIONSHIP_TARGET_ALIAS))
@@ -186,13 +218,13 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
186218
}
187219
query
188220
.`with`(entity)
189-
// Spark does not push down limits when aggregation is involved
221+
// Spark does not push down limits/top N when aggregation is involved
190222
.orderBy(id)
191-
.skip(partitionSkipLimit.skip)
192-
.limit(partitionSkipLimit.limit)
223+
.skip(partitionPagination.skip)
224+
.limit(partitionPagination.topN.limit)
193225
.returning(fields: _*)
194226
} else {
195-
val orderByProp = options.orderBy
227+
val orderByProp = options.streamingOrderBy
196228
if (StringUtils.isBlank(orderByProp)) {
197229
query.returning(fields: _*)
198230
} else {
@@ -207,37 +239,40 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
207239
}
208240

209241
private def buildStatement(options: Neo4jOptions,
210-
returning: StatementBuilder.OngoingReadingAndReturn,
242+
returning: StatementBuilder.TerminalExposesSkip
243+
with StatementBuilder.TerminalExposesLimit
244+
with StatementBuilder.TerminalExposesOrderBy
245+
with StatementBuilder.BuildableStatement[_],
211246
entity: PropertyContainer = null): Statement = {
212247

213248
def addSkipLimit(ret: StatementBuilder.TerminalExposesSkip
214-
with StatementBuilder.TerminalExposesLimit
215-
with StatementBuilder.BuildableStatement[_]) = {
249+
with StatementBuilder.TerminalExposesLimit
250+
with StatementBuilder.BuildableStatement[_]) = {
216251

217-
if (partitionSkipLimit.skip == 0) {
218-
ret.limit(partitionSkipLimit.limit)
252+
if (partitionPagination.skip == 0) {
253+
ret.limit(partitionPagination.topN.limit)
219254
}
220255
else {
221-
ret.skip(partitionSkipLimit.skip).asInstanceOf[StatementBuilder.TerminalExposesLimit]
222-
.limit(partitionSkipLimit.limit)
256+
ret.skip(partitionPagination.skip)
257+
.limit(partitionPagination.topN.limit)
223258
}
224259
}
225260

226261
val ret = if (entity == null) {
227262
if (hasSkipLimit) addSkipLimit(returning) else returning
228263
} else {
229264
if (hasSkipLimit) {
230-
val id = entity match {
231-
case node: Node => Functions.id(node)
232-
case rel: Relationship => Functions.id(rel)
233-
}
234-
if (options.partitions == 1) {
235-
addSkipLimit(returning)
265+
if (options.partitions == 1 || partitionPagination.topN.orders.nonEmpty) {
266+
addSkipLimit(returning.orderBy(partitionPagination.topN.orders.map(order => convertSort(entity, order)): _*))
236267
} else {
268+
val id = entity match {
269+
case node: Node => Functions.id(node)
270+
case rel: Relationship => Functions.id(rel)
271+
}
237272
addSkipLimit(returning.orderBy(id))
238273
}
239274
} else {
240-
val orderByProp = options.orderBy
275+
val orderByProp = options.streamingOrderBy
241276
if (StringUtils.isBlank(orderByProp)) returning else returning.orderBy(entity.property(orderByProp))
242277
}
243278
}
@@ -282,6 +317,7 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
282317
def propertyOrSymbolicName(col: String) = {
283318
if (entity != null) entity.property(col) else Cypher.name(col)
284319
}
320+
285321
column match {
286322
case Neo4jUtil.INTERNAL_ID_FIELD => Functions.id(entity.asInstanceOf[Node]).as(Neo4jUtil.INTERNAL_ID_FIELD)
287323
case Neo4jUtil.INTERNAL_REL_ID_FIELD => Functions.id(entity.asInstanceOf[Relationship]).as(Neo4jUtil.INTERNAL_REL_ID_FIELD)
@@ -340,7 +376,7 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
340376
val ret = if (requiredColumns.isEmpty) {
341377
matchQuery.returning(node)
342378
} else {
343-
matchQuery.returning(expressions : _*)
379+
matchQuery.returning(expressions: _*)
344380
}
345381
buildStatement(options, ret, node)
346382
}
@@ -416,9 +452,9 @@ class Neo4jQueryReadStrategy(filters: Array[Filter] = Array.empty[Filter],
416452
.map(_._1)
417453
.map(Cypher.parameter)
418454
val statement = Cypher.call(options.query.value)
419-
.withArgs(cypherParams : _*)
420-
.`yield`(yieldFields : _*)
421-
.returning(retCols : _*)
455+
.withArgs(cypherParams: _*)
456+
.`yield`(yieldFields: _*)
457+
.returning(retCols: _*)
422458
.build()
423459
renderer.render(statement)
424460
}
@@ -450,7 +486,8 @@ class Neo4jQueryService(private val options: Neo4jOptions,
450486
case QueryType.RELATIONSHIP => strategy.createStatementForRelationships(options)
451487
case QueryType.QUERY => strategy.createStatementForQuery(options)
452488
case QueryType.GDS => strategy.createStatementForGDS(options)
453-
case _ => throw new UnsupportedOperationException(s"""Query Type not supported.
489+
case _ => throw new UnsupportedOperationException(
490+
s"""Query Type not supported.
454491
|You provided ${options.query.queryType},
455492
|supported types: ${QueryType.values.mkString(",")}""".stripMargin)
456493
}

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
66
import org.neo4j.driver.exceptions.ClientException
77
import org.neo4j.driver.types.Entity
88
import org.neo4j.driver.{Record, Session, Transaction, TransactionWork, Value, Values, summary}
9+
import org.neo4j.spark.config.TopN
910
import org.neo4j.spark.service.SchemaService.{cypherToSparkType, normalizedClassName, normalizedClassNameFromGraphEntity}
1011
import org.neo4j.spark.util.Neo4jImplicits.{CypherImplicits, EntityImplicits}
1112
import org.neo4j.spark.util._
@@ -16,12 +17,12 @@ import scala.collection.JavaConverters._
1617
import scala.collection.mutable
1718
import scala.collection.mutable.ArrayBuffer
1819

19-
object PartitionSkipLimit {
20-
val EMPTY = PartitionSkipLimit(0, -1, -1)
21-
val EMPTY_FOR_QUERY = PartitionSkipLimit(0, 0, 0)
20+
object PartitionPagination {
21+
val EMPTY = PartitionPagination(0, -1, TopN(-1))
22+
val EMPTY_FOR_QUERY = PartitionPagination(0, 0, TopN(0))
2223
}
2324

24-
case class PartitionSkipLimit(partitionNumber: Int, skip: Long, limit: Long)
25+
case class PartitionPagination(partitionNumber: Int, skip: Long, topN: TopN)
2526

2627
case class Neo4jVersion(name: String, versions: Seq[String], edition: String)
2728

@@ -301,7 +302,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
301302
StructType(fields)
302303
}
303304

304-
def inputForGDSProc(procName: String): Seq[(String, Boolean)] = {
305+
def inputForGDSProc(procName: String): Seq[(String, Boolean)] = {
305306
val query =
306307
"""
307308
|WITH $procName AS procName
@@ -472,27 +473,28 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
472473
case QueryType.QUERY => countForQuery()
473474
}
474475

475-
def skipLimitFromPartition(limit: Option[Int]): Seq[PartitionSkipLimit] = if (options.partitions == 1) {
476-
val skipLimit = limit.map(l => PartitionSkipLimit(0, 0, l)).getOrElse(PartitionSkipLimit.EMPTY)
477-
Seq(skipLimit)
478-
} else {
479-
val count: Long = this.count()
480-
if (count <= 0) {
481-
Seq(PartitionSkipLimit.EMPTY)
476+
def skipLimitFromPartition(topN: Option[TopN]): Seq[PartitionPagination] =
477+
if (options.partitions == 1) {
478+
val skipLimit = topN.map(top => PartitionPagination(0, 0, top)).getOrElse(PartitionPagination.EMPTY)
479+
Seq(skipLimit)
482480
} else {
483-
val partitionSize = Math.ceil(count.toDouble / options.partitions).toLong
484-
val partitions = options.query.queryType match {
485-
case QueryType.QUERY => if (options.queryMetadata.queryCount.nonEmpty) {
486-
options.partitions // for custom query count we overfetch
487-
} else {
488-
options.partitions - 1
481+
val count: Long = this.count()
482+
if (count <= 0) {
483+
Seq(PartitionPagination.EMPTY)
484+
} else {
485+
val partitionSize = Math.ceil(count.toDouble / options.partitions).toInt
486+
val partitions = options.query.queryType match {
487+
case QueryType.QUERY => if (options.queryMetadata.queryCount.nonEmpty) {
488+
options.partitions // for custom query count we overfetch
489+
} else {
490+
options.partitions - 1
491+
}
492+
case _ => options.partitions - 1
489493
}
490-
case _ => options.partitions - 1
494+
(0 to partitions)
495+
.map(index => PartitionPagination(index, index * partitionSize, TopN(partitionSize, Array.empty)))
491496
}
492-
(0 to partitions)
493-
.map(index => PartitionSkipLimit(index, index * partitionSize, partitionSize))
494497
}
495-
}
496498

497499
def isGdsProcedure(procName: String): Boolean = {
498500
val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
@@ -614,7 +616,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
614616
|RETURN count(*) > 0 AS isPresent""".stripMargin
615617
val params: util.Map[String, AnyRef] = Map("labels" -> Seq(label).asJava,
616618
"properties" -> props.asJava).asJava.asInstanceOf[util.Map[String, AnyRef]]
617-
session.run(queryCheck, params)
619+
session.run(queryCheck, params)
618620
.single()
619621
.get("isPresent")
620622
.asBoolean()
@@ -699,7 +701,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
699701
val label = options.nodeMetadata.labels.head
700702
session.run(
701703
s"""MATCH (n:$label)
702-
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin)
704+
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin)
703705
.single()
704706
.get(options.streamingOptions.propertyName)
705707
.asLong(-1)
@@ -731,7 +733,7 @@ class SchemaService(private val options: Neo4jOptions, private val driverCache:
731733

732734
private def logResolutionChange(message: String, e: ClientException): Unit = {
733735
log.warn(message)
734-
if(!e.code().equals("Neo.ClientError.Procedure.ProcedureNotFound")) {
736+
if (!e.code().equals("Neo.ClientError.Procedure.ProcedureNotFound")) {
735737
log.warn(s"For the following exception", e)
736738
}
737739
}

common/src/main/scala/org/neo4j/spark/streaming/BaseStreamingPartitionReader.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
55
import org.apache.spark.sql.sources.Filter
66
import org.apache.spark.sql.types.{DataTypes, StructType}
77
import org.neo4j.spark.reader.BasePartitionReader
8-
import org.neo4j.spark.service.{Neo4jQueryStrategy, PartitionSkipLimit}
8+
import org.neo4j.spark.service.{Neo4jQueryStrategy, PartitionPagination}
99
import org.neo4j.spark.util.Neo4jImplicits._
1010
import org.neo4j.spark.util.{Neo4jOptions, Neo4jUtil, StreamingFrom}
1111

@@ -16,7 +16,7 @@ class BaseStreamingPartitionReader(private val options: Neo4jOptions,
1616
private val filters: Array[Filter],
1717
private val schema: StructType,
1818
private val jobId: String,
19-
private val partitionSkipLimit: PartitionSkipLimit,
19+
private val partitionSkipLimit: PartitionPagination,
2020
private val scriptResult: java.util.List[java.util.Map[String, AnyRef]],
2121
private val offsetAccumulator: OffsetStorage[java.lang.Long, java.lang.Long],
2222
private val requiredColumns: StructType,

common/src/main/scala/org/neo4j/spark/util/Neo4jImplicits.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ object Neo4jImplicits {
2727
def isQuoted(): Boolean = str.startsWith("`");
2828

2929
def removeAlias(): String = {
30-
val splatString = str.split('.')
30+
val splatString = str.unquote().split('.')
3131

3232
if (splatString.size > 1) {
3333
splatString.tail.mkString(".")

common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S
5454
val pushdownColumnsEnabled: Boolean = getParameter(PUSHDOWN_COLUMNS_ENABLED, DEFAULT_PUSHDOWN_COLUMNS_ENABLED.toString).toBoolean
5555
val pushdownAggregateEnabled: Boolean = getParameter(PUSHDOWN_AGGREGATE_ENABLED, DEFAULT_PUSHDOWN_AGGREGATE_ENABLED.toString).toBoolean
5656
val pushdownLimitEnabled: Boolean = getParameter(PUSHDOWN_LIMIT_ENABLED, DEFAULT_PUSHDOWN_LIMIT_ENABLED.toString).toBoolean
57+
val pushdownTopNEnabled: Boolean = getParameter(PUSHDOWN_TOPN_ENABLED, DEFAULT_PUSHDOWN_TOPN_ENABLED.toString).toBoolean
5758

5859
val schemaMetadata: Neo4jSchemaMetadata = Neo4jSchemaMetadata(getParameter(SCHEMA_FLATTEN_LIMIT, DEFAULT_SCHEMA_FLATTEN_LIMIT.toString).toInt,
5960
SchemaStrategy.withCaseInsensitiveName(getParameter(SCHEMA_STRATEGY, DEFAULT_SCHEMA_STRATEGY.toString).toUpperCase),
@@ -202,7 +203,7 @@ class Neo4jOptions(private val options: java.util.Map[String, String]) extends S
202203

203204
val partitions: Int = getParameter(PARTITIONS, DEFAULT_PARTITIONS.toString).toInt
204205

205-
val orderBy: String = getParameter(ORDER_BY, getParameter(STREAMING_PROPERTY_NAME))
206+
val streamingOrderBy: String = getParameter(ORDER_BY, getParameter(STREAMING_PROPERTY_NAME))
206207

207208
val apocConfig: Neo4jApocConfig = Neo4jApocConfig(parameters.asScala
208209
.filterKeys(_.startsWith("apoc."))
@@ -391,6 +392,7 @@ object Neo4jOptions {
391392
val PUSHDOWN_COLUMNS_ENABLED = "pushdown.columns.enabled"
392393
val PUSHDOWN_AGGREGATE_ENABLED = "pushdown.aggregate.enabled"
393394
val PUSHDOWN_LIMIT_ENABLED = "pushdown.limit.enabled"
395+
val PUSHDOWN_TOPN_ENABLED = "pushdown.topN.enabled"
394396

395397
// schema options
396398
val SCHEMA_STRATEGY = "schema.strategy"
@@ -462,6 +464,7 @@ object Neo4jOptions {
462464
val DEFAULT_PUSHDOWN_COLUMNS_ENABLED = true
463465
val DEFAULT_PUSHDOWN_AGGREGATE_ENABLED = true
464466
val DEFAULT_PUSHDOWN_LIMIT_ENABLED = true
467+
val DEFAULT_PUSHDOWN_TOPN_ENABLED = true
465468
val DEFAULT_PARTITIONS = 1
466469
val DEFAULT_OPTIMIZATION_TYPE = OptimizationType.NONE
467470
val DEFAULT_SAVE_MODE = SaveMode.Overwrite

0 commit comments

Comments
 (0)