Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 164 additions & 89 deletions datafusion/expr-common/src/interval_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,31 +646,25 @@ impl Interval {
/// Compute the intersection of this interval with the given interval.
/// If the intersection is empty, return `None`.
///
/// NOTE: This function only works with intervals of the same data type.
/// Attempting to compare intervals of different data types will lead
/// to an error.
/// If the two intervals have different data types, both are coerced to a
/// common comparison type via [`comparison_coercion`] before computing the
/// intersection.
pub fn intersect<T: Borrow<Self>>(&self, other: T) -> Result<Option<Self>> {
let rhs = other.borrow();
let lhs_type = self.data_type();
let rhs_type = rhs.data_type();
assert_eq_or_internal_err!(
lhs_type,
rhs_type,
"Only intervals with the same data type are intersectable, lhs:{}, rhs:{}",
self.data_type(),
rhs.data_type()
);
let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?;
let lhs = lhs_owned.as_ref().unwrap_or(self);
let rhs = rhs_owned.as_ref().unwrap_or(rhs);

// If it is evident that the result is an empty interval, short-circuit
// and directly return `None`.
if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper)
|| (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower)
if (!(lhs.lower.is_null() || rhs.upper.is_null()) && lhs.lower > rhs.upper)
|| (!(lhs.upper.is_null() || rhs.lower.is_null()) && lhs.upper < rhs.lower)
{
return Ok(None);
}

let lower = max_of_bounds(&self.lower, &rhs.lower);
let upper = min_of_bounds(&self.upper, &rhs.upper);
let lower = max_of_bounds(&lhs.lower, &rhs.lower);
let upper = min_of_bounds(&lhs.upper, &rhs.upper);

// New lower and upper bounds must always construct a valid interval.
debug_assert!(
Expand All @@ -683,35 +677,27 @@ impl Interval {

/// Compute the union of this interval with the given interval.
///
/// NOTE: This function only works with intervals of the same data type.
/// Attempting to compare intervals of different data types will lead
/// to an error.
/// If the two intervals have different data types, both are coerced to a
/// common comparison type via [`comparison_coercion`] before computing the
/// union.
pub fn union<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
let rhs = other.borrow();
let lhs_type = self.data_type();
let rhs_type = rhs.data_type();
assert_eq_or_internal_err!(
lhs_type,
rhs_type,
"Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}",
self.data_type(),
rhs.data_type()
);
let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?;
let lhs = lhs_owned.as_ref().unwrap_or(self);
let rhs = rhs_owned.as_ref().unwrap_or(rhs);

let lower = if self.lower.is_null()
|| (!rhs.lower.is_null() && self.lower <= rhs.lower)
{
self.lower.clone()
} else {
rhs.lower.clone()
};
let upper = if self.upper.is_null()
|| (!rhs.upper.is_null() && self.upper >= rhs.upper)
{
self.upper.clone()
} else {
rhs.upper.clone()
};
let lower =
if lhs.lower.is_null() || (!rhs.lower.is_null() && lhs.lower <= rhs.lower) {
lhs.lower.clone()
} else {
rhs.lower.clone()
};
let upper =
if lhs.upper.is_null() || (!rhs.upper.is_null() && lhs.upper >= rhs.upper) {
lhs.upper.clone()
} else {
rhs.upper.clone()
};

// New lower and upper bounds must always construct a valid interval.
debug_assert!(
Expand Down Expand Up @@ -754,22 +740,16 @@ impl Interval {
/// disjoint with `other` by returning `[true, true]`, `[false, true]` or
/// `[false, false]` respectively.
///
/// NOTE: This function only works with intervals of the same data type.
/// Attempting to compare intervals of different data types will lead
/// to an error.
/// If the two intervals have different data types, both are coerced to a
/// common comparison type via [`comparison_coercion`] before checking
/// containment.
pub fn contains<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
let rhs = other.borrow();
let lhs_type = self.data_type();
let rhs_type = rhs.data_type();
assert_eq_or_internal_err!(
lhs_type,
rhs_type,
"Interval data types must match for containment checks, lhs:{}, rhs:{}",
self.data_type(),
rhs.data_type()
);
let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?;
let lhs = lhs_owned.as_ref().unwrap_or(self);
let rhs = rhs_owned.as_ref().unwrap_or(rhs);

match self.intersect(rhs)? {
match lhs.intersect(rhs)? {
Some(intersection) => {
if &intersection == rhs {
Ok(Self::TRUE)
Expand Down Expand Up @@ -830,36 +810,29 @@ impl Interval {
/// Note that this represents all possible values the product can take if
/// one can choose single values arbitrarily from each of the operands.
///
/// NOTE: This function only works with intervals of the same data type.
/// Attempting to compare intervals of different data types will lead
/// to an error.
/// If the two intervals have different data types, both are coerced to a
/// common type via [`BinaryTypeCoercer`] before computing the product.
pub fn mul<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
let rhs = other.borrow();
let dt = self.data_type();
let rhs_type = rhs.data_type();
assert_eq_or_internal_err!(
dt.clone(),
rhs_type.clone(),
"Intervals must have the same data type for multiplication, lhs:{}, rhs:{}",
dt.clone(),
rhs_type.clone()
);
let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Multiply)?;
let lhs_ref = lhs_owned.as_ref().unwrap_or(self);
let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs);

let zero = ScalarValue::new_zero(&dt)?;

let result = match (
self.contains_value(&zero)?,
rhs.contains_value(&zero)?,
lhs_ref.contains_value(&zero)?,
rhs_ref.contains_value(&zero)?,
dt.is_unsigned_integer(),
) {
(true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs),
(true, true, false) => mul_helper_multi_zero_inclusive(&dt, lhs_ref, rhs_ref),
(true, false, false) => {
mul_helper_single_zero_inclusive(&dt, self, rhs, &zero)
mul_helper_single_zero_inclusive(&dt, lhs_ref, rhs_ref, &zero)
}
(false, true, false) => {
mul_helper_single_zero_inclusive(&dt, rhs, self, &zero)
mul_helper_single_zero_inclusive(&dt, rhs_ref, lhs_ref, &zero)
}
_ => mul_helper_zero_exclusive(&dt, self, rhs, &zero),
_ => mul_helper_zero_exclusive(&dt, lhs_ref, rhs_ref, &zero),
};
Ok(result)
}
Expand All @@ -870,23 +843,16 @@ impl Interval {
/// all possible values the quotient can take if one can choose single values
/// arbitrarily from each of the operands.
///
/// NOTE: This function only works with intervals of the same data type.
/// Attempting to compare intervals of different data types will lead
/// to an error.
/// If the two intervals have different data types, both are coerced to a
/// common type via [`BinaryTypeCoercer`] before computing the quotient.
///
/// **TODO**: Once interval sets are supported, cases where the divisor contains
/// zero should result in an interval set, not the universal set.
pub fn div<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
let rhs = other.borrow();
let dt = self.data_type();
let rhs_type = rhs.data_type();
assert_eq_or_internal_err!(
dt.clone(),
rhs_type.clone(),
"Intervals must have the same data type for division, lhs:{}, rhs:{}",
dt.clone(),
rhs_type.clone()
);
let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Divide)?;
let lhs_ref = lhs_owned.as_ref().unwrap_or(self);
let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs);

let zero = ScalarValue::new_zero(&dt)?;
// We want 0 to be approachable from both negative and positive sides.
Expand All @@ -897,15 +863,27 @@ impl Interval {

// Exit early with an unbounded interval if zero is strictly inside the
// right hand side:
if rhs.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() {
if rhs_ref.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() {
Self::make_unbounded(&dt)
}
// At this point, we know that only one endpoint of the right hand side
// can be zero.
else if self.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() {
Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point))
else if lhs_ref.contains(&zero_point)? == Self::TRUE
&& !dt.is_unsigned_integer()
{
Ok(div_helper_lhs_zero_inclusive(
&dt,
lhs_ref,
rhs_ref,
&zero_point,
))
} else {
Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point))
Ok(div_helper_zero_exclusive(
&dt,
lhs_ref,
rhs_ref,
&zero_point,
))
}
}

Expand Down Expand Up @@ -1000,6 +978,70 @@ impl From<&ScalarValue> for Interval {
}
}

/// Coerces two intervals to a common comparison type so that lower/upper
/// bounds from each can be compared directly.
///
/// Returns `(coerced_lhs, coerced_rhs)` where each is `Some(...)` if a cast
/// was required and `None` otherwise. Returns an internal error if the two
/// types cannot be unified for comparison.
fn coerce_for_comparison(
lhs: &Interval,
rhs: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
let lhs_type = lhs.data_type();
let rhs_type = rhs.data_type();
if lhs_type == rhs_type {
return Ok((None, None));
}
let maybe_common = comparison_coercion(&lhs_type, &rhs_type);
assert_or_internal_err!(
maybe_common.is_some(),
"Data types must be compatible for interval comparison, lhs:{}, rhs:{}",
lhs_type,
rhs_type
);
let common = maybe_common.expect("checked for Some");
let cast_options = CastOptions::default();
let new_lhs = (lhs_type != common)
.then(|| lhs.cast_to(&common, &cast_options))
.transpose()?;
let new_rhs = (rhs_type != common)
.then(|| rhs.cast_to(&common, &cast_options))
.transpose()?;
Ok((new_lhs, new_rhs))
}

/// Coerces two intervals to a common type for the given binary `op` so that
/// downstream interval helpers can operate on a single, consistent data type.
///
/// Returns `(coerced_lhs, coerced_rhs, common_type)`. Each `coerced_*` is
/// `Some(...)` when a cast was required, and `None` when the original interval
/// already had the common type (the caller should use the original in that
/// case). The returned `common_type` is the type both (possibly cast) operands
/// share, taken from [`BinaryTypeCoercer::get_result_type`] — this mirrors
/// what arrow's numeric kernels would produce when computing the operation.
fn coerce_operands(
lhs: &Interval,
rhs: &Interval,
op: &Operator,
) -> Result<(Option<Interval>, Option<Interval>, DataType)> {
let lhs_type = lhs.data_type();
let rhs_type = rhs.data_type();
if lhs_type == rhs_type {
return Ok((None, None, lhs_type));
}
let common_type =
BinaryTypeCoercer::new(&lhs_type, op, &rhs_type).get_result_type()?;
let cast_options = CastOptions::default();
let new_lhs = (lhs_type != common_type)
.then(|| lhs.cast_to(&common_type, &cast_options))
.transpose()?;
let new_rhs = (rhs_type != common_type)
.then(|| rhs.cast_to(&common_type, &cast_options))
.transpose()?;
Ok((new_lhs, new_rhs, common_type))
}

/// Applies the given binary operator the `lhs` and `rhs` arguments.
pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result<Interval> {
match *op {
Expand Down Expand Up @@ -3794,6 +3836,39 @@ mod tests {
Ok(())
}

#[test]
fn test_mul_div_mismatched_operand_types() -> Result<()> {
// Regression test: previously `Interval::div` and `Interval::mul`
// asserted that both operands had identical data types. That broke
// interval propagation for queries like `numeric / count(*)` where
// the operands end up as different `Decimal128` precisions/scales.
// Now both operations coerce to a common type via `BinaryTypeCoercer`.

// `Decimal128(38, 10)` / `Decimal128(20, 0)` — the shape produced when
// dividing an unqualified `NUMERIC` by an `Int64` (e.g. `count(*)`).
let lhs = Interval::try_new(
ScalarValue::Decimal128(Some(0), 38, 10),
ScalarValue::Decimal128(Some(100_000_000_000), 38, 10), // 10.0
)?;
let rhs = Interval::try_new(
ScalarValue::Decimal128(Some(1), 20, 0),
ScalarValue::Decimal128(Some(10), 20, 0),
)?;
let div_result = lhs.div(&rhs)?;
assert!(matches!(div_result.data_type(), DataType::Decimal128(_, _)));
let mul_result = lhs.mul(&rhs)?;
assert!(matches!(mul_result.data_type(), DataType::Decimal128(_, _)));

// Cross-type Decimal128 / Int64 also goes through coercion.
let int_rhs = Interval::make(Some(1_i64), Some(10_i64))?;
let div_int = lhs.div(&int_rhs)?;
assert!(matches!(div_int.data_type(), DataType::Decimal128(_, _)));
let mul_int = lhs.mul(&int_rhs)?;
assert!(matches!(mul_int.data_type(), DataType::Decimal128(_, _)));

Ok(())
}

#[test]
fn test_overflow_handling() -> Result<()> {
// Test integer overflow handling:
Expand Down
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1235,3 +1235,28 @@ NULL

query error Arrow error: Invalid argument error: 1.10 is too large to store in a Decimal128 of precision 2. Max is 0.99
select cast(1.1 as decimal(2, 2)) + 1;

# Regression test for division of two `Decimal128` values with different
# precision/scale. The combined CASE/WHERE shape triggers interval propagation
# through the optimizer; previously this hit an internal assertion in
# `Interval::div` because `numeric / count(*)` gives `Decimal128(38, 10)` /
# `Decimal128(20, 0)`.
statement ok
CREATE TABLE decimal_div_mismatch(c1 BIGINT) AS VALUES (1::bigint), (2::bigint), (10::bigint);

query IR
SELECT
c1,
CASE WHEN c1 = 0 THEN 100.0
ELSE ROUND((1.0 - (c1::numeric / c1)) * 100, 2)
END AS rate
FROM decimal_div_mismatch
WHERE (1.0 - (c1::numeric / c1)) * 100 < 95.0
ORDER BY c1;
----
1 0
2 0
10 0

statement ok
DROP TABLE decimal_div_mismatch;
Loading