From 097259d4679993aae233d387ea66552e943e92ad Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 5 May 2026 14:12:23 -0500 Subject: [PATCH] fix: coerce operand types in `Interval::mul`/`div`/`intersect`/`union`/`contains` Previously these methods asserted that both intervals shared an identical data type, causing internal errors during interval propagation for queries like `numeric / count(*)` where the operands end up as different `Decimal128` precisions/scales (e.g. `Decimal128(38, 10) / Decimal128(20, 0)`). `Interval::add` and `Interval::sub` already used `BinaryTypeCoercer` to find a common arithmetic type. This change brings `mul`/`div` in line, and similarly relaxes `intersect`/`union`/`contains` to coerce via `comparison_coercion`, since CP-solver propagation feeds the result of an arithmetic op into `intersect` with a child interval whose type may differ. Adds a unit test in `interval_arithmetic.rs` for mismatched `Decimal128` mul/div, and an end-to-end `decimal.slt` regression covering the `numeric / bigint` shape from the failing customer query. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../expr-common/src/interval_arithmetic.rs | 253 ++++++++++++------ .../sqllogictest/test_files/decimal.slt | 25 ++ 2 files changed, 189 insertions(+), 89 deletions(-) diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 71b150eb92c94..097d5b255520d 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -646,31 +646,25 @@ impl Interval { /// Compute the intersection of this interval with the given interval. /// If the intersection is empty, return `None`. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common comparison type via [`comparison_coercion`] before computing the + /// intersection. pub fn intersect>(&self, other: T) -> Result> { let rhs = other.borrow(); - let lhs_type = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - lhs_type, - rhs_type, - "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); + let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?; + let lhs = lhs_owned.as_ref().unwrap_or(self); + let rhs = rhs_owned.as_ref().unwrap_or(rhs); // If it is evident that the result is an empty interval, short-circuit // and directly return `None`. - if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) - || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + if (!(lhs.lower.is_null() || rhs.upper.is_null()) && lhs.lower > rhs.upper) + || (!(lhs.upper.is_null() || rhs.lower.is_null()) && lhs.upper < rhs.lower) { return Ok(None); } - let lower = max_of_bounds(&self.lower, &rhs.lower); - let upper = min_of_bounds(&self.upper, &rhs.upper); + let lower = max_of_bounds(&lhs.lower, &rhs.lower); + let upper = min_of_bounds(&lhs.upper, &rhs.upper); // New lower and upper bounds must always construct a valid interval. debug_assert!( @@ -683,35 +677,27 @@ impl Interval { /// Compute the union of this interval with the given interval. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common comparison type via [`comparison_coercion`] before computing the + /// union. pub fn union>(&self, other: T) -> Result { let rhs = other.borrow(); - let lhs_type = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - lhs_type, - rhs_type, - "Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); + let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?; + let lhs = lhs_owned.as_ref().unwrap_or(self); + let rhs = rhs_owned.as_ref().unwrap_or(rhs); - let lower = if self.lower.is_null() - || (!rhs.lower.is_null() && self.lower <= rhs.lower) - { - self.lower.clone() - } else { - rhs.lower.clone() - }; - let upper = if self.upper.is_null() - || (!rhs.upper.is_null() && self.upper >= rhs.upper) - { - self.upper.clone() - } else { - rhs.upper.clone() - }; + let lower = + if lhs.lower.is_null() || (!rhs.lower.is_null() && lhs.lower <= rhs.lower) { + lhs.lower.clone() + } else { + rhs.lower.clone() + }; + let upper = + if lhs.upper.is_null() || (!rhs.upper.is_null() && lhs.upper >= rhs.upper) { + lhs.upper.clone() + } else { + rhs.upper.clone() + }; // New lower and upper bounds must always construct a valid interval. debug_assert!( @@ -754,22 +740,16 @@ impl Interval { /// disjoint with `other` by returning `[true, true]`, `[false, true]` or /// `[false, false]` respectively. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common comparison type via [`comparison_coercion`] before checking + /// containment. pub fn contains>(&self, other: T) -> Result { let rhs = other.borrow(); - let lhs_type = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - lhs_type, - rhs_type, - "Interval data types must match for containment checks, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); + let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?; + let lhs = lhs_owned.as_ref().unwrap_or(self); + let rhs = rhs_owned.as_ref().unwrap_or(rhs); - match self.intersect(rhs)? { + match lhs.intersect(rhs)? { Some(intersection) => { if &intersection == rhs { Ok(Self::TRUE) @@ -830,36 +810,29 @@ impl Interval { /// Note that this represents all possible values the product can take if /// one can choose single values arbitrarily from each of the operands. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common type via [`BinaryTypeCoercer`] before computing the product. pub fn mul>(&self, other: T) -> Result { let rhs = other.borrow(); - let dt = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - dt.clone(), - rhs_type.clone(), - "Intervals must have the same data type for multiplication, lhs:{}, rhs:{}", - dt.clone(), - rhs_type.clone() - ); + let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Multiply)?; + let lhs_ref = lhs_owned.as_ref().unwrap_or(self); + let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs); let zero = ScalarValue::new_zero(&dt)?; let result = match ( - self.contains_value(&zero)?, - rhs.contains_value(&zero)?, + lhs_ref.contains_value(&zero)?, + rhs_ref.contains_value(&zero)?, dt.is_unsigned_integer(), ) { - (true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs), + (true, true, false) => mul_helper_multi_zero_inclusive(&dt, lhs_ref, rhs_ref), (true, false, false) => { - mul_helper_single_zero_inclusive(&dt, self, rhs, &zero) + mul_helper_single_zero_inclusive(&dt, lhs_ref, rhs_ref, &zero) } (false, true, false) => { - mul_helper_single_zero_inclusive(&dt, rhs, self, &zero) + mul_helper_single_zero_inclusive(&dt, rhs_ref, lhs_ref, &zero) } - _ => mul_helper_zero_exclusive(&dt, self, rhs, &zero), + _ => mul_helper_zero_exclusive(&dt, lhs_ref, rhs_ref, &zero), }; Ok(result) } @@ -870,23 +843,16 @@ impl Interval { /// all possible values the quotient can take if one can choose single values /// arbitrarily from each of the operands. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common type via [`BinaryTypeCoercer`] before computing the quotient. /// /// **TODO**: Once interval sets are supported, cases where the divisor contains /// zero should result in an interval set, not the universal set. pub fn div>(&self, other: T) -> Result { let rhs = other.borrow(); - let dt = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - dt.clone(), - rhs_type.clone(), - "Intervals must have the same data type for division, lhs:{}, rhs:{}", - dt.clone(), - rhs_type.clone() - ); + let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Divide)?; + let lhs_ref = lhs_owned.as_ref().unwrap_or(self); + let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs); let zero = ScalarValue::new_zero(&dt)?; // We want 0 to be approachable from both negative and positive sides. @@ -897,15 +863,27 @@ impl Interval { // Exit early with an unbounded interval if zero is strictly inside the // right hand side: - if rhs.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() { + if rhs_ref.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() { Self::make_unbounded(&dt) } // At this point, we know that only one endpoint of the right hand side // can be zero. - else if self.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() { - Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point)) + else if lhs_ref.contains(&zero_point)? == Self::TRUE + && !dt.is_unsigned_integer() + { + Ok(div_helper_lhs_zero_inclusive( + &dt, + lhs_ref, + rhs_ref, + &zero_point, + )) } else { - Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point)) + Ok(div_helper_zero_exclusive( + &dt, + lhs_ref, + rhs_ref, + &zero_point, + )) } } @@ -1000,6 +978,70 @@ impl From<&ScalarValue> for Interval { } } +/// Coerces two intervals to a common comparison type so that lower/upper +/// bounds from each can be compared directly. +/// +/// Returns `(coerced_lhs, coerced_rhs)` where each is `Some(...)` if a cast +/// was required and `None` otherwise. Returns an internal error if the two +/// types cannot be unified for comparison. +fn coerce_for_comparison( + lhs: &Interval, + rhs: &Interval, +) -> Result<(Option, Option)> { + let lhs_type = lhs.data_type(); + let rhs_type = rhs.data_type(); + if lhs_type == rhs_type { + return Ok((None, None)); + } + let maybe_common = comparison_coercion(&lhs_type, &rhs_type); + assert_or_internal_err!( + maybe_common.is_some(), + "Data types must be compatible for interval comparison, lhs:{}, rhs:{}", + lhs_type, + rhs_type + ); + let common = maybe_common.expect("checked for Some"); + let cast_options = CastOptions::default(); + let new_lhs = (lhs_type != common) + .then(|| lhs.cast_to(&common, &cast_options)) + .transpose()?; + let new_rhs = (rhs_type != common) + .then(|| rhs.cast_to(&common, &cast_options)) + .transpose()?; + Ok((new_lhs, new_rhs)) +} + +/// Coerces two intervals to a common type for the given binary `op` so that +/// downstream interval helpers can operate on a single, consistent data type. +/// +/// Returns `(coerced_lhs, coerced_rhs, common_type)`. Each `coerced_*` is +/// `Some(...)` when a cast was required, and `None` when the original interval +/// already had the common type (the caller should use the original in that +/// case). The returned `common_type` is the type both (possibly cast) operands +/// share, taken from [`BinaryTypeCoercer::get_result_type`] — this mirrors +/// what arrow's numeric kernels would produce when computing the operation. +fn coerce_operands( + lhs: &Interval, + rhs: &Interval, + op: &Operator, +) -> Result<(Option, Option, DataType)> { + let lhs_type = lhs.data_type(); + let rhs_type = rhs.data_type(); + if lhs_type == rhs_type { + return Ok((None, None, lhs_type)); + } + let common_type = + BinaryTypeCoercer::new(&lhs_type, op, &rhs_type).get_result_type()?; + let cast_options = CastOptions::default(); + let new_lhs = (lhs_type != common_type) + .then(|| lhs.cast_to(&common_type, &cast_options)) + .transpose()?; + let new_rhs = (rhs_type != common_type) + .then(|| rhs.cast_to(&common_type, &cast_options)) + .transpose()?; + Ok((new_lhs, new_rhs, common_type)) +} + /// Applies the given binary operator the `lhs` and `rhs` arguments. pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { @@ -3794,6 +3836,39 @@ mod tests { Ok(()) } + #[test] + fn test_mul_div_mismatched_operand_types() -> Result<()> { + // Regression test: previously `Interval::div` and `Interval::mul` + // asserted that both operands had identical data types. That broke + // interval propagation for queries like `numeric / count(*)` where + // the operands end up as different `Decimal128` precisions/scales. + // Now both operations coerce to a common type via `BinaryTypeCoercer`. + + // `Decimal128(38, 10)` / `Decimal128(20, 0)` — the shape produced when + // dividing an unqualified `NUMERIC` by an `Int64` (e.g. `count(*)`). + let lhs = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(100_000_000_000), 38, 10), // 10.0 + )?; + let rhs = Interval::try_new( + ScalarValue::Decimal128(Some(1), 20, 0), + ScalarValue::Decimal128(Some(10), 20, 0), + )?; + let div_result = lhs.div(&rhs)?; + assert!(matches!(div_result.data_type(), DataType::Decimal128(_, _))); + let mul_result = lhs.mul(&rhs)?; + assert!(matches!(mul_result.data_type(), DataType::Decimal128(_, _))); + + // Cross-type Decimal128 / Int64 also goes through coercion. + let int_rhs = Interval::make(Some(1_i64), Some(10_i64))?; + let div_int = lhs.div(&int_rhs)?; + assert!(matches!(div_int.data_type(), DataType::Decimal128(_, _))); + let mul_int = lhs.mul(&int_rhs)?; + assert!(matches!(mul_int.data_type(), DataType::Decimal128(_, _))); + + Ok(()) + } + #[test] fn test_overflow_handling() -> Result<()> { // Test integer overflow handling: diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 5485a5fd30141..5faf801c84652 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -1235,3 +1235,28 @@ NULL query error Arrow error: Invalid argument error: 1.10 is too large to store in a Decimal128 of precision 2. Max is 0.99 select cast(1.1 as decimal(2, 2)) + 1; + +# Regression test for division of two `Decimal128` values with different +# precision/scale. The combined CASE/WHERE shape triggers interval propagation +# through the optimizer; previously this hit an internal assertion in +# `Interval::div` because `numeric / count(*)` gives `Decimal128(38, 10)` / +# `Decimal128(20, 0)`. +statement ok +CREATE TABLE decimal_div_mismatch(c1 BIGINT) AS VALUES (1::bigint), (2::bigint), (10::bigint); + +query IR +SELECT + c1, + CASE WHEN c1 = 0 THEN 100.0 + ELSE ROUND((1.0 - (c1::numeric / c1)) * 100, 2) + END AS rate +FROM decimal_div_mismatch +WHERE (1.0 - (c1::numeric / c1)) * 100 < 95.0 +ORDER BY c1; +---- +1 0 +2 0 +10 0 + +statement ok +DROP TABLE decimal_div_mismatch;