Skip to content

Commit e09f94b

Browse files
authored
Improve schema type constraint and constraint management (#581)
1 parent a159a31 commit e09f94b

File tree

20 files changed

+1133
-356
lines changed

20 files changed

+1133
-356
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package org.neo4j.spark.converter
2+
3+
import org.apache.spark.sql.catalyst.InternalRow
4+
import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema, UnsafeRow}
5+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, MapData}
6+
import org.apache.spark.sql.types._
7+
import org.apache.spark.unsafe.types.UTF8String
8+
import org.neo4j.driver.internal._
9+
import org.neo4j.driver.types.{IsoDuration, Node, Relationship}
10+
import org.neo4j.driver.{Value, Values}
11+
import org.neo4j.spark.service.SchemaService
12+
import org.neo4j.spark.util.Neo4jUtil
13+
14+
import java.time._
15+
import java.time.format.DateTimeFormatter
16+
import scala.annotation.tailrec
17+
import scala.collection.JavaConverters._
18+
19+
trait DataConverter[T] {
20+
def convert(value: Any, dataType: DataType = null): T
21+
22+
@tailrec
23+
private[converter] final def extractStructType(dataType: DataType): StructType = dataType match {
24+
case structType: StructType => structType
25+
case mapType: MapType => extractStructType(mapType.valueType)
26+
case arrayType: ArrayType => extractStructType(arrayType.elementType)
27+
case _ => throw new UnsupportedOperationException(s"$dataType not supported")
28+
}
29+
}
30+
31+
object SparkToNeo4jDataConverter {
32+
def apply(): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter()
33+
}
34+
35+
class SparkToNeo4jDataConverter extends DataConverter[Value] {
36+
override def convert(value: Any, dataType: DataType): Value = {
37+
value match {
38+
case date: java.sql.Date => convert(date.toLocalDate, dataType)
39+
case timestamp: java.sql.Timestamp => convert(timestamp.toLocalDateTime, dataType)
40+
case intValue: Int if dataType == DataTypes.DateType => convert(DateTimeUtils
41+
.toJavaDate(intValue), dataType)
42+
case longValue: Long if dataType == DataTypes.TimestampType => convert(DateTimeUtils
43+
.toJavaTimestamp(longValue), dataType)
44+
case unsafeRow: UnsafeRow => {
45+
val structType = extractStructType(dataType)
46+
val row = new GenericRowWithSchema(unsafeRow.toSeq(structType).toArray, structType)
47+
convert(row)
48+
}
49+
case struct: GenericRow => {
50+
def toMap(struct: GenericRow): Value = {
51+
Values.value(
52+
struct.schema.fields.map(
53+
f => f.name -> convert(struct.getAs(f.name), f.dataType)
54+
).toMap.asJava)
55+
}
56+
57+
try {
58+
struct.getAs[UTF8String]("type").toString match {
59+
case SchemaService.POINT_TYPE_2D => Values.point(struct.getAs[Number]("srid").intValue(),
60+
struct.getAs[Number]("x").doubleValue(),
61+
struct.getAs[Number]("y").doubleValue())
62+
case SchemaService.POINT_TYPE_3D => Values.point(struct.getAs[Number]("srid").intValue(),
63+
struct.getAs[Number]("x").doubleValue(),
64+
struct.getAs[Number]("y").doubleValue(),
65+
struct.getAs[Number]("z").doubleValue())
66+
case SchemaService.DURATION_TYPE => Values.isoDuration(struct.getAs[Number]("months").longValue(),
67+
struct.getAs[Number]("days").longValue(),
68+
struct.getAs[Number]("seconds").longValue(),
69+
struct.getAs[Number]("nanoseconds").intValue())
70+
case SchemaService.TIME_TYPE_OFFSET => Values.value(OffsetTime.parse(struct.getAs[UTF8String]("value").toString))
71+
case SchemaService.TIME_TYPE_LOCAL => Values.value(LocalTime.parse(struct.getAs[UTF8String]("value").toString))
72+
case _ => toMap(struct)
73+
}
74+
} catch {
75+
case _: Throwable => toMap(struct)
76+
}
77+
}
78+
case unsafeArray: ArrayData => {
79+
val sparkType = dataType match {
80+
case arrayType: ArrayType => arrayType.elementType
81+
case _ => dataType
82+
}
83+
val javaList = unsafeArray.toSeq[AnyRef](sparkType)
84+
.map(elem => convert(elem, sparkType))
85+
.asJava
86+
Values.value(javaList)
87+
}
88+
case unsafeMapData: MapData => { // Neo4j only supports Map[String, AnyRef]
89+
val mapType = dataType.asInstanceOf[MapType]
90+
val map: Map[String, AnyRef] = (0 until unsafeMapData.numElements())
91+
.map(i => (unsafeMapData.keyArray().getUTF8String(i).toString, unsafeMapData.valueArray().get(i, mapType.valueType)))
92+
.toMap[String, AnyRef]
93+
.mapValues(innerValue => convert(innerValue, mapType.valueType))
94+
.toMap[String, AnyRef]
95+
Values.value(map.asJava)
96+
}
97+
case string: UTF8String => convert(string.toString)
98+
case _ => Values.value(value)
99+
}
100+
}
101+
}
102+
103+
object Neo4jToSparkDataConverter {
104+
def apply(): Neo4jToSparkDataConverter = new Neo4jToSparkDataConverter()
105+
}
106+
107+
class Neo4jToSparkDataConverter extends DataConverter[Any] {
108+
override def convert(value: Any, dataType: DataType): Any = {
109+
if (dataType != null && dataType == DataTypes.StringType && value != null && !value.isInstanceOf[String]) {
110+
convert(Neo4jUtil.mapper.writeValueAsString(value), dataType)
111+
} else {
112+
value match {
113+
case node: Node => {
114+
val map = node.asMap()
115+
val structType = extractStructType(dataType)
116+
val fields = structType
117+
.filter(field => field.name != Neo4jUtil.INTERNAL_ID_FIELD && field.name != Neo4jUtil.INTERNAL_LABELS_FIELD)
118+
.map(field => convert(map.get(field.name), field.dataType))
119+
InternalRow.fromSeq(Seq(convert(node.id()), convert(node.labels())) ++ fields)
120+
}
121+
case rel: Relationship => {
122+
val map = rel.asMap()
123+
val structType = extractStructType(dataType)
124+
val fields = structType
125+
.filter(field => field.name != Neo4jUtil.INTERNAL_REL_ID_FIELD
126+
&& field.name != Neo4jUtil.INTERNAL_REL_TYPE_FIELD
127+
&& field.name != Neo4jUtil.INTERNAL_REL_SOURCE_ID_FIELD
128+
&& field.name != Neo4jUtil.INTERNAL_REL_TARGET_ID_FIELD)
129+
.map(field => convert(map.get(field.name), field.dataType))
130+
InternalRow.fromSeq(Seq(convert(rel.id()),
131+
convert(rel.`type`()),
132+
convert(rel.startNodeId()),
133+
convert(rel.endNodeId())) ++ fields)
134+
}
135+
case d: IsoDuration => {
136+
val months = d.months()
137+
val days = d.days()
138+
val nanoseconds: Integer = d.nanoseconds()
139+
val seconds = d.seconds()
140+
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.DURATION_TYPE), months, days, seconds, nanoseconds, UTF8String.fromString(d.toString)))
141+
}
142+
case zt: ZonedDateTime => DateTimeUtils.instantToMicros(zt.toInstant)
143+
case dt: LocalDateTime => DateTimeUtils.instantToMicros(dt.toInstant(ZoneOffset.UTC))
144+
case d: LocalDate => d.toEpochDay.toInt
145+
case lt: LocalTime => {
146+
InternalRow.fromSeq(Seq(
147+
UTF8String.fromString(SchemaService.TIME_TYPE_LOCAL),
148+
UTF8String.fromString(lt.format(DateTimeFormatter.ISO_TIME))
149+
))
150+
}
151+
case t: OffsetTime => {
152+
InternalRow.fromSeq(Seq(
153+
UTF8String.fromString(SchemaService.TIME_TYPE_OFFSET),
154+
UTF8String.fromString(t.format(DateTimeFormatter.ISO_TIME))
155+
))
156+
}
157+
case p: InternalPoint2D => {
158+
val srid: Integer = p.srid()
159+
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_2D), srid, p.x(), p.y(), null))
160+
}
161+
case p: InternalPoint3D => {
162+
val srid: Integer = p.srid()
163+
InternalRow.fromSeq(Seq(UTF8String.fromString(SchemaService.POINT_TYPE_3D), srid, p.x(), p.y(), p.z()))
164+
}
165+
case l: java.util.List[_] => {
166+
val elementType = if (dataType != null) dataType.asInstanceOf[ArrayType].elementType else null
167+
ArrayData.toArrayData(l.asScala.map(e => convert(e, elementType)).toArray)
168+
}
169+
case map: java.util.Map[_, _] => {
170+
if (dataType != null) {
171+
val mapType = dataType.asInstanceOf[MapType]
172+
ArrayBasedMapData(map.asScala.map(t => (convert(t._1, mapType.keyType), convert(t._2, mapType.valueType))))
173+
} else {
174+
ArrayBasedMapData(map.asScala.map(t => (convert(t._1), convert(t._2))))
175+
}
176+
}
177+
case s: String => UTF8String.fromString(s)
178+
case _ => value
179+
}
180+
}
181+
}
182+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package org.neo4j.spark.converter
2+
3+
import org.apache.spark.sql.types.{DataType, DataTypes}
4+
import org.neo4j.driver.types.Entity
5+
import org.neo4j.spark.converter.CypherToSparkTypeConverter.{cleanTerms, durationType, pointType, timeType}
6+
import org.neo4j.spark.converter.SparkToCypherTypeConverter.mapping
7+
import org.neo4j.spark.service.SchemaService.normalizedClassName
8+
import org.neo4j.spark.util.Neo4jImplicits.EntityImplicits
9+
10+
import scala.collection.JavaConverters._
11+
12+
trait TypeConverter[SOURCE_TYPE, DESTINATION_TYPE] {
13+
14+
def convert(sourceType: SOURCE_TYPE, value: Any = null): DESTINATION_TYPE
15+
16+
}
17+
18+
object CypherToSparkTypeConverter {
19+
def apply(): CypherToSparkTypeConverter = new CypherToSparkTypeConverter()
20+
21+
private val cleanTerms: String = "Unmodifiable|Internal|Iso|2D|3D|Offset|Local|Zoned"
22+
23+
val durationType: DataType = DataTypes.createStructType(Array(
24+
DataTypes.createStructField("type", DataTypes.StringType, false),
25+
DataTypes.createStructField("months", DataTypes.LongType, false),
26+
DataTypes.createStructField("days", DataTypes.LongType, false),
27+
DataTypes.createStructField("seconds", DataTypes.LongType, false),
28+
DataTypes.createStructField("nanoseconds", DataTypes.IntegerType, false),
29+
DataTypes.createStructField("value", DataTypes.StringType, false)
30+
))
31+
32+
val pointType: DataType = DataTypes.createStructType(Array(
33+
DataTypes.createStructField("type", DataTypes.StringType, false),
34+
DataTypes.createStructField("srid", DataTypes.IntegerType, false),
35+
DataTypes.createStructField("x", DataTypes.DoubleType, false),
36+
DataTypes.createStructField("y", DataTypes.DoubleType, false),
37+
DataTypes.createStructField("z", DataTypes.DoubleType, true)
38+
))
39+
40+
val timeType: DataType = DataTypes.createStructType(Array(
41+
DataTypes.createStructField("type", DataTypes.StringType, false),
42+
DataTypes.createStructField("value", DataTypes.StringType, false)
43+
))
44+
}
45+
46+
class CypherToSparkTypeConverter extends TypeConverter[String, DataType] {
47+
override def convert(sourceType: String, value: Any = null): DataType = sourceType
48+
.replaceAll(cleanTerms, "") match {
49+
case "Node" | "Relationship" => if (value != null) value.asInstanceOf[Entity].toStruct else DataTypes.NullType
50+
case "NodeArray" | "RelationshipArray" => if (value != null) DataTypes.createArrayType(value.asInstanceOf[Entity].toStruct) else DataTypes.NullType
51+
case "Boolean" => DataTypes.BooleanType
52+
case "Long" => DataTypes.LongType
53+
case "Double" => DataTypes.DoubleType
54+
case "Point" => pointType
55+
case "DateTime" | "ZonedDateTime" | "LocalDateTime" => DataTypes.TimestampType
56+
case "Time" => timeType
57+
case "Date" => DataTypes.DateType
58+
case "Duration" => durationType
59+
case "Map" => {
60+
val valueType = if (value == null) {
61+
DataTypes.NullType
62+
} else {
63+
val map = value.asInstanceOf[java.util.Map[String, AnyRef]].asScala
64+
val types = map.values
65+
.map(normalizedClassName)
66+
.toSet
67+
if (types.size == 1) convert(types.head, map.values.head) else DataTypes.StringType
68+
}
69+
DataTypes.createMapType(DataTypes.StringType, valueType)
70+
}
71+
case "Array" => {
72+
val valueType = if (value == null) {
73+
DataTypes.NullType
74+
} else {
75+
val list = value.asInstanceOf[java.util.List[AnyRef]].asScala
76+
val types = list
77+
.map(normalizedClassName)
78+
.toSet
79+
if (types.size == 1) convert(types.head, list.head) else DataTypes.StringType
80+
}
81+
DataTypes.createArrayType(valueType)
82+
}
83+
// These are from APOC
84+
case "StringArray" => DataTypes.createArrayType(DataTypes.StringType)
85+
case "LongArray" => DataTypes.createArrayType(DataTypes.LongType)
86+
case "DoubleArray" => DataTypes.createArrayType(DataTypes.DoubleType)
87+
case "BooleanArray" => DataTypes.createArrayType(DataTypes.BooleanType)
88+
case "PointArray" => DataTypes.createArrayType(pointType)
89+
case "DateTimeArray" => DataTypes.createArrayType(DataTypes.TimestampType)
90+
case "TimeArray" => DataTypes.createArrayType(timeType)
91+
case "DateArray" => DataTypes.createArrayType(DataTypes.DateType)
92+
case "DurationArray" => DataTypes.createArrayType(durationType)
93+
// Default is String
94+
case _ => DataTypes.StringType
95+
}
96+
}
97+
98+
object SparkToCypherTypeConverter {
99+
def apply(): SparkToCypherTypeConverter = new SparkToCypherTypeConverter()
100+
101+
private val mapping: Map[DataType, String] = Map(
102+
DataTypes.BooleanType -> "BOOLEAN",
103+
DataTypes.StringType -> "STRING",
104+
DataTypes.IntegerType -> "INTEGER",
105+
DataTypes.LongType -> "INTEGER",
106+
DataTypes.FloatType -> "FLOAT",
107+
DataTypes.DoubleType -> "FLOAT",
108+
DataTypes.DateType -> "DATE",
109+
DataTypes.TimestampType -> "LOCAL DATETIME",
110+
durationType -> "DURATION",
111+
pointType -> "POINT",
112+
// Cypher graph entities do not allow null values in arrays
113+
DataTypes.createArrayType(DataTypes.BooleanType, false) -> "LIST<BOOLEAN NOT NULL>",
114+
DataTypes.createArrayType(DataTypes.StringType, false) -> "LIST<STRING NOT NULL>",
115+
DataTypes.createArrayType(DataTypes.IntegerType, false) -> "LIST<INTEGER NOT NULL>",
116+
DataTypes.createArrayType(DataTypes.LongType, false) -> "LIST<INTEGER NOT NULL>",
117+
DataTypes.createArrayType(DataTypes.FloatType, false) -> "LIST<FLOAT NOT NULL>",
118+
DataTypes.createArrayType(DataTypes.DoubleType, false) -> "LIST<FLOAT NOT NULL>",
119+
DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST<DATE NOT NULL>",
120+
DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<LOCAL DATETIME NOT NULL>",
121+
DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST<LOCAL DATETIME NOT NULL>",
122+
DataTypes.createArrayType(durationType, false) -> "LIST<DURATION NOT NULL>",
123+
DataTypes.createArrayType(pointType, false) -> "LIST<POINT NOT NULL>"
124+
)
125+
}
126+
127+
class SparkToCypherTypeConverter extends TypeConverter[DataType, String] {
128+
override def convert(sourceType: DataType, value: Any): String = mapping(sourceType)
129+
}

common/src/main/scala/org/neo4j/spark/service/MappingService.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import org.apache.spark.sql.types.StructType
1010
import org.neo4j.driver.internal.value.MapValue
1111
import org.neo4j.driver.types.Node
1212
import org.neo4j.driver.{Record, Value, Values}
13+
import org.neo4j.spark.converter.{Neo4jToSparkDataConverter, SparkToNeo4jDataConverter}
1314
import org.neo4j.spark.service.Neo4jWriteMappingStrategy.{KEYS, PROPERTIES}
1415
import org.neo4j.spark.util.{Neo4jNodeMetadata, Neo4jOptions, Neo4jUtil, QueryType, RelationshipSaveStrategy, ValidateSchemaOptions, Validations}
1516

@@ -21,6 +22,8 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
2122
extends Neo4jMappingStrategy[InternalRow, java.util.Map[String, AnyRef]]
2223
with Logging {
2324

25+
private val dataConverter = SparkToNeo4jDataConverter()
26+
2427
override def node(row: InternalRow, schema: StructType): java.util.Map[String, AnyRef] = {
2528
Validations.validate(ValidateSchemaOptions(options, schema))
2629

@@ -35,7 +38,7 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
3538
override def accept(key: String, value: AnyRef): Unit = if (options.nodeMetadata.nodeKeys.contains(key)) {
3639
keys.put(options.nodeMetadata.nodeKeys.getOrElse(key, key), value)
3740
} else {
38-
properties.put(options.nodeMetadata.nodeProps.getOrElse(key, key), value)
41+
properties.put(options.nodeMetadata.properties.getOrElse(key, key), value)
3942
}
4043
})
4144

@@ -70,8 +73,8 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
7073
if (nodeMetadata.nodeKeys.contains(key)) {
7174
nodeMap.get(KEYS).put(nodeMetadata.nodeKeys.getOrElse(key, key), value)
7275
}
73-
if (nodeMetadata.nodeProps.contains(key)) {
74-
nodeMap.get(PROPERTIES).put(nodeMetadata.nodeProps.getOrElse(key, key), value)
76+
if (nodeMetadata.properties.contains(key)) {
77+
nodeMap.get(PROPERTIES).put(nodeMetadata.properties.getOrElse(key, key), value)
7578
}
7679
}
7780

@@ -83,7 +86,9 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
8386
addToNodeMap(sourceNodeMap, source, key, value)
8487
addToNodeMap(targetNodeMap, target, key, value)
8588

86-
if (options.relationshipMetadata.properties.contains(key)) {
89+
if (options.relationshipMetadata.relationshipKeys.contains(key)) {
90+
relMap.get(KEYS).put(options.relationshipMetadata.relationshipKeys.getOrElse(key, key), value)
91+
} else {
8792
relMap.get(PROPERTIES).put(options.relationshipMetadata.properties.getOrElse(key, key), value)
8893
}
8994
}
@@ -123,7 +128,7 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
123128
schema.indices
124129
.flatMap(i => {
125130
val field = schema(i)
126-
val neo4jValue = Neo4jUtil.convertFromSpark(seq(i), field.dataType)
131+
val neo4jValue = dataConverter.convert(seq(i), field.dataType)
127132
neo4jValue match {
128133
case map: MapValue =>
129134
map.asMap().asScala.toMap
@@ -140,6 +145,8 @@ class Neo4jWriteMappingStrategy(private val options: Neo4jOptions)
140145

141146
class Neo4jReadMappingStrategy(private val options: Neo4jOptions, requiredColumns: StructType) extends Neo4jMappingStrategy[Record, InternalRow] {
142147

148+
private val dataConverter = Neo4jToSparkDataConverter()
149+
143150
override def node(record: Record, schema: StructType): InternalRow = {
144151
if (requiredColumns.nonEmpty) {
145152
query(record, schema)
@@ -158,7 +165,7 @@ class Neo4jReadMappingStrategy(private val options: Neo4jOptions, requiredColumn
158165
schema: StructType) = InternalRow
159166
.fromSeq(
160167
schema.map(
161-
field => Neo4jUtil.convertFromNeo4j(map.get(field.name), field.dataType)
168+
field => dataConverter.convert(map.get(field.name), field.dataType)
162169
)
163170
)
164171

@@ -254,6 +261,7 @@ private abstract class MappingBiConsumer extends BiConsumer[String, AnyRef] {
254261
val sourceNodeMap = new util.HashMap[String, util.Map[String, AnyRef]]()
255262
val targetNodeMap = new util.HashMap[String, util.Map[String, AnyRef]]()
256263

264+
relMap.put(KEYS, new util.HashMap[String, AnyRef]())
257265
relMap.put(PROPERTIES, new util.HashMap[String, AnyRef]())
258266
sourceNodeMap.put(PROPERTIES, new util.HashMap[String, AnyRef]())
259267
sourceNodeMap.put(KEYS, new util.HashMap[String, AnyRef]())

0 commit comments

Comments
 (0)