diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 08e545cb8c204..850bfb4a89857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -408,6 +408,10 @@ class ParquetFileFormat } override def supportDataType(dataType: DataType): Boolean = dataType match { + // GeoSpatial data types in Parquet are limited only to types with supported SRIDs. + case g: GeometryType => GeometryType.isSridSupported(g.srid) + case g: GeographyType => GeographyType.isSridSupported(g.srid) + case _: AtomicType | _: NullType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index dcaf88fa8dfdb..b84148992e32b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, STUtils} import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ @@ -276,6 +276,20 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { (row: SpecializedGetters, ordinal: Int) => recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal))) + case _: GeometryType => + (row: SpecializedGetters, ordinal: Int) => + // Data is written to Parquet using the WKB format, as per spec: + // https://parquet.apache.org/docs/file-format/types/geospatial/. + val wkb = STUtils.stAsBinary(row.getGeometry(ordinal)) + recordConsumer.addBinary(Binary.fromReusedByteArray(wkb)) + + case _: GeographyType => + (row: SpecializedGetters, ordinal: Int) => + // Data is written to Parquet using the WKB format, as per spec: + // https://parquet.apache.org/docs/file-format/types/geospatial/. + val wkb = STUtils.stAsBinary(row.getGeography(ordinal)) + recordConsumer.addBinary(Binary.fromReusedByteArray(wkb)) + case DecimalType.Fixed(precision, scale) => makeDecimalWriter(precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index 28c5a62f91ecb..67052c201a9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -51,6 +51,10 @@ case class ParquetTable( } override def supportsDataType(dataType: DataType): Boolean = dataType match { + // GeoSpatial data types in Parquet are limited only to types with supported SRIDs. + case g: GeometryType => GeometryType.isSridSupported(g.srid) + case g: GeographyType => GeographyType.isSridSupported(g.srid) + case _: AtomicType => true case st: StructType => st.forall { f => supportsDataType(f.dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/STExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/STExpressionsSuite.scala index 323826ca38202..d7f1ab49e89c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/STExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/STExpressionsSuite.scala @@ -42,6 +42,60 @@ class STExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(expectedDataType)) } + // Test data: WKB representations of POINT(1 2) and POINT(3 4). + private final val wkbString1 = "0101000000000000000000F03F0000000000000040" + private final val wkbString2 = "010100000000000000000008400000000000001040" + + /** Geospatial type storage. */ + + test("Parquet tables - unsupported geospatial types") { + val tableName = "tst_tbl" + // Test both v1 and v2 data sources. + Seq(true, false).foreach { useV1 => + val useV1List = if (useV1) "parquet" else "" + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) { + Seq("GEOMETRY(ANY)", "GEOGRAPHY(ANY)").foreach { unsupportedType => + withTable(tableName) { + checkError( + exception = intercept[AnalysisException] { + sql(s"CREATE TABLE $tableName (g $unsupportedType) USING PARQUET") + }, + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + parameters = Map( + "columnName" -> "`g`", + "columnType" -> s""""$unsupportedType"""", + "format" -> "Parquet")) + } + } + } + } + } + + test("Parquet write support for geometry and geography types") { + val tableName = "tst_tbl" + // Test both v1 and v2 data sources. + Seq(true, false).foreach { useV1 => + val useV1List = if (useV1) "parquet" else "" + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) { + withTable(tableName) { + sql(s"CREATE TABLE $tableName (geom GEOMETRY(0), geog GEOGRAPHY(4326)) USING PARQUET") + + val geomNull = "ST_GeomFromWKB(NULL)" + val geomNotNull = s"ST_GeomFromWKB(X'$wkbString1')" + val geogNull = "ST_GeogFromWKB(NULL)" + val geogNotNull = s"ST_GeogFromWKB(X'$wkbString2')" + + sql(s"INSERT INTO $tableName VALUES ($geomNull, $geogNull)") + sql(s"INSERT INTO $tableName VALUES ($geomNotNull, $geogNull)") + sql(s"INSERT INTO $tableName VALUES ($geomNull, $geogNotNull)") + sql(s"INSERT INTO $tableName VALUES ($geomNotNull, $geogNotNull)") + + checkAnswer(sql(s"SELECT COUNT(*) FROM $tableName"), Seq(Row(4))) + } + } + } + } + /** Geospatial type casting. */ test("Cast GEOGRAPHY(srid) to GEOGRAPHY(ANY)") {