diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs index 05745666183d3..c385057678873 100644 --- a/datafusion/spark/src/function/math/round.rs +++ b/datafusion/spark/src/function/math/round.rs @@ -275,6 +275,34 @@ fn round_integer(value: i64, scale: i32, enable_ansi_mode: bool) -> Result Ok(result) } +fn round_unsigned_integer(value: u64, scale: i32, enable_ansi_mode: bool) -> Result { + if scale >= 0 { + return Ok(value); + } + let abs_scale = (-scale) as u32; + let Some(factor) = 10_u64.checked_pow(abs_scale) else { + return Ok(0); + }; + let remainder = value % factor; + let result = if remainder >= factor / 2 { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_add(factor)) + .ok_or_else(|| { + (exec_err!("UInt64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_add(factor) + } + } else { + value - remainder + }; + Ok(result) +} + // --------------------------------------------------------------------------- // Decimal rounding using ArrowNativeTypeOp (HALF_UP) // --------------------------------------------------------------------------- @@ -463,16 +491,8 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { let array = array.as_primitive::(); - let result: PrimitiveArray = array.try_unary(|x| { - let v_i64 = i64::try_from(x).map_err(|_| { - (exec_err!( - "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" - ) as Result<(), _>) - .unwrap_err() - })?; - round_integer(v_i64, scale, enable_ansi_mode) - .map(|v| v as u64) - })?; + let result: PrimitiveArray = array + .try_unary(|x| round_unsigned_integer(x, scale, enable_ansi_mode))?; Ok(ColumnarValue::Array(Arc::new(result))) } @@ -588,16 +608,8 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { - let v_i64 = i64::try_from(*v).map_err(|_| { - (exec_err!( - "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" - ) as Result<(), _>) - .unwrap_err() - })?; - let result = round_integer(v_i64, scale, enable_ansi_mode)?; - Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( - result as u64, - )))) + let result = round_unsigned_integer(*v, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(result)))) } // Float scalars @@ -652,3 +664,26 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result