Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 133 additions & 44 deletions spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, FromUTCTimestamp, GetDateField, GetTimestamp, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, Expression, FromUTCTimestamp, GetDateField, GetTimestamp, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -32,6 +32,7 @@ import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.CometGetDateField.CometGetDateField
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometTypeShim

private object CometGetDateField extends Enumeration {
type CometGetDateField = Value
Expand Down Expand Up @@ -289,13 +290,26 @@ object CometSecond extends CometExpressionSerde[Second] with CodegenDispatchFall
}
}

private[serde] object DatetimeCollation extends CometTypeShim {
def reason(functionName: String): String =
s"$functionName does not support non-UTF8_BINARY collations " +
"(https://github.com/apache/datafusion-comet/issues/4646)"

def hasNonDefaultCollation(expr: Expression): Boolean =
expr.children.exists(c => hasNonDefaultStringCollation(c.dataType))
}

object CometUnixTimestamp extends CometExpressionSerde[UnixTimestamp] {

private val collationReason = DatetimeCollation.reason("unix_timestamp")

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only `TimestampType` and `DateType` inputs are supported." +
" `TimestampNTZType` is not supported because Comet incorrectly applies timezone" +
" conversion to TimestampNTZ values.")

override def getIncompatibleReasons(): Seq[String] = Seq(collationReason)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: this incompatibility will get documented for all Spark versions even though it is specific to Spark 4.x

I wonder if we can shim getIncompatibleReasons so it only applies for 4.x?


private def isSupportedInputType(expr: UnixTimestamp): Boolean = {
expr.children.head.dataType match {
case TimestampType | DateType => true
Expand All @@ -305,7 +319,9 @@ object CometUnixTimestamp extends CometExpressionSerde[UnixTimestamp] {
}

override def getSupportLevel(expr: UnixTimestamp): SupportLevel = {
if (isSupportedInputType(expr)) {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else if (isSupportedInputType(expr)) {
Compatible()
} else {
val inputType = expr.children.head.dataType
Expand Down Expand Up @@ -401,11 +417,18 @@ object CometConvertTimezone
extends CometExpressionSerde[ConvertTimezone]
with CodegenDispatchFallback {

override def getSupportLevel(expr: ConvertTimezone): SupportLevel =
Incompatible(Some(UTCTimestampSerde.tzParseIncompatReason))
private val collationReason = DatetimeCollation.reason("convert_timezone")

override def getSupportLevel(expr: ConvertTimezone): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Incompatible(Some(UTCTimestampSerde.tzParseIncompatReason))
}
}

override def getIncompatibleReasons(): Seq[String] =
Seq(UTCTimestampSerde.tzParseIncompatReason)
Seq(UTCTimestampSerde.tzParseIncompatReason, collationReason)

override def convert(
expr: ConvertTimezone,
Expand All @@ -427,6 +450,17 @@ object CometNextDay extends CometExpressionSerde[NextDay] {
* `dayOfWeek` rather than returning NULL. The resolved flag is passed to native via the
* `ScalarFunc.fail_on_error` field.
*/
private val collationReason = DatetimeCollation.reason("next_day")

override def getIncompatibleReasons(): Seq[String] = Seq(collationReason)

override def getSupportLevel(expr: NextDay): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}
override def convert(expr: NextDay, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
val optExpr = scalarFunctionExprToProtoWithReturnType(
Expand Down Expand Up @@ -508,28 +542,35 @@ object CometTruncDate extends CometExpressionSerde[TruncDate] with CodegenDispat
val supportedFormats: Seq[String] =
Seq("year", "yyyy", "yy", "quarter", "mon", "month", "mm", "week")

private val collationReason = DatetimeCollation.reason("trunc")

private val nonLiteralFormatIncompatReason: String =
"Non-literal format strings will throw an exception instead of returning NULL"

private def unsupportedFormatReason(fmt: Any): String =
s"Format $fmt is not supported. Only the following formats are supported: " +
supportedFormats.mkString(", ")

override def getIncompatibleReasons(): Seq[String] = Seq(nonLiteralFormatIncompatReason)
override def getIncompatibleReasons(): Seq[String] =
Seq(nonLiteralFormatIncompatReason, collationReason)

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only the following formats are supported: " + supportedFormats.mkString(", "))

override def getSupportLevel(expr: TruncDate): SupportLevel = {
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
Compatible()
} else {
Unsupported(Some(unsupportedFormatReason(fmt)))
}
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
Compatible()
} else {
Unsupported(Some(unsupportedFormatReason(fmt)))
}
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
}
}
}

Expand Down Expand Up @@ -572,6 +613,8 @@ object CometTruncTimestamp
"millisecond",
"microsecond")

private val collationReason = DatetimeCollation.reason("date_trunc")

private val nonUtcIncompatReason: String =
"Produces incorrect results when used with non-UTC timezones. Compatible when timezone is" +
" UTC. (https://github.com/apache/datafusion-comet/issues/2649)"
Expand All @@ -584,27 +627,31 @@ object CometTruncTimestamp
supportedFormats.mkString(", ")

override def getIncompatibleReasons(): Seq[String] =
Seq(nonUtcIncompatReason, nonLiteralFormatIncompatReason)
Seq(nonUtcIncompatReason, nonLiteralFormatIncompatReason, collationReason)

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only the following formats are supported: " + supportedFormats.mkString(", "))

override def getSupportLevel(expr: TruncTimestamp): SupportLevel = {
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
if (isUtc) {
Compatible()
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
if (isUtc) {
Compatible()
} else {
Incompatible(Some(nonUtcIncompatReason))
}
} else {
Incompatible(Some(nonUtcIncompatReason))
Unsupported(Some(unsupportedFormatReason(fmt)))
}
} else {
Unsupported(Some(unsupportedFormatReason(fmt)))
}
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
}
}
}

Expand Down Expand Up @@ -648,7 +695,10 @@ object CometTruncTimestamp
* by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When that flag is disabled the operator
* falls back to Spark.
*/
object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
object CometDateFormat
extends CometExpressionSerde[DateFormatClass]
with CodegenDispatchFallback
with CometTypeShim {

/**
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
Expand Down Expand Up @@ -686,18 +736,26 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
// ISO formats
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")

// Compatibility is decided inside `convert`: the native path covers a subset, and the codegen
// dispatcher covers everything else when enabled. Plan-time tagging happens via
// `withFallbackReason` on the path that returns None.
override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible()
private val collationReason = DatetimeCollation.reason("date_format")

override def getIncompatibleReasons(): Seq[String] = Seq(collationReason)

// Non-default collations return Incompatible; all other inputs are Compatible. In both cases
// convert() decides between the native to_char path and the codegen dispatcher.
override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}

override def getCompatibleNotes(): Seq[String] = Seq(
"Format strings in a curated allow-list run natively via DataFusion's `to_char` for UTC " +
"sessions. Other format strings (including non-literal formats), as well as non-UTC " +
"sessions, route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct " +
"codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the " +
"codegen dispatcher is disabled (default) the operator falls back to Spark in those " +
"cases.")
"sessions. Other format strings (including non-literal formats) and non-UTC sessions " +
"route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct codegen " +
"dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the codegen " +
"dispatcher is disabled (default) the operator falls back to Spark in those cases.")

override def convert(
expr: DateFormatClass,
Expand All @@ -711,9 +769,10 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
case _ => None
}

val canUseNative = nativeFormat.isDefined && {
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
}
val canUseNative = nativeFormat.isDefined &&
!expr.children.exists(c => hasNonDefaultStringCollation(c.dataType)) && {
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
}

if (canUseNative) {
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
Expand Down Expand Up @@ -829,7 +888,22 @@ object CometAddMonths extends CometCodegenDispatch[AddMonths]

object CometMonthsBetween extends CometCodegenDispatch[MonthsBetween]

object CometMakeTimestamp extends CometCodegenDispatch[MakeTimestamp]
object CometMakeTimestamp
extends CometCodegenDispatch[MakeTimestamp]
with CodegenDispatchFallback {

private val collationReason = DatetimeCollation.reason("make_timestamp")

override def getIncompatibleReasons(): Seq[String] = Seq(collationReason)

override def getSupportLevel(expr: MakeTimestamp): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}
}

object CometMicrosToTimestamp extends CometCodegenDispatch[MicrosToTimestamp]

Expand All @@ -841,6 +915,21 @@ object CometUnixMillis extends CometCodegenDispatch[UnixMillis]

object CometUnixMicros extends CometCodegenDispatch[UnixMicros]

object CometToUnixTimestamp extends CometCodegenDispatch[ToUnixTimestamp]
object CometToUnixTimestamp
extends CometCodegenDispatch[ToUnixTimestamp]
with CodegenDispatchFallback {

private val collationReason = DatetimeCollation.reason("to_unix_timestamp")

override def getIncompatibleReasons(): Seq[String] = Seq(collationReason)

override def getSupportLevel(expr: ToUnixTimestamp): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}
}

object CometGetTimestamp extends CometCodegenDispatch[GetTimestamp]
18 changes: 15 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/unixtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,24 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFa
// https://github.com/apache/datafusion/issues/16594
object CometFromUnixTime extends CometExpressionSerde[FromUnixTime] with CodegenDispatchFallback {

override def getIncompatibleReasons(): Seq[String] = Seq(
private val collationReason =
"from_unixtime does not support non-UTF8_BINARY collations " +
"(https://github.com/apache/datafusion-comet/issues/4646)"

private val formatReason =
"Only supports the default datetime format pattern `yyyy-MM-dd HH:mm:ss`." +
" DataFusion's valid timestamp range differs from Spark" +
" (https://github.com/apache/datafusion/issues/16594)")
" (https://github.com/apache/datafusion/issues/16594)"

override def getIncompatibleReasons(): Seq[String] = Seq(formatReason, collationReason)

override def getSupportLevel(expr: FromUnixTime): SupportLevel = Incompatible(None)
override def getSupportLevel(expr: FromUnixTime): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Incompatible(Some(formatReason))
}
}

override def convert(
expr: FromUnixTime,
Expand Down
Loading
Loading