Skip to content

Commit 86c7d43

Browse files
adriangbclaude
andcommitted
fix: coerce operand types in Interval::mul/div/intersect/union/contains
Previously these methods asserted that both intervals shared an identical data type, causing internal errors during interval propagation for queries like `numeric / count(*)` where the operands end up as different `Decimal128` precisions/scales (e.g. `Decimal128(38, 10) / Decimal128(20, 0)`). `Interval::add` and `Interval::sub` already used `BinaryTypeCoercer` to find a common arithmetic type. This change brings `mul`/`div` in line, and similarly relaxes `intersect`/`union`/`contains` to coerce via `comparison_coercion`, since CP-solver propagation feeds the result of an arithmetic op into `intersect` with a child interval whose type may differ. Adds a unit test in `interval_arithmetic.rs` for mismatched `Decimal128` mul/div, and an end-to-end `decimal.slt` regression covering the `numeric / bigint` shape from the failing customer query. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f043092 commit 86c7d43

2 files changed

Lines changed: 189 additions & 89 deletions

File tree

datafusion/expr-common/src/interval_arithmetic.rs

Lines changed: 164 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -646,31 +646,25 @@ impl Interval {
646646
/// Compute the intersection of this interval with the given interval.
647647
/// If the intersection is empty, return `None`.
648648
///
649-
/// NOTE: This function only works with intervals of the same data type.
650-
/// Attempting to compare intervals of different data types will lead
651-
/// to an error.
649+
/// If the two intervals have different data types, both are coerced to a
650+
/// common comparison type via [`comparison_coercion`] before computing the
651+
/// intersection.
652652
pub fn intersect<T: Borrow<Self>>(&self, other: T) -> Result<Option<Self>> {
653653
let rhs = other.borrow();
654-
let lhs_type = self.data_type();
655-
let rhs_type = rhs.data_type();
656-
assert_eq_or_internal_err!(
657-
lhs_type,
658-
rhs_type,
659-
"Only intervals with the same data type are intersectable, lhs:{}, rhs:{}",
660-
self.data_type(),
661-
rhs.data_type()
662-
);
654+
let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?;
655+
let lhs = lhs_owned.as_ref().unwrap_or(self);
656+
let rhs = rhs_owned.as_ref().unwrap_or(rhs);
663657

664658
// If it is evident that the result is an empty interval, short-circuit
665659
// and directly return `None`.
666-
if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper)
667-
|| (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower)
660+
if (!(lhs.lower.is_null() || rhs.upper.is_null()) && lhs.lower > rhs.upper)
661+
|| (!(lhs.upper.is_null() || rhs.lower.is_null()) && lhs.upper < rhs.lower)
668662
{
669663
return Ok(None);
670664
}
671665

672-
let lower = max_of_bounds(&self.lower, &rhs.lower);
673-
let upper = min_of_bounds(&self.upper, &rhs.upper);
666+
let lower = max_of_bounds(&lhs.lower, &rhs.lower);
667+
let upper = min_of_bounds(&lhs.upper, &rhs.upper);
674668

675669
// New lower and upper bounds must always construct a valid interval.
676670
debug_assert!(
@@ -683,35 +677,27 @@ impl Interval {
683677

684678
/// Compute the union of this interval with the given interval.
685679
///
686-
/// NOTE: This function only works with intervals of the same data type.
687-
/// Attempting to compare intervals of different data types will lead
688-
/// to an error.
680+
/// If the two intervals have different data types, both are coerced to a
681+
/// common comparison type via [`comparison_coercion`] before computing the
682+
/// union.
689683
pub fn union<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
690684
let rhs = other.borrow();
691-
let lhs_type = self.data_type();
692-
let rhs_type = rhs.data_type();
693-
assert_eq_or_internal_err!(
694-
lhs_type,
695-
rhs_type,
696-
"Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}",
697-
self.data_type(),
698-
rhs.data_type()
699-
);
685+
let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?;
686+
let lhs = lhs_owned.as_ref().unwrap_or(self);
687+
let rhs = rhs_owned.as_ref().unwrap_or(rhs);
700688

701-
let lower = if self.lower.is_null()
702-
|| (!rhs.lower.is_null() && self.lower <= rhs.lower)
703-
{
704-
self.lower.clone()
705-
} else {
706-
rhs.lower.clone()
707-
};
708-
let upper = if self.upper.is_null()
709-
|| (!rhs.upper.is_null() && self.upper >= rhs.upper)
710-
{
711-
self.upper.clone()
712-
} else {
713-
rhs.upper.clone()
714-
};
689+
let lower =
690+
if lhs.lower.is_null() || (!rhs.lower.is_null() && lhs.lower <= rhs.lower) {
691+
lhs.lower.clone()
692+
} else {
693+
rhs.lower.clone()
694+
};
695+
let upper =
696+
if lhs.upper.is_null() || (!rhs.upper.is_null() && lhs.upper >= rhs.upper) {
697+
lhs.upper.clone()
698+
} else {
699+
rhs.upper.clone()
700+
};
715701

716702
// New lower and upper bounds must always construct a valid interval.
717703
debug_assert!(
@@ -754,22 +740,16 @@ impl Interval {
754740
/// disjoint with `other` by returning `[true, true]`, `[false, true]` or
755741
/// `[false, false]` respectively.
756742
///
757-
/// NOTE: This function only works with intervals of the same data type.
758-
/// Attempting to compare intervals of different data types will lead
759-
/// to an error.
743+
/// If the two intervals have different data types, both are coerced to a
744+
/// common comparison type via [`comparison_coercion`] before checking
745+
/// containment.
760746
pub fn contains<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
761747
let rhs = other.borrow();
762-
let lhs_type = self.data_type();
763-
let rhs_type = rhs.data_type();
764-
assert_eq_or_internal_err!(
765-
lhs_type,
766-
rhs_type,
767-
"Interval data types must match for containment checks, lhs:{}, rhs:{}",
768-
self.data_type(),
769-
rhs.data_type()
770-
);
748+
let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?;
749+
let lhs = lhs_owned.as_ref().unwrap_or(self);
750+
let rhs = rhs_owned.as_ref().unwrap_or(rhs);
771751

772-
match self.intersect(rhs)? {
752+
match lhs.intersect(rhs)? {
773753
Some(intersection) => {
774754
if &intersection == rhs {
775755
Ok(Self::TRUE)
@@ -830,36 +810,29 @@ impl Interval {
830810
/// Note that this represents all possible values the product can take if
831811
/// one can choose single values arbitrarily from each of the operands.
832812
///
833-
/// NOTE: This function only works with intervals of the same data type.
834-
/// Attempting to compare intervals of different data types will lead
835-
/// to an error.
813+
/// If the two intervals have different data types, both are coerced to a
814+
/// common type via [`BinaryTypeCoercer`] before computing the product.
836815
pub fn mul<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
837816
let rhs = other.borrow();
838-
let dt = self.data_type();
839-
let rhs_type = rhs.data_type();
840-
assert_eq_or_internal_err!(
841-
dt.clone(),
842-
rhs_type.clone(),
843-
"Intervals must have the same data type for multiplication, lhs:{}, rhs:{}",
844-
dt.clone(),
845-
rhs_type.clone()
846-
);
817+
let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Multiply)?;
818+
let lhs_ref = lhs_owned.as_ref().unwrap_or(self);
819+
let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs);
847820

848821
let zero = ScalarValue::new_zero(&dt)?;
849822

850823
let result = match (
851-
self.contains_value(&zero)?,
852-
rhs.contains_value(&zero)?,
824+
lhs_ref.contains_value(&zero)?,
825+
rhs_ref.contains_value(&zero)?,
853826
dt.is_unsigned_integer(),
854827
) {
855-
(true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs),
828+
(true, true, false) => mul_helper_multi_zero_inclusive(&dt, lhs_ref, rhs_ref),
856829
(true, false, false) => {
857-
mul_helper_single_zero_inclusive(&dt, self, rhs, &zero)
830+
mul_helper_single_zero_inclusive(&dt, lhs_ref, rhs_ref, &zero)
858831
}
859832
(false, true, false) => {
860-
mul_helper_single_zero_inclusive(&dt, rhs, self, &zero)
833+
mul_helper_single_zero_inclusive(&dt, rhs_ref, lhs_ref, &zero)
861834
}
862-
_ => mul_helper_zero_exclusive(&dt, self, rhs, &zero),
835+
_ => mul_helper_zero_exclusive(&dt, lhs_ref, rhs_ref, &zero),
863836
};
864837
Ok(result)
865838
}
@@ -870,23 +843,16 @@ impl Interval {
870843
/// all possible values the quotient can take if one can choose single values
871844
/// arbitrarily from each of the operands.
872845
///
873-
/// NOTE: This function only works with intervals of the same data type.
874-
/// Attempting to compare intervals of different data types will lead
875-
/// to an error.
846+
/// If the two intervals have different data types, both are coerced to a
847+
/// common type via [`BinaryTypeCoercer`] before computing the quotient.
876848
///
877849
/// **TODO**: Once interval sets are supported, cases where the divisor contains
878850
/// zero should result in an interval set, not the universal set.
879851
pub fn div<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
880852
let rhs = other.borrow();
881-
let dt = self.data_type();
882-
let rhs_type = rhs.data_type();
883-
assert_eq_or_internal_err!(
884-
dt.clone(),
885-
rhs_type.clone(),
886-
"Intervals must have the same data type for division, lhs:{}, rhs:{}",
887-
dt.clone(),
888-
rhs_type.clone()
889-
);
853+
let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Divide)?;
854+
let lhs_ref = lhs_owned.as_ref().unwrap_or(self);
855+
let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs);
890856

891857
let zero = ScalarValue::new_zero(&dt)?;
892858
// We want 0 to be approachable from both negative and positive sides.
@@ -897,15 +863,27 @@ impl Interval {
897863

898864
// Exit early with an unbounded interval if zero is strictly inside the
899865
// right hand side:
900-
if rhs.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() {
866+
if rhs_ref.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() {
901867
Self::make_unbounded(&dt)
902868
}
903869
// At this point, we know that only one endpoint of the right hand side
904870
// can be zero.
905-
else if self.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() {
906-
Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point))
871+
else if lhs_ref.contains(&zero_point)? == Self::TRUE
872+
&& !dt.is_unsigned_integer()
873+
{
874+
Ok(div_helper_lhs_zero_inclusive(
875+
&dt,
876+
lhs_ref,
877+
rhs_ref,
878+
&zero_point,
879+
))
907880
} else {
908-
Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point))
881+
Ok(div_helper_zero_exclusive(
882+
&dt,
883+
lhs_ref,
884+
rhs_ref,
885+
&zero_point,
886+
))
909887
}
910888
}
911889

@@ -1000,6 +978,70 @@ impl From<&ScalarValue> for Interval {
1000978
}
1001979
}
1002980

981+
/// Coerces two intervals to a common comparison type so that lower/upper
982+
/// bounds from each can be compared directly.
983+
///
984+
/// Returns `(coerced_lhs, coerced_rhs)` where each is `Some(...)` if a cast
985+
/// was required and `None` otherwise. Returns an internal error if the two
986+
/// types cannot be unified for comparison.
987+
fn coerce_for_comparison(
988+
lhs: &Interval,
989+
rhs: &Interval,
990+
) -> Result<(Option<Interval>, Option<Interval>)> {
991+
let lhs_type = lhs.data_type();
992+
let rhs_type = rhs.data_type();
993+
if lhs_type == rhs_type {
994+
return Ok((None, None));
995+
}
996+
let maybe_common = comparison_coercion(&lhs_type, &rhs_type);
997+
assert_or_internal_err!(
998+
maybe_common.is_some(),
999+
"Data types must be compatible for interval comparison, lhs:{}, rhs:{}",
1000+
lhs_type,
1001+
rhs_type
1002+
);
1003+
let common = maybe_common.expect("checked for Some");
1004+
let cast_options = CastOptions::default();
1005+
let new_lhs = (lhs_type != common)
1006+
.then(|| lhs.cast_to(&common, &cast_options))
1007+
.transpose()?;
1008+
let new_rhs = (rhs_type != common)
1009+
.then(|| rhs.cast_to(&common, &cast_options))
1010+
.transpose()?;
1011+
Ok((new_lhs, new_rhs))
1012+
}
1013+
1014+
/// Coerces two intervals to a common type for the given binary `op` so that
1015+
/// downstream interval helpers can operate on a single, consistent data type.
1016+
///
1017+
/// Returns `(coerced_lhs, coerced_rhs, common_type)`. Each `coerced_*` is
1018+
/// `Some(...)` when a cast was required, and `None` when the original interval
1019+
/// already had the common type (the caller should use the original in that
1020+
/// case). The returned `common_type` is the type both (possibly cast) operands
1021+
/// share, taken from [`BinaryTypeCoercer::get_result_type`] — this mirrors
1022+
/// what arrow's numeric kernels would produce when computing the operation.
1023+
fn coerce_operands(
1024+
lhs: &Interval,
1025+
rhs: &Interval,
1026+
op: &Operator,
1027+
) -> Result<(Option<Interval>, Option<Interval>, DataType)> {
1028+
let lhs_type = lhs.data_type();
1029+
let rhs_type = rhs.data_type();
1030+
if lhs_type == rhs_type {
1031+
return Ok((None, None, lhs_type));
1032+
}
1033+
let common_type =
1034+
BinaryTypeCoercer::new(&lhs_type, op, &rhs_type).get_result_type()?;
1035+
let cast_options = CastOptions::default();
1036+
let new_lhs = (lhs_type != common_type)
1037+
.then(|| lhs.cast_to(&common_type, &cast_options))
1038+
.transpose()?;
1039+
let new_rhs = (rhs_type != common_type)
1040+
.then(|| rhs.cast_to(&common_type, &cast_options))
1041+
.transpose()?;
1042+
Ok((new_lhs, new_rhs, common_type))
1043+
}
1044+
10031045
/// Applies the given binary operator the `lhs` and `rhs` arguments.
10041046
pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result<Interval> {
10051047
match *op {
@@ -3794,6 +3836,39 @@ mod tests {
37943836
Ok(())
37953837
}
37963838

3839+
#[test]
3840+
fn test_mul_div_mismatched_operand_types() -> Result<()> {
3841+
// Regression test: previously `Interval::div` and `Interval::mul`
3842+
// asserted that both operands had identical data types. That broke
3843+
// interval propagation for queries like `numeric / count(*)` where
3844+
// the operands end up as different `Decimal128` precisions/scales.
3845+
// Now both operations coerce to a common type via `BinaryTypeCoercer`.
3846+
3847+
// `Decimal128(38, 10)` / `Decimal128(20, 0)` — the shape from the
3848+
// failing alert query in pydantic/platform#21153.
3849+
let lhs = Interval::try_new(
3850+
ScalarValue::Decimal128(Some(0), 38, 10),
3851+
ScalarValue::Decimal128(Some(100_000_000_000), 38, 10), // 10.0
3852+
)?;
3853+
let rhs = Interval::try_new(
3854+
ScalarValue::Decimal128(Some(1), 20, 0),
3855+
ScalarValue::Decimal128(Some(10), 20, 0),
3856+
)?;
3857+
let div_result = lhs.div(&rhs)?;
3858+
assert!(matches!(div_result.data_type(), DataType::Decimal128(_, _)));
3859+
let mul_result = lhs.mul(&rhs)?;
3860+
assert!(matches!(mul_result.data_type(), DataType::Decimal128(_, _)));
3861+
3862+
// Cross-type Decimal128 / Int64 also goes through coercion.
3863+
let int_rhs = Interval::make(Some(1_i64), Some(10_i64))?;
3864+
let div_int = lhs.div(&int_rhs)?;
3865+
assert!(matches!(div_int.data_type(), DataType::Decimal128(_, _)));
3866+
let mul_int = lhs.mul(&int_rhs)?;
3867+
assert!(matches!(mul_int.data_type(), DataType::Decimal128(_, _)));
3868+
3869+
Ok(())
3870+
}
3871+
37973872
#[test]
37983873
fn test_overflow_handling() -> Result<()> {
37993874
// Test integer overflow handling:

datafusion/sqllogictest/test_files/decimal.slt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,3 +1235,28 @@ NULL
12351235

12361236
query error Arrow error: Invalid argument error: 1.10 is too large to store in a Decimal128 of precision 2. Max is 0.99
12371237
select cast(1.1 as decimal(2, 2)) + 1;
1238+
1239+
# Regression test for division of two `Decimal128` values with different
1240+
# precision/scale. The combined CASE/WHERE shape triggers interval propagation
1241+
# through the optimizer; previously this hit an internal assertion in
1242+
# `Interval::div` because `numeric / count(*)` gives `Decimal128(38, 10)` /
1243+
# `Decimal128(20, 0)`.
1244+
statement ok
1245+
CREATE TABLE decimal_div_mismatch(c1 BIGINT) AS VALUES (1::bigint), (2::bigint), (10::bigint);
1246+
1247+
query IR
1248+
SELECT
1249+
c1,
1250+
CASE WHEN c1 = 0 THEN 100.0
1251+
ELSE ROUND((1.0 - (c1::numeric / c1)) * 100, 2)
1252+
END AS rate
1253+
FROM decimal_div_mismatch
1254+
WHERE (1.0 - (c1::numeric / c1)) * 100 < 95.0
1255+
ORDER BY c1;
1256+
----
1257+
1 0
1258+
2 0
1259+
10 0
1260+
1261+
statement ok
1262+
DROP TABLE decimal_div_mismatch;

0 commit comments

Comments
 (0)