Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/user-guide/latest/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ These settings can be used to determine which parts of the plan are accelerated
| `spark.comet.expression.CreateArray.enabled` | Enable Comet acceleration for `CreateArray` | true |
| `spark.comet.expression.CreateNamedStruct.enabled` | Enable Comet acceleration for `CreateNamedStruct` | true |
| `spark.comet.expression.DateAdd.enabled` | Enable Comet acceleration for `DateAdd` | true |
| `spark.comet.expression.DateDiff.enabled` | Enable Comet acceleration for `DateDiff` | true |
| `spark.comet.expression.DateSub.enabled` | Enable Comet acceleration for `DateSub` | true |
| `spark.comet.expression.DayOfMonth.enabled` | Enable Comet acceleration for `DayOfMonth` | true |
| `spark.comet.expression.DayOfWeek.enabled` | Enable Comet acceleration for `DayOfWeek` | true |
Expand Down
5 changes: 3 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc,
SparkStringSpace,
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateDiff, SparkDateTrunc,
SparkSizeFunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
vec![
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())),
Expand Down
104 changes: 104 additions & 0 deletions native/spark-expr/src/datetime_funcs/date_diff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, Date32Array, Int32Array};
use arrow::compute::kernels::arity::binary;
use arrow::datatypes::DataType;
use datafusion::common::{utils::take_function_args, DataFusionError, Result};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
use std::sync::Arc;

/// Spark-compatible date_diff function.
/// Returns the number of days from startDate to endDate (endDate - startDate).
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkDateDiff {
signature: Signature,
aliases: Vec<String>,
}

impl SparkDateDiff {
pub fn new() -> Self {
Self {
signature: Signature::exact(
vec![DataType::Date32, DataType::Date32],
Volatility::Immutable,
),
aliases: vec!["datediff".to_string()],
}
}
}

impl Default for SparkDateDiff {
fn default() -> Self {
Self::new()
}
}

impl ScalarUDFImpl for SparkDateDiff {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"date_diff"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [end_date, start_date] = take_function_args(self.name(), args.args)?;

// Convert scalars to arrays for uniform processing
let end_arr = end_date.into_array(1)?;
let start_arr = start_date.into_array(1)?;

let end_date_array = end_arr
.as_any()
.downcast_ref::<Date32Array>()
.ok_or_else(|| {
DataFusionError::Execution("date_diff expects Date32Array for end_date".to_string())
})?;

let start_date_array = start_arr
.as_any()
.downcast_ref::<Date32Array>()
.ok_or_else(|| {
DataFusionError::Execution(
"date_diff expects Date32Array for start_date".to_string(),
)
})?;

// Date32 stores days since epoch, so difference is just subtraction
let result: Int32Array =
binary(end_date_array, start_date_array, |end, start| end - start)?;

Ok(ColumnarValue::Array(Arc::new(result)))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/datetime_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
// specific language governing permissions and limitations
// under the License.

mod date_diff;
mod date_trunc;
mod extract_date_part;
mod timestamp_trunc;

pub use date_diff::SparkDateDiff;
pub use date_trunc::SparkDateTrunc;
pub use extract_date_part::SparkHour;
pub use extract_date_part::SparkMinute;
Expand Down
4 changes: 3 additions & 1 deletion native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ pub use comet_scalar_funcs::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
register_all_comet_functions,
};
pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr};
pub use datetime_funcs::{
SparkDateDiff, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr,
};
pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::{FromJson, ToJson};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ object QueryPlanSerde extends Logging with CometExprShim {

private val temporalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
classOf[DateAdd] -> CometDateAdd,
classOf[DateDiff] -> CometDateDiff,
classOf[DateSub] -> CometDateSub,
classOf[FromUnixTime] -> CometFromUnixTime,
classOf[Hour] -> CometHour,
Expand Down
4 changes: 3 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.types.{DateType, IntegerType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -258,6 +258,8 @@ object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")

object CometDateSub extends CometScalarFunction[DateSub]("date_sub")

object CometDateDiff extends CometScalarFunction[DateDiff]("date_diff")

object CometTruncDate extends CometExpressionSerde[TruncDate] {

val supportedFormats: Seq[String] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,38 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
StructField("fmt", DataTypes.StringType, true)))
FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
}

test("datediff") {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests using dates from leap years would be fun to try. Try this -

datediff( 1900-03-01, 1900-02-27) != datediff(2000-03-01, 2000-02-27)  

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I've added leap year edge case tests:

  • datediff('1900-03-01', '1900-02-27') = 2 days (1900 was NOT a leap year - divisible by 100 but not 400)
  • datediff('2000-03-01', '2000-02-27') = 3 days (2000 WAS a leap year - divisible by 400)
  • datediff('2004-03-01', '2004-02-28') = 2 days (2004 was a leap year - divisible by 4, not by 100)
  • datediff('2100-03-01', '2100-02-28') = 1 day (2100 will NOT be a leap year - divisible by 100 but not 400)

All tests pass.

val r = new Random(42)
val schema = StructType(
Seq(
StructField("c0", DataTypes.DateType, true),
StructField("c1", DataTypes.DateType, true)))
val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
df.createOrReplaceTempView("tbl")

// Basic test with random dates
checkSparkAnswerAndOperator("SELECT c0, c1, datediff(c0, c1) FROM tbl ORDER BY c0, c1")

// Disable constant folding to ensure literal expressions are executed by Comet
withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
// Test positive difference (end date > start date)
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-31'), DATE('2009-07-30'))")

// Test negative difference (end date < start date)
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-30'), DATE('2009-07-31'))")

// Test same dates (should be 0)
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-30'), DATE('2009-07-30'))")

// Test larger date differences
checkSparkAnswerAndOperator("SELECT datediff(DATE('2024-01-01'), DATE('2020-01-01'))")

// Test null handling
checkSparkAnswerAndOperator("SELECT datediff(NULL, DATE('2009-07-30'))")
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-30'), NULL)")
}
}
}
Loading