diff --git a/.python-version b/.python-version index 6324d40..95ed564 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.14 +3.14.2 diff --git a/Cargo.toml b/Cargo.toml index ee49897..317cf18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "pymath" author = "Jeong, YunWon " repository = "https://github.com/RustPython/pymath" description = "A binary representation compatible Rust implementation of Python's math library." -version = "0.1.1" +version = "0.1.2" edition = "2024" license = "PSF-2.0" diff --git a/README.md b/README.md index e7b5c98..a1320da 100644 --- a/README.md +++ b/README.md @@ -11,63 +11,64 @@ Every function produces identical results to Python at the binary representation ## Compatibility Status -### math (55/55) +### math (56/56) +- [x] `acos` +- [x] `acosh` +- [x] `asin` +- [x] `asinh` +- [x] `atan` +- [x] `atan2` +- [x] `atanh` +- [x] `cbrt` - [x] `ceil` - [x] `copysign` +- [x] `cos` +- [x] `cosh` +- [x] `degrees` +- [x] `dist` +- [x] `e` +- [x] `erf` +- [x] `erfc` +- [x] `exp` +- [x] `exp2` +- [x] `expm1` - [x] `fabs` - [x] `floor` +- [x] `fma` - [x] `fmod` - [x] `frexp` +- [x] `fsum` +- [x] `gamma` +- [x] `hypot` (n-dimensional) +- [x] `inf` - [x] `isclose` - [x] `isfinite` - [x] `isinf` - [x] `isnan` - [x] `ldexp` -- [x] `modf` -- [x] `nextafter` -- [x] `remainder` -- [x] `trunc` -- [x] `ulp` -- [x] `cbrt` -- [x] `exp` -- [x] `exp2` -- [x] `expm1` +- [x] `lgamma` - [x] `log` - [x] `log10` - [x] `log1p` - [x] `log2` +- [x] `modf` +- [x] `nan` +- [x] `nextafter` +- [x] `pi` - [x] `pow` -- [x] `sqrt` -- [x] `acos` -- [x] `acosh` -- [x] `asin` -- [x] `asinh` -- [x] `atan` -- [x] `atan2` -- [x] `atanh` -- [x] `cos` -- [x] `cosh` +- [x] `prod` (`prod`, `prod_int`) +- [x] `radians` +- [x] `remainder` - [x] `sin` - [x] `sinh` +- [x] `sqrt` +- [x] `sumprod` (`sumprod`, `sumprod_int`) - [x] `tan` - [x] `tanh` -- [x] `erf` -- [x] `erfc` -- [x] `gamma` -- [x] `lgamma` -- [x] `dist` -- [x] `fsum` -- [x] `hypot` -- [x] `prod` -- [x] `sumprod` -- [x] `degrees` -- [x] `radians` -- [x] `pi` -- [x] `e` - [x] `tau` -- [x] `inf` -- [x] `nan` +- [x] `trunc` +- [x] `ulp` ### math.integer (6/6, requires `num-bigint` or `malachite-bigint` feature) @@ -78,20 +79,9 @@ Every function produces identical results to Python at the binary representation - [x] `lcm` - [x] `perm` -### cmath (24/24, requires `complex` feature) +### cmath (31/31, requires `complex` feature) - [x] `abs` -- [x] `isclose` -- [x] `isfinite` -- [x] `isinf` -- [x] `isnan` -- [x] `phase` -- [x] `polar` -- [x] `rect` -- [x] `exp` -- [x] `log` -- [x] `log10` -- [x] `sqrt` - [x] `acos` - [x] `acosh` - [x] `asin` @@ -100,10 +90,28 @@ Every function produces identical results to Python at the binary representation - [x] `atanh` - [x] `cos` - [x] `cosh` +- [x] `e` +- [x] `exp` +- [x] `inf` +- [x] `infj` +- [x] `isclose` +- [x] `isfinite` +- [x] `isinf` +- [x] `isnan` +- [x] `log` +- [x] `log10` +- [x] `nan` +- [x] `nanj` +- [x] `phase` +- [x] `pi` +- [x] `polar` +- [x] `rect` - [x] `sin` - [x] `sinh` +- [x] `sqrt` - [x] `tan` - [x] `tanh` +- [x] `tau` ## Usage diff --git a/proptest-regressions/cmath/trigonometric.txt b/proptest-regressions/cmath/trigonometric.txt index 04c0b1b..4870ee5 100644 --- a/proptest-regressions/cmath/trigonometric.txt +++ b/proptest-regressions/cmath/trigonometric.txt @@ -9,3 +9,4 @@ cc 409f359905a7e77eacdbd4643c3a1483cc128e46a1c06795c246f3d57a2e61b9 # shrinks to cc 5acbcf76c3bcbaf0b2b89a987e38c75265434f96fdc11598830e2221df72fc53 # shrinks to re = 0.00010260965539526095, im = 4.984721877290597e-19 cc ac04191f916633de0408e071f5f6ada8f7fe7a1be02caca7a32f37ecb148a5ac # shrinks to re = 3.049837651806167e74, im = -2.222842222335753e-166 cc 77ec3c08d2e057fb9467eb13c3b953e4e30f2800259d3653f7bcdfcbaf53614f # shrinks to re = 7.812038268590211e52, im = -2.623972069152808e-109 +cc c8fa2873ca664202a38d40ef3766b0f9f07dc82c1d5cd08dcd6d180564a013a1 # shrinks to re = -0.0016106222874309162, im = 0.10954881424244108 diff --git a/proptest-regressions/math/integer.txt b/proptest-regressions/math/integer.txt new file mode 100644 index 0000000..a043e24 --- /dev/null +++ b/proptest-regressions/math/integer.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 79031914da5204bfc75b0d7cf66e7f76d2e455d6c2837b66cbd11ebf7225be4a # shrinks to n = 32, k = 30 diff --git a/proptest-regressions/math/misc.txt b/proptest-regressions/math/misc.txt new file mode 100644 index 0000000..020f6ec --- /dev/null +++ b/proptest-regressions/math/misc.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 46166a18f59569bb11107fd280e53c8f37902760e15afbe67d6951dcd3032c72 # shrinks to x = 0.0, y = 0.0, z = 0.0 diff --git a/src/cmath.rs b/src/cmath.rs index 90bce66..c66242f 100644 --- a/src/cmath.rs +++ b/src/cmath.rs @@ -11,11 +11,34 @@ pub use exponential::{exp, log, log10, sqrt}; pub use misc::{abs, isclose, isfinite, isinf, isnan, phase, polar, rect}; pub use trigonometric::{acos, acosh, asin, asinh, atan, atanh, cos, cosh, sin, sinh, tan, tanh}; +use num_complex::Complex64; + +// Public constants (matching Python's cmath module) + +/// The mathematical constant e = 2.718281... +pub const E: f64 = std::f64::consts::E; + +/// The mathematical constant π = 3.141592... +pub const PI: f64 = std::f64::consts::PI; + +/// The mathematical constant τ = 6.283185... +pub const TAU: f64 = std::f64::consts::TAU; + +/// Positive infinity. +pub const INF: f64 = f64::INFINITY; + +/// A floating point "not a number" (NaN) value. +pub const NAN: f64 = f64::NAN; + +/// Complex number with zero real part and positive infinity imaginary part. +pub const INFJ: Complex64 = Complex64::new(0.0, f64::INFINITY); + +/// Complex number with zero real part and NaN imaginary part. +pub const NANJ: Complex64 = Complex64::new(0.0, f64::NAN); + #[cfg(test)] use crate::Result; use crate::m; -#[cfg(test)] -use num_complex::Complex64; // Shared constants @@ -25,8 +48,6 @@ const M_LN2: f64 = core::f64::consts::LN_2; const CM_LARGE_DOUBLE: f64 = f64::MAX / 4.0; const CM_LOG_LARGE_DOUBLE: f64 = 709.0895657128241; // log(CM_LARGE_DOUBLE) -const INF: f64 = f64::INFINITY; - // Special value table constants const P: f64 = core::f64::consts::PI; const P14: f64 = 0.25 * core::f64::consts::PI; diff --git a/src/cmath/exponential.rs b/src/cmath/exponential.rs index 1fdc30f..f306065 100644 --- a/src/cmath/exponential.rs +++ b/src/cmath/exponential.rs @@ -131,9 +131,10 @@ pub fn exp(z: Complex64) -> Result { Ok(Complex64::new(r_re, r_im)) } -/// Complex natural logarithm. +/// Complex natural logarithm (ln, base e). +/// TODO: consider to expose API #[inline] -pub fn log(z: Complex64) -> Result { +pub(crate) fn ln(z: Complex64) -> Result { special_value!(z, LOG_SPECIAL_VALUES); let ax = m::fabs(z.re); @@ -149,8 +150,10 @@ pub fn log(z: Complex64) -> Result { m::ldexp(ay, f64::MANTISSA_DIGITS as i32), )) - f64::MANTISSA_DIGITS as f64 * M_LN2 } else { - // log(+/-0. +/- 0i) - return Err(Error::EDOM); + // log(+/-0. +/- 0i) - return -inf like CPython does + // Note: CPython sets errno=EDOM but still returns a value. + // When used with a base, the second c_log call clears errno. + f64::NEG_INFINITY } } else { let h = m::hypot(ax, ay); @@ -167,10 +170,90 @@ pub fn log(z: Complex64) -> Result { Ok(Complex64::new(r_re, r_im)) } +/// Complex logarithm with optional base. +/// +/// If base is None, returns the natural logarithm. +/// If base is Some(b), returns log(z) / log(b). +#[inline] +pub fn log(z: Complex64, base: Option) -> Result { + // c_log always returns a value, but sets errno for special cases. + // The error check happens at the end of cmath_log_impl. + // For log(z) without base: z=0 raises EDOM + // For log(z, base): z=0 doesn't raise because c_log(base) clears errno + let z_is_zero = z.re == 0.0 && z.im == 0.0; + match base { + None => { + // No base: raise error if z=0 + if z_is_zero { + return Err(Error::EDOM); + } + ln(z) + } + Some(b) => { + // With base: z=0 is allowed (second ln clears the "errno") + let log_z = ln(z)?; + let log_b = ln(b)?; + // Use _Py_c_quot-style division to preserve sign of zero + Ok(c_quot(log_z, log_b)) + } + } +} + +/// Complex division following _Py_c_quot algorithm. +/// This preserves the sign of zero correctly and recovers infinities +/// from NaN results per C11 Annex G.5.2. +#[inline] +fn c_quot(a: Complex64, b: Complex64) -> Complex64 { + let abs_breal = m::fabs(b.re); + let abs_bimag = m::fabs(b.im); + + let mut r = if abs_breal >= abs_bimag { + if abs_breal == 0.0 { + Complex64::new(f64::NAN, f64::NAN) + } else { + let ratio = b.im / b.re; + let denom = b.re + b.im * ratio; + Complex64::new((a.re + a.im * ratio) / denom, (a.im - a.re * ratio) / denom) + } + } else if abs_bimag >= abs_breal { + let ratio = b.re / b.im; + let denom = b.re * ratio + b.im; + Complex64::new((a.re * ratio + a.im) / denom, (a.im * ratio - a.re) / denom) + } else { + // At least one of b.re or b.im is NaN + Complex64::new(f64::NAN, f64::NAN) + }; + + // Recover infinities and zeros that computed as nan+nanj. + // See C11 Annex G.5.2, routine _Cdivd(). + if r.re.is_nan() && r.im.is_nan() { + if (a.re.is_infinite() || a.im.is_infinite()) && b.re.is_finite() && b.im.is_finite() { + let x = m::copysign(if a.re.is_infinite() { 1.0 } else { 0.0 }, a.re); + let y = m::copysign(if a.im.is_infinite() { 1.0 } else { 0.0 }, a.im); + r.re = f64::INFINITY * (x * b.re + y * b.im); + r.im = f64::INFINITY * (y * b.re - x * b.im); + } else if (abs_breal.is_infinite() || abs_bimag.is_infinite()) + && a.re.is_finite() + && a.im.is_finite() + { + let x = m::copysign(if b.re.is_infinite() { 1.0 } else { 0.0 }, b.re); + let y = m::copysign(if b.im.is_infinite() { 1.0 } else { 0.0 }, b.im); + r.re = 0.0 * (a.re * x + a.im * y); + r.im = 0.0 * (a.im * x - a.re * y); + } + } + + r +} + /// Complex base-10 logarithm. #[inline] pub fn log10(z: Complex64) -> Result { - let r = log(z)?; + // Like log(z) without base, log10(0) raises EDOM + if z.re == 0.0 && z.im == 0.0 { + return Err(Error::EDOM); + } + let r = ln(z)?; Ok(Complex64::new(r.re / M_LN10, r.im / M_LN10)) } @@ -191,13 +274,56 @@ mod tests { fn test_exp(re: f64, im: f64) { test_cmath_func("exp", exp, re, im); } - fn test_log(re: f64, im: f64) { - test_cmath_func("log", log, re, im); + fn test_log_n(re: f64, im: f64) { + test_cmath_func("log", |z| log(z, None), re, im); } fn test_log10(re: f64, im: f64) { test_cmath_func("log10", log10, re, im); } + /// Test log with base - compares with Python's cmath.log(z, base) + fn test_log(z_re: f64, z_im: f64, base_re: f64, base_im: f64) { + use pyo3::prelude::*; + + let z = Complex64::new(z_re, z_im); + let base = Complex64::new(base_re, base_im); + let rs_result = log(z, Some(base)); + + pyo3::Python::attach(|py| { + let cmath = pyo3::types::PyModule::import(py, "cmath").unwrap(); + let py_z = pyo3::types::PyComplex::from_doubles(py, z_re, z_im); + let py_base = pyo3::types::PyComplex::from_doubles(py, base_re, base_im); + let py_result = cmath.getattr("log").unwrap().call1((py_z, py_base)); + + match py_result { + Ok(result) => { + use pyo3::types::PyComplexMethods; + let c = result.cast::().unwrap(); + let py_re = c.real(); + let py_im = c.imag(); + match rs_result { + Ok(rs) => { + crate::cmath::tests::assert_complex_eq( + py_re, py_im, rs, "log", z_re, z_im, + ); + } + Err(e) => { + panic!( + "log({z_re}+{z_im}j, {base_re}+{base_im}j): py=({py_re}, {py_im}) but rs returned error {e:?}" + ); + } + } + } + Err(_) => { + assert!( + rs_result.is_err(), + "log({z_re}+{z_im}j, {base_re}+{base_im}j): py raised error but rs={rs_result:?}" + ); + } + } + }); + } + use crate::test::EDGE_VALUES; #[test] @@ -219,10 +345,10 @@ mod tests { } #[test] - fn edgetest_log() { + fn edgetest_log_n() { for &re in &EDGE_VALUES { for &im in &EDGE_VALUES { - test_log(re, im); + test_log_n(re, im); } } } @@ -236,6 +362,25 @@ mod tests { } } + #[test] + fn edgetest_log() { + // Test log with various bases - sign preservation edge cases + let bases = [0.5, 2.0, 10.0]; + let values = [0.01, 0.1, 0.5, 1.0, 2.0, 10.0, 100.0]; + for &base in &bases { + for &v in &values { + test_log(v, 0.0, base, 0.0); + } + } + // Additional edge cases with imaginary parts + for &z_re in &EDGE_VALUES { + for &z_im in &EDGE_VALUES { + test_log(z_re, z_im, 2.0, 0.0); + test_log(z_re, z_im, 0.5, 0.0); + } + } + } + proptest::proptest! { #[test] fn proptest_sqrt(re: f64, im: f64) { @@ -249,7 +394,7 @@ mod tests { #[test] fn proptest_log(re: f64, im: f64) { - test_log(re, im); + test_log_n(re, im); } #[test] diff --git a/src/cmath/misc.rs b/src/cmath/misc.rs index 9831cac..3380649 100644 --- a/src/cmath/misc.rs +++ b/src/cmath/misc.rs @@ -98,22 +98,38 @@ pub fn abs(z: Complex64) -> f64 { } /// Determine whether two complex numbers are close in value. +/// +/// Default tolerances: rel_tol = 1e-09, abs_tol = 0.0 +/// Returns Err(EDOM) if rel_tol or abs_tol is negative. #[inline] -pub fn isclose(a: Complex64, b: Complex64, rel_tol: f64, abs_tol: f64) -> bool { +pub fn isclose( + a: Complex64, + b: Complex64, + rel_tol: Option, + abs_tol: Option, +) -> Result { + let rel_tol = rel_tol.unwrap_or(1e-09); + let abs_tol = abs_tol.unwrap_or(0.0); + + // Tolerances must be non-negative + if rel_tol < 0.0 || abs_tol < 0.0 { + return Err(Error::EDOM); + } + // short circuit exact equality if a.re == b.re && a.im == b.im { - return true; + return Ok(true); } // This catches the case of two infinities of opposite sign, or // one infinity and one finite number. if a.re.is_infinite() || a.im.is_infinite() || b.re.is_infinite() || b.im.is_infinite() { - return false; + return Ok(false); } // now do the regular computation let diff = abs(Complex64::new(a.re - b.re, a.im - b.im)); - (diff <= rel_tol * abs(b)) || (diff <= rel_tol * abs(a)) || (diff <= abs_tol) + Ok((diff <= rel_tol * abs(b)) || (diff <= rel_tol * abs(a)) || (diff <= abs_tol)) } #[cfg(test)] @@ -315,39 +331,64 @@ mod tests { #[test] fn test_isclose_basic() { // Equal values - assert!(isclose( - Complex64::new(1.0, 2.0), - Complex64::new(1.0, 2.0), - 1e-9, - 0.0 - )); + assert_eq!( + isclose( + Complex64::new(1.0, 2.0), + Complex64::new(1.0, 2.0), + Some(1e-9), + Some(0.0) + ), + Ok(true) + ); // Close values - assert!(isclose( - Complex64::new(1.0, 2.0), - Complex64::new(1.0 + 1e-10, 2.0), - 1e-9, - 0.0 - )); + assert_eq!( + isclose( + Complex64::new(1.0, 2.0), + Complex64::new(1.0 + 1e-10, 2.0), + Some(1e-9), + Some(0.0) + ), + Ok(true) + ); // Not close - assert!(!isclose( - Complex64::new(1.0, 2.0), - Complex64::new(2.0, 2.0), - 1e-9, - 0.0 - )); + assert_eq!( + isclose( + Complex64::new(1.0, 2.0), + Complex64::new(2.0, 2.0), + Some(1e-9), + Some(0.0) + ), + Ok(false) + ); // Infinities - assert!(isclose( - Complex64::new(f64::INFINITY, 0.0), - Complex64::new(f64::INFINITY, 0.0), - 1e-9, - 0.0 - )); - assert!(!isclose( - Complex64::new(f64::INFINITY, 0.0), - Complex64::new(f64::NEG_INFINITY, 0.0), - 1e-9, - 0.0 - )); + assert_eq!( + isclose( + Complex64::new(f64::INFINITY, 0.0), + Complex64::new(f64::INFINITY, 0.0), + Some(1e-9), + Some(0.0) + ), + Ok(true) + ); + assert_eq!( + isclose( + Complex64::new(f64::INFINITY, 0.0), + Complex64::new(f64::NEG_INFINITY, 0.0), + Some(1e-9), + Some(0.0) + ), + Ok(false) + ); + // Negative tolerance + assert!( + isclose( + Complex64::new(1.0, 2.0), + Complex64::new(1.0, 2.0), + Some(-1.0), + Some(0.0) + ) + .is_err() + ); } proptest::proptest! { diff --git a/src/math.rs b/src/math.rs index e65dece..9b40faf 100644 --- a/src/math.rs +++ b/src/math.rs @@ -10,11 +10,11 @@ mod misc; mod trigonometric; // Re-export from submodules -pub use aggregate::{dist, fsum, prod, sumprod}; +pub use aggregate::{dist, fsum, prod, prod_int, sumprod, sumprod_int}; pub use exponential::{cbrt, exp, exp2, expm1, log, log1p, log2, log10, pow, sqrt}; pub use gamma::{erf, erfc, gamma, lgamma}; pub use misc::{ - ceil, copysign, fabs, floor, fmod, frexp, isclose, isfinite, isinf, isnan, ldexp, modf, + ceil, copysign, fabs, floor, fma, fmod, frexp, isclose, isfinite, isinf, isnan, ldexp, modf, nextafter, remainder, trunc, ulp, }; pub use trigonometric::{ @@ -96,17 +96,33 @@ pub(crate) fn math_1a(x: f64, func: fn(f64) -> f64) -> crate::Result { crate::err::is_error(r) } -/// Return the Euclidean distance, sqrt(x*x + y*y). +/// Return the Euclidean norm of n-dimensional coordinates. /// -/// Uses high-precision vector_norm algorithm instead of libm hypot() -/// for consistent results across platforms and better handling of overflow/underflow. +/// Equivalent to sqrt(sum(x**2 for x in coords)). +/// Uses high-precision vector_norm algorithm for consistent results +/// across platforms and better handling of overflow/underflow. #[inline] -pub fn hypot(x: f64, y: f64) -> f64 { - let ax = x.abs(); - let ay = y.abs(); - let max = if ax > ay { ax } else { ay }; - let found_nan = x.is_nan() || y.is_nan(); - aggregate::vector_norm_2(ax, ay, max, found_nan) +pub fn hypot(coords: &[f64]) -> f64 { + let n = coords.len(); + if n == 0 { + return 0.0; + } + + let mut max = 0.0_f64; + let mut found_nan = false; + let abs_coords: Vec = coords + .iter() + .map(|&x| { + let ax = x.abs(); + found_nan |= ax.is_nan(); + if ax > max { + max = ax; + } + ax + }) + .collect(); + + aggregate::vector_norm(&abs_coords, max, found_nan) } // Mathematical constants @@ -208,16 +224,43 @@ mod tests { assert!(NAN.is_nan()); } - // hypot tests - fn test_hypot(x: f64, y: f64) { - crate::test::test_math_2(x, y, "hypot", |x, y| Ok(hypot(x, y))); + // hypot tests - n-dimensional + fn test_hypot(coords: &[f64]) { + let rs_result = hypot(coords); + + pyo3::Python::attach(|py| { + let math = PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("hypot").unwrap(); + let py_args = pyo3::types::PyTuple::new(py, coords).unwrap(); + let py_result: f64 = py_func.call1(py_args).unwrap().extract().unwrap(); + + if py_result.is_nan() && rs_result.is_nan() { + return; + } + assert_eq!( + py_result.to_bits(), + rs_result.to_bits(), + "hypot({:?}): py={} vs rs={}", + coords, + py_result, + rs_result + ); + }); + } + + #[test] + fn test_hypot_basic() { + test_hypot(&[3.0, 4.0]); // 5.0 + test_hypot(&[1.0, 2.0, 2.0]); // 3.0 + test_hypot(&[]); // 0.0 + test_hypot(&[5.0]); // 5.0 } #[test] fn edgetest_hypot() { for &x in &crate::test::EDGE_VALUES { for &y in &crate::test::EDGE_VALUES { - test_hypot(x, y); + test_hypot(&[x, y]); } } } @@ -235,7 +278,7 @@ mod tests { #[test] fn proptest_hypot(x: f64, y: f64) { - test_hypot(x, y); + test_hypot(&[x, y]); } } } diff --git a/src/math/aggregate.rs b/src/math/aggregate.rs index 2676d60..ad9219b 100644 --- a/src/math/aggregate.rs +++ b/src/math/aggregate.rs @@ -170,80 +170,8 @@ pub fn fsum(iter: impl IntoIterator) -> crate::Result { // VECTOR_NORM - for dist and hypot -/// Compute the Euclidean norm of two values with high precision. -/// Optimized version for hypot(x, y). -pub(super) fn vector_norm_2(x: f64, y: f64, max: f64, found_nan: bool) -> f64 { - // Check for infinity first (inf wins over nan) - if x.is_infinite() || y.is_infinite() { - return f64::INFINITY; - } - if found_nan { - return f64::NAN; - } - if max == 0.0 { - return 0.0; - } - // n == 1 case: only one non-zero value - if x == 0.0 || y == 0.0 { - return max; - } - - let mut max_e: i32 = 0; - crate::m::frexp(max, &mut max_e); - - if max_e < -1023 { - // When max_e < -1023, ldexp(1.0, -max_e) would overflow - return f64::MIN_POSITIVE - * vector_norm_2( - x / f64::MIN_POSITIVE, - y / f64::MIN_POSITIVE, - max / f64::MIN_POSITIVE, - found_nan, - ); - } - - let scale = crate::m::ldexp(1.0, -max_e); - debug_assert!(max * scale >= 0.5); - debug_assert!(max * scale < 1.0); - - let mut csum = 1.0; - let mut frac1 = 0.0; - let mut frac2 = 0.0; - - // Process x - let xs = x * scale; - debug_assert!(xs.abs() < 1.0); - let pr = dl_mul(xs, xs); - debug_assert!(pr.hi <= 1.0); - let sm = dl_fast_sum(csum, pr.hi); - csum = sm.hi; - frac1 += pr.lo; - frac2 += sm.lo; - - // Process y - let ys = y * scale; - debug_assert!(ys.abs() < 1.0); - let pr = dl_mul(ys, ys); - debug_assert!(pr.hi <= 1.0); - let sm = dl_fast_sum(csum, pr.hi); - csum = sm.hi; - frac1 += pr.lo; - frac2 += sm.lo; - - let mut h = (csum - 1.0 + (frac1 + frac2)).sqrt(); - let pr = dl_mul(-h, h); - let sm = dl_fast_sum(csum, pr.hi); - csum = sm.hi; - frac1 += pr.lo; - frac2 += sm.lo; - let x = csum - 1.0 + (frac1 + frac2); - h += x / (2.0 * h); // differential correction - - h / scale -} - /// Compute the Euclidean norm of a vector with high precision. -fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 { +pub fn vector_norm(vec: &[f64], max: f64, found_nan: bool) -> f64 { let n = vec.len(); if max.is_infinite() { @@ -331,7 +259,7 @@ pub fn dist(p: &[f64], q: &[f64]) -> f64 { vector_norm(&diffs, max, found_nan) } -/// Return the sum of products of values from two sequences. +/// Return the sum of products of values from two sequences (float version). /// /// Uses TripleLength arithmetic for high precision. /// Equivalent to sum(p[i] * q[i] for i in range(len(p))). @@ -353,7 +281,25 @@ pub fn sumprod(p: &[f64], q: &[f64]) -> f64 { tl_to_d(flt_total) } -/// Return the product of all elements in the iterable. +/// Return the sum of products of values from two sequences (integer version). +/// +/// Uses checked arithmetic to detect overflow. +/// Returns None if overflow occurs during computation. +/// Equivalent to sum(p[i] * q[i] for i in range(len(p))). +pub fn sumprod_int(p: &[i64], q: &[i64]) -> Option { + assert_eq!(p.len(), q.len(), "Inputs are not the same length"); + + let mut total: i64 = 0; + + for (&pi, &qi) in p.iter().zip(q.iter()) { + let prod = pi.checked_mul(qi)?; + total = total.checked_add(prod)?; + } + + Some(total) +} + +/// Return the product of all elements in the iterable (float version). /// /// If start is None, uses 1.0 as the start value. pub fn prod(iter: impl IntoIterator, start: Option) -> f64 { @@ -364,6 +310,18 @@ pub fn prod(iter: impl IntoIterator, start: Option) -> f64 { result } +/// Return the product of all elements using i64 arithmetic. +/// +/// Returns None if overflow occurs during multiplication. +/// Uses checked arithmetic to detect overflow. +pub fn prod_int(iter: impl IntoIterator, start: Option) -> Option { + let mut result = start.unwrap_or(1); + for x in iter { + result = result.checked_mul(x)?; + } + Some(result) +} + #[cfg(test)] mod tests { use super::*; @@ -441,27 +399,18 @@ mod tests { } fn test_dist_impl(p: &[f64], q: &[f64]) { - let rs_result = dist(p, q); - - pyo3::Python::attach(|py| { - let math = pyo3::types::PyModule::import(py, "math").unwrap(); - let py_func = math.getattr("dist").unwrap(); + let rs = dist(p, q); + crate::test::with_py_math(|py, math| { let py_p = pyo3::types::PyList::new(py, p).unwrap(); let py_q = pyo3::types::PyList::new(py, q).unwrap(); - let py_result: f64 = py_func.call1((py_p, py_q)).unwrap().extract().unwrap(); - - if py_result.is_nan() && rs_result.is_nan() { - return; - } - assert_eq!( - py_result.to_bits(), - rs_result.to_bits(), - "dist({:?}, {:?}): py={} vs rs={}", - p, - q, - py_result, - rs_result - ); + let py: f64 = math + .getattr("dist") + .unwrap() + .call1((py_p, py_q)) + .unwrap() + .extract() + .unwrap(); + crate::test::assert_f64_eq(py, rs, format_args!("dist({p:?}, {q:?})")); }); } @@ -474,27 +423,18 @@ mod tests { } fn test_sumprod_impl(p: &[f64], q: &[f64]) { - let rs_result = sumprod(p, q); - - pyo3::Python::attach(|py| { - let math = pyo3::types::PyModule::import(py, "math").unwrap(); - let py_func = math.getattr("sumprod").unwrap(); + let rs = sumprod(p, q); + crate::test::with_py_math(|py, math| { let py_p = pyo3::types::PyList::new(py, p).unwrap(); let py_q = pyo3::types::PyList::new(py, q).unwrap(); - let py_result: f64 = py_func.call1((py_p, py_q)).unwrap().extract().unwrap(); - - if py_result.is_nan() && rs_result.is_nan() { - return; - } - assert_eq!( - py_result.to_bits(), - rs_result.to_bits(), - "sumprod({:?}, {:?}): py={} vs rs={}", - p, - q, - py_result, - rs_result - ); + let py: f64 = math + .getattr("sumprod") + .unwrap() + .call1((py_p, py_q)) + .unwrap() + .extract() + .unwrap(); + crate::test::assert_f64_eq(py, rs, format_args!("sumprod({p:?}, {q:?})")); }); } @@ -507,37 +447,22 @@ mod tests { } fn test_prod_impl(values: &[f64], start: Option) { - let rs_result = prod(values.iter().copied(), start); - - pyo3::Python::attach(|py| { - let math = pyo3::types::PyModule::import(py, "math").unwrap(); - let py_func = math.getattr("prod").unwrap(); + let rs = prod(values.iter().copied(), start); + crate::test::with_py_math(|py, math| { let py_list = pyo3::types::PyList::new(py, values).unwrap(); - let py_result: f64 = match start { + let py_func = math.getattr("prod").unwrap(); + let py: f64 = match start { Some(s) => { let kwargs = pyo3::types::PyDict::new(py); kwargs.set_item("start", s).unwrap(); - py_func - .call((py_list,), Some(&kwargs)) - .unwrap() - .extract() - .unwrap() + py_func.call((py_list,), Some(&kwargs)) } - None => py_func.call1((py_list,)).unwrap().extract().unwrap(), - }; - - if py_result.is_nan() && rs_result.is_nan() { - return; + None => py_func.call1((py_list,)), } - assert_eq!( - py_result.to_bits(), - rs_result.to_bits(), - "prod({:?}, {:?}): py={} vs rs={}", - values, - start, - py_result, - rs_result - ); + .unwrap() + .extract() + .unwrap(); + crate::test::assert_f64_eq(py, rs, format_args!("prod({values:?}, {start:?})")); }); } diff --git a/src/math/exponential.rs b/src/math/exponential.rs index 7a817a5..be2c47c 100644 --- a/src/math/exponential.rs +++ b/src/math/exponential.rs @@ -250,7 +250,7 @@ mod tests { fn test_expm1(x: f64) { crate::test::test_math_1(x, "expm1", expm1); } - fn test_log_without_base(x: f64) { + fn test_log_n(x: f64) { crate::test::test_math_1(x, "log", |x| log(x, None)); } fn test_log(x: f64, base: f64) { @@ -302,8 +302,8 @@ mod tests { } #[test] - fn proptest_log_without_base(x: f64) { - test_log_without_base(x); + fn proptest_log_n(x: f64) { + test_log_n(x); } #[test] @@ -368,9 +368,9 @@ mod tests { } #[test] - fn edgetest_log_without_base() { + fn edgetest_log_n() { for &x in &crate::test::EDGE_VALUES { - test_log_without_base(x); + test_log_n(x); } } diff --git a/src/math/integer.rs b/src/math/integer.rs index ea6148c..521997d 100644 --- a/src/math/integer.rs +++ b/src/math/integer.rs @@ -11,23 +11,43 @@ use num_bigint::{BigInt, BigUint}; use num_integer::Integer; use num_traits::{One, Signed, ToPrimitive, Zero}; -/// Return the greatest common divisor of a and b. +/// Return the greatest common divisor of the integer arguments. /// -/// Uses the optimized GCD implementation from num-integer, +/// gcd() with no arguments returns 0. +/// gcd(a) returns abs(a). #[inline] -pub fn gcd(a: &BigInt, b: &BigInt) -> BigInt { - a.gcd(b) +pub fn gcd(args: &[BigInt]) -> BigInt { + if args.is_empty() { + return BigInt::zero(); + } + let mut result = args[0].abs(); + for arg in &args[1..] { + result = result.gcd(arg); + if result.is_one() { + return result; + } + } + result } -/// Return the least common multiple of a and b. +/// Return the least common multiple of the integer arguments. +/// +/// lcm() with no arguments returns 1. +/// lcm(a) returns abs(a). #[inline] -pub fn lcm(a: &BigInt, b: &BigInt) -> BigInt { - if a.is_zero() || b.is_zero() { - return BigInt::zero(); +pub fn lcm(args: &[BigInt]) -> BigInt { + if args.is_empty() { + return BigInt::one(); } - let g = gcd(a, b); - let f = a / &g; - (&f * b).abs() + let mut result = args[0].abs(); + for arg in &args[1..] { + if result.is_zero() || arg.is_zero() { + return BigInt::zero(); + } + let g = result.gcd(arg); + result = result / &g * arg.abs(); + } + result } /// Approximate square roots for 16-bit integers. @@ -62,10 +82,22 @@ fn approximate_isqrt(n: u64) -> u32 { (u << 15) + ((n >> 17) / u as u64) as u32 } +/// Return the integer part of the square root of a non-negative integer. +/// +/// Returns Err if n is negative. +pub fn isqrt(n: &BigInt) -> Result { + if n.is_negative() { + return Err(()); + } + Ok(isqrt_unsigned(&n.magnitude()).into()) +} + /// Return the integer part of the square root of the input. /// /// This is an adaptive-precision pure-integer version of Newton's iteration. -pub fn isqrt(n: &BigUint) -> BigUint { +/// +/// TODO: regard to expose as public API +fn isqrt_unsigned(n: &BigUint) -> BigUint { if n.is_zero() { return BigUint::zero(); } @@ -174,11 +206,7 @@ fn factorial_partial_product(start: u64, stop: u64, max_bits: u32) -> BigUint { // Find midpoint of range(start, stop), rounded up to next odd number let midpoint = (start + num_operands) | 1; - let left = factorial_partial_product( - start, - midpoint, - (64 - midpoint.leading_zeros()).saturating_sub(1), - ); + let left = factorial_partial_product(start, midpoint, 64 - (midpoint - 2).leading_zeros()); let right = factorial_partial_product(midpoint, stop, max_bits); left * right } @@ -199,11 +227,7 @@ fn factorial_odd_part(n: u64) -> BigUint { let lower = upper; // (v + 1) | 1 = least odd integer strictly larger than n / 2**i upper = (v + 1) | 1; - let partial = factorial_partial_product( - lower, - upper, - (64 - (upper - 2).leading_zeros()).saturating_sub(1), - ); + let partial = factorial_partial_product(lower, upper, 64 - (upper - 2).leading_zeros()); inner *= partial; outer *= &inner; } @@ -213,18 +237,23 @@ fn factorial_odd_part(n: u64) -> BigUint { /// Return n factorial (n!). /// +/// Returns Err(()) if n is negative. /// Uses the divide-and-conquer algorithm. /// Based on: http://www.luschny.de/math/factorial/binarysplitfact.html -pub fn factorial(n: u64) -> BigUint { +pub fn factorial(n: i64) -> Result { + if n < 0 { + return Err(()); + } + let n = n as u64; // Use lookup table for small values if n < SMALL_FACTORIALS.len() as u64 { - return BigUint::from(SMALL_FACTORIALS[n as usize]); + return Ok(BigUint::from(SMALL_FACTORIALS[n as usize])); } // Express as odd_part * 2**two_valuation let odd_part = factorial_odd_part(n); let two_valuation = n - count_set_bits(n); - odd_part << two_valuation as usize + Ok(odd_part << two_valuation as usize) } // COMB / PERM @@ -393,7 +422,7 @@ const INVERTED_FACTORIAL_ODD_PART: [u64; 128] = [ 0x2b3e690621a42251, 0x4f520f00e03c04e7, 0x2edf84ee600211d3, - 0xadcaa2764aaacffd, + 0xadcaa2764aaacdfd, 0x161f4f9033f4fe63, 0x161f4f9033f4fe63, 0xbada2932ea4d3e03, @@ -652,11 +681,16 @@ fn perm_comb(n: &BigUint, k: u64, is_comb: bool) -> BigUint { /// Return the number of ways to choose k items from n items (n choose k). /// +/// Returns Err(()) if n or k is negative. /// Evaluates to n! / (k! * (n - k)!) when k <= n and evaluates /// to zero when k > n. -pub fn comb(n: u64, k: u64) -> BigUint { +pub fn comb(n: i64, k: i64) -> Result { + if n < 0 || k < 0 { + return Err(()); + } + let (n, k) = (n as u64, k as u64); if k > n { - return BigUint::zero(); + return Ok(BigUint::zero()); } // Use smaller k for efficiency @@ -664,107 +698,464 @@ pub fn comb(n: u64, k: u64) -> BigUint { if k <= 1 { if k == 0 { - return BigUint::one(); + return Ok(BigUint::one()); } - return BigUint::from(n); + return Ok(BigUint::from(n)); } - perm_comb_small(n, k, true) + Ok(perm_comb_small(n, k, true)) } /// Return the number of ways to arrange k items from n items. /// +/// Returns Err(()) if n or k is negative. /// Evaluates to n! / (n - k)! when k <= n and evaluates /// to zero when k > n. /// /// If k is not specified (None), then k defaults to n /// and the function returns n!. -pub fn perm(n: u64, k: Option) -> BigUint { - let k = k.unwrap_or(n); +pub fn perm(n: i64, k: Option) -> Result { + if n < 0 { + return Err(()); + } + let n = n as u64; + let k = match k { + Some(k) if k < 0 => return Err(()), + Some(k) => k as u64, + None => n, + }; if k > n { - return BigUint::zero(); + return Ok(BigUint::zero()); } if k == 0 { - return BigUint::one(); + return Ok(BigUint::one()); } if k == 1 { - return BigUint::from(n); + return Ok(BigUint::from(n)); } - perm_comb_small(n, k, false) + Ok(perm_comb_small(n, k, false)) } #[cfg(test)] mod tests { use super::*; + use pyo3::prelude::*; + + /// Edge i64 values for testing integer math functions (gcd, lcm, isqrt, factorial, comb, perm) + const EDGE_I64: [i64; 24] = [ + // Zero and small values + 0, + 1, + -1, + 2, + -2, + // Prime numbers + 7, + 13, + 97, + // Powers of 2 + 64, + 1024, + 65536, + // Factorial-relevant + 20, // 20! fits in u64 + 21, // 21! overflows u64 + // Large values + 1_000_000, + -1_000_000, + i32::MAX as i64, + i32::MIN as i64, + // Near i64 bounds + i64::MAX, + i64::MIN, + i64::MAX - 1, + i64::MIN + 1, + // Square root boundaries + (1i64 << 31) - 1, // sqrt fits in u32 + 1i64 << 32, // sqrt boundary + (1i64 << 62) - 1, // large but valid for isqrt + ]; + + fn test_gcd_impl(args: &[i64]) { + use std::str::FromStr; + let bigints: Vec = args.iter().map(|&x| BigInt::from(x)).collect(); + let rs = gcd(&bigints); + crate::test::with_py_math(|py, math| { + let py_args = pyo3::types::PyTuple::new(py, args).unwrap(); + let py_result = math.getattr("gcd").unwrap().call1(py_args).unwrap(); + let py_str: String = py_result.str().unwrap().extract().unwrap(); + let py = BigInt::from_str(&py_str).unwrap(); + assert_eq!(rs, py, "gcd({args:?}): py={py} vs rs={rs}"); + }); + } + + fn test_lcm_impl(args: &[i64]) { + use std::str::FromStr; + let bigints: Vec = args.iter().map(|&x| BigInt::from(x)).collect(); + let rs = lcm(&bigints); + crate::test::with_py_math(|py, math| { + let py_args = pyo3::types::PyTuple::new(py, args).unwrap(); + let py_result = math.getattr("lcm").unwrap().call1(py_args).unwrap(); + let py_str: String = py_result.str().unwrap().extract().unwrap(); + let py = BigInt::from_str(&py_str).unwrap(); + assert_eq!(rs, py, "lcm({args:?}): py={py} vs rs={rs}"); + }); + } + + fn test_isqrt_impl(n: i64) { + let rs = isqrt(&BigInt::from(n)); + crate::test::with_py_math(|_py, math| { + let py_result = math.getattr("isqrt").unwrap().call1((n,)); + match py_result { + Ok(result) => { + let py: i64 = result.extract().unwrap(); + assert_eq!(rs, Ok(BigInt::from(py)), "isqrt({n}): py={py} vs rs={rs:?}"); + } + Err(_) => { + assert!(rs.is_err(), "isqrt({n}): py raised error but rs={rs:?}"); + } + } + }); + } + + fn test_factorial_impl(n: i64) { + use std::str::FromStr; + let rs = factorial(n); + crate::test::with_py_math(|_py, math| { + let py_result = math.getattr("factorial").unwrap().call1((n,)); + match py_result { + Ok(result) => { + let py_str: String = result.str().unwrap().extract().unwrap(); + let py = BigUint::from_str(&py_str).unwrap(); + assert_eq!(rs, Ok(py.clone()), "factorial({n}): py={py} vs rs={rs:?}"); + } + Err(_) => { + assert!(rs.is_err(), "factorial({n}): py raised error but rs={rs:?}"); + } + } + }); + } + + fn test_comb_impl(n: i64, k: i64) { + use std::str::FromStr; + let rs = comb(n, k); + crate::test::with_py_math(|_py, math| { + let py_result = math.getattr("comb").unwrap().call1((n, k)); + match py_result { + Ok(result) => { + let py_str: String = result.str().unwrap().extract().unwrap(); + let py = BigUint::from_str(&py_str).unwrap(); + assert_eq!(rs, Ok(py.clone()), "comb({n}, {k}): py={py} vs rs={rs:?}"); + } + Err(_) => { + assert!(rs.is_err(), "comb({n}, {k}): py raised error but rs={rs:?}"); + } + } + }); + } + + fn test_perm_impl(n: i64, k: Option) { + use std::str::FromStr; + let rs = perm(n, k); + crate::test::with_py_math(|_py, math| { + let py_func = math.getattr("perm").unwrap(); + let py_result = match k { + Some(k) => py_func.call1((n, k)), + None => py_func.call1((n,)), + }; + match py_result { + Ok(result) => { + let py_str: String = result.str().unwrap().extract().unwrap(); + let py = BigUint::from_str(&py_str).unwrap(); + assert_eq!(rs, Ok(py.clone()), "perm({n}, {k:?}): py={py} vs rs={rs:?}"); + } + Err(_) => { + assert!( + rs.is_err(), + "perm({n}, {k:?}): py raised error but rs={rs:?}" + ); + } + } + }); + } #[test] fn test_gcd() { - assert_eq!(gcd(&12.into(), &8.into()), 4.into()); - assert_eq!(gcd(&8.into(), &12.into()), 4.into()); - assert_eq!(gcd(&0.into(), &5.into()), 5.into()); - assert_eq!(gcd(&5.into(), &0.into()), 5.into()); - assert_eq!(gcd(&0.into(), &0.into()), 0.into()); - assert_eq!(gcd(&(-12).into(), &8.into()), 4.into()); - assert_eq!(gcd(&12.into(), &(-8).into()), 4.into()); - assert_eq!(gcd(&(-12).into(), &(-8).into()), 4.into()); - assert_eq!(gcd(&17.into(), &13.into()), 1.into()); + // No arguments + test_gcd_impl(&[]); + // Single argument + test_gcd_impl(&[5]); + test_gcd_impl(&[-5]); + test_gcd_impl(&[0]); + // Two arguments + test_gcd_impl(&[12, 8]); + test_gcd_impl(&[8, 12]); + test_gcd_impl(&[0, 5]); + test_gcd_impl(&[5, 0]); + test_gcd_impl(&[0, 0]); + test_gcd_impl(&[-12, 8]); + test_gcd_impl(&[12, -8]); + test_gcd_impl(&[-12, -8]); + test_gcd_impl(&[17, 13]); + // Multiple arguments + test_gcd_impl(&[12, 8, 4]); + test_gcd_impl(&[12, 8, 6]); + test_gcd_impl(&[100, 150, 200, 250]); } #[test] fn test_lcm() { - assert_eq!(lcm(&12.into(), &8.into()), 24.into()); - assert_eq!(lcm(&8.into(), &12.into()), 24.into()); - assert_eq!(lcm(&0.into(), &5.into()), 0.into()); - assert_eq!(lcm(&5.into(), &0.into()), 0.into()); - assert_eq!(lcm(&0.into(), &0.into()), 0.into()); - assert_eq!(lcm(&(-12).into(), &8.into()), 24.into()); - assert_eq!(lcm(&12.into(), &(-8).into()), 24.into()); - assert_eq!(lcm(&17.into(), &13.into()), 221.into()); + // No arguments + test_lcm_impl(&[]); + // Single argument + test_lcm_impl(&[5]); + test_lcm_impl(&[-5]); + // Two arguments + test_lcm_impl(&[12, 8]); + test_lcm_impl(&[8, 12]); + test_lcm_impl(&[0, 5]); + test_lcm_impl(&[5, 0]); + test_lcm_impl(&[0, 0]); + test_lcm_impl(&[-12, 8]); + test_lcm_impl(&[12, -8]); + test_lcm_impl(&[17, 13]); + // Multiple arguments + test_lcm_impl(&[2, 3, 4]); + test_lcm_impl(&[6, 8, 12]); } #[test] fn test_isqrt() { - assert_eq!(isqrt(&BigUint::from(0u32)), BigUint::from(0u32)); - assert_eq!(isqrt(&BigUint::from(1u32)), BigUint::from(1u32)); - assert_eq!(isqrt(&BigUint::from(4u32)), BigUint::from(2u32)); - assert_eq!(isqrt(&BigUint::from(9u32)), BigUint::from(3u32)); - assert_eq!(isqrt(&BigUint::from(10u32)), BigUint::from(3u32)); - assert_eq!(isqrt(&BigUint::from(15u32)), BigUint::from(3u32)); - assert_eq!(isqrt(&BigUint::from(16u32)), BigUint::from(4u32)); - assert_eq!(isqrt(&BigUint::from(100u32)), BigUint::from(10u32)); - assert_eq!(isqrt(&BigUint::from(1000000u32)), BigUint::from(1000u32)); - // Test large number - let large = BigUint::from(10u64).pow(40); - assert_eq!(isqrt(&large), BigUint::from(10u64).pow(20)); + test_isqrt_impl(0); + test_isqrt_impl(1); + test_isqrt_impl(4); + test_isqrt_impl(9); + test_isqrt_impl(10); + test_isqrt_impl(15); + test_isqrt_impl(16); + test_isqrt_impl(100); + test_isqrt_impl(1000000); + test_isqrt_impl(-1); + test_isqrt_impl(-100); } #[test] fn test_factorial() { - assert_eq!(factorial(0), BigUint::from(1u32)); - assert_eq!(factorial(1), BigUint::from(1u32)); - assert_eq!(factorial(5), BigUint::from(120u32)); - assert_eq!(factorial(10), BigUint::from(3628800u32)); - assert_eq!(factorial(20), BigUint::from(2432902008176640000u64)); + test_factorial_impl(0); + test_factorial_impl(1); + test_factorial_impl(5); + test_factorial_impl(10); + test_factorial_impl(20); + test_factorial_impl(50); + test_factorial_impl(100); + test_factorial_impl(-1); + test_factorial_impl(-10); } #[test] fn test_comb() { - assert_eq!(comb(5, 0), BigUint::from(1u32)); - assert_eq!(comb(5, 5), BigUint::from(1u32)); - assert_eq!(comb(5, 2), BigUint::from(10u32)); - assert_eq!(comb(10, 3), BigUint::from(120u32)); - assert_eq!(comb(3, 5), BigUint::from(0u32)); // k > n - assert_eq!(comb(100, 30), comb(100, 70)); // symmetry: C(n, k) == C(n, n-k) + test_comb_impl(5, 0); + test_comb_impl(5, 5); + test_comb_impl(5, 2); + test_comb_impl(10, 3); + test_comb_impl(3, 5); // k > n + test_comb_impl(100, 30); + test_comb_impl(100, 70); + test_comb_impl(-1, 0); + test_comb_impl(5, -1); + test_comb_impl(-5, -3); } #[test] fn test_perm() { - assert_eq!(perm(5, Some(0)), BigUint::from(1u32)); - assert_eq!(perm(5, Some(5)), BigUint::from(120u32)); - assert_eq!(perm(5, Some(2)), BigUint::from(20u32)); - assert_eq!(perm(5, None), BigUint::from(120u32)); // 5! - assert_eq!(perm(3, Some(5)), BigUint::from(0u32)); // k > n + test_perm_impl(5, Some(0)); + test_perm_impl(5, Some(5)); + test_perm_impl(5, Some(2)); + test_perm_impl(5, None); // 5! + test_perm_impl(3, Some(5)); // k > n + test_perm_impl(10, Some(3)); + test_perm_impl(20, None); + test_perm_impl(-1, None); + test_perm_impl(5, Some(-1)); + test_perm_impl(-5, Some(-3)); + } + + // Edge case tests using EDGE_I64 + #[test] + fn edgetest_gcd() { + // Test all edge values - gcd handles arbitrary large integers + for &a in &EDGE_I64 { + test_gcd_impl(&[a]); + for &b in &EDGE_I64 { + test_gcd_impl(&[a, b]); + } + } + // Additional large value tests + test_gcd_impl(&[i64::MAX, i64::MAX - 1]); + test_gcd_impl(&[i64::MIN, i64::MIN + 1]); + test_gcd_impl(&[i64::MAX, i64::MIN]); + } + + #[test] + fn edgetest_lcm() { + // Test all edge values - lcm handles arbitrary large integers + for &a in &EDGE_I64 { + test_lcm_impl(&[a]); + for &b in &EDGE_I64 { + test_lcm_impl(&[a, b]); + } + } + // Additional boundary tests + test_lcm_impl(&[i64::MAX, 1]); + test_lcm_impl(&[i64::MIN, 1]); + test_lcm_impl(&[i64::MIN, -1]); + } + + #[test] + fn edgetest_isqrt() { + for &n in &EDGE_I64 { + test_isqrt_impl(n); + } + // Additional boundary cases + test_isqrt_impl(i64::MAX); + test_isqrt_impl((1i64 << 62) - 1); + test_isqrt_impl(1i64 << 62); + // Perfect squares + for i in 0..20 { + let sq = i * i; + test_isqrt_impl(sq); + if sq > 0 { + test_isqrt_impl(sq - 1); + test_isqrt_impl(sq + 1); + } + } + } + + #[test] + fn edgetest_factorial() { + for &n in &EDGE_I64 { + // factorial only makes sense for reasonable n values + if n >= -10 && n <= 170 { + test_factorial_impl(n); + } + } + // Boundary cases + for n in 0..=25 { + test_factorial_impl(n); + } + test_factorial_impl(50); + test_factorial_impl(100); + test_factorial_impl(170); + } + + #[test] + fn edgetest_comb() { + let vals: Vec = EDGE_I64 + .iter() + .copied() + .filter(|&x| x >= -10 && x <= 200) + .collect(); + for &n in &vals { + for &k in &vals { + test_comb_impl(n, k); + } + } + // Large comb values + test_comb_impl(1000, 3); + test_comb_impl(1000, 500); + test_comb_impl(500, 250); + test_comb_impl(200, 100); + } + + #[test] + fn edgetest_perm() { + let vals: Vec = EDGE_I64 + .iter() + .copied() + .filter(|&x| x >= -10 && x <= 100) + .collect(); + for &n in &vals { + for &k in &vals { + test_perm_impl(n, Some(k)); + } + test_perm_impl(n, None); + } + // Large perm values + test_perm_impl(100, Some(5)); + test_perm_impl(1000, Some(3)); + } + + proptest::proptest! { + #[test] + fn proptest_gcd_2(a: i64, b: i64) { + test_gcd_impl(&[a, b]); + } + + #[test] + fn proptest_gcd_3(a: i64, b: i64, c: i64) { + test_gcd_impl(&[a, b, c]); + } + + #[test] + fn proptest_lcm_2(a in -100000i64..100000i64, b in -100000i64..100000i64) { + test_lcm_impl(&[a, b]); + } + + #[test] + fn proptest_lcm_3(a in -1000i64..1000i64, b in -1000i64..1000i64, c in -1000i64..1000i64) { + test_lcm_impl(&[a, b, c]); + } + + #[test] + fn proptest_isqrt_small(n in 0i64..1_000_000i64) { + test_isqrt_impl(n); + } + + #[test] + fn proptest_isqrt_medium(n in 0i64..(1i64 << 32)) { + test_isqrt_impl(n); + } + + #[test] + fn proptest_isqrt_large(n in 0i64..i64::MAX) { + test_isqrt_impl(n); + } + + #[test] + fn proptest_isqrt_negative(n in i64::MIN..0i64) { + test_isqrt_impl(n); + } + + #[test] + fn proptest_factorial(n in -10i64..50i64) { + test_factorial_impl(n); + } + + #[test] + fn proptest_comb_small(n in -10i64..100i64, k in -10i64..100i64) { + test_comb_impl(n, k); + } + + #[test] + fn proptest_comb_large(n in 100i64..1000i64, k in 0i64..50i64) { + test_comb_impl(n, k); + } + + #[test] + fn proptest_perm_small(n in -10i64..50i64, k in -10i64..50i64) { + test_perm_impl(n, Some(k)); + } + + #[test] + fn proptest_perm_large(n in 50i64..200i64, k in 0i64..20i64) { + test_perm_impl(n, Some(k)); + } + + #[test] + fn proptest_perm_none(n in -10i64..25i64) { + test_perm_impl(n, None); + } } } diff --git a/src/math/misc.rs b/src/math/misc.rs index 9809058..42d9e5a 100644 --- a/src/math/misc.rs +++ b/src/math/misc.rs @@ -3,7 +3,27 @@ use crate::{Error, Result, m}; super::libm_simple!(@1 ceil, floor, trunc); -super::libm_simple!(@2 nextafter); + +/// Return the next floating-point value after x towards y. +/// +/// If steps is provided, move that many steps towards y. +/// Steps must be non-negative. +#[inline] +pub fn nextafter(x: f64, y: f64, steps: Option) -> f64 { + match steps { + Some(n) => { + let mut result = x; + for _ in 0..n { + result = crate::m::nextafter(result, y); + if result == y { + break; + } + } + result + } + None => crate::m::nextafter(x, y), + } +} /// Return the absolute value of x. #[inline] @@ -43,9 +63,12 @@ pub fn isnan(x: f64) -> bool { /// given absolute and relative tolerances: /// `abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)` /// +/// Default tolerances: rel_tol = 1e-09, abs_tol = 0.0 /// Returns Err(EDOM) if rel_tol or abs_tol is negative. #[inline] -pub fn isclose(a: f64, b: f64, rel_tol: f64, abs_tol: f64) -> Result { +pub fn isclose(a: f64, b: f64, rel_tol: Option, abs_tol: Option) -> Result { + let rel_tol = rel_tol.unwrap_or(1e-09); + let abs_tol = abs_tol.unwrap_or(0.0); // Tolerances must be non-negative if rel_tol < 0.0 || abs_tol < 0.0 { return Err(Error::EDOM); @@ -63,6 +86,33 @@ pub fn isclose(a: f64, b: f64, rel_tol: f64, abs_tol: f64) -> Result { Ok(diff <= abs_tol.max(rel_tol * a.abs().max(b.abs()))) } +/// Fused multiply-add operation: (x * y) + z. +/// +/// Returns EDOM for invalid operation (NaN result from non-NaN inputs). +/// Returns ERANGE for overflow (infinite result from finite inputs). +#[inline] +pub fn fma(x: f64, y: f64, z: f64) -> Result { + let r = x.mul_add(y, z); + + // Fast path: if we got a finite result, we're done. + if r.is_finite() { + return Ok(r); + } + + // Non-finite result. Raise an exception if appropriate, else return r. + if r.is_nan() { + if !x.is_nan() && !y.is_nan() && !z.is_nan() { + // NaN result from non-NaN inputs. + return Err(Error::EDOM); + } + } else if x.is_finite() && y.is_finite() && z.is_finite() { + // Infinite result from finite inputs. + return Err(Error::ERANGE); + } + + Ok(r) +} + /// Return the mantissa and exponent of x as (m, e). /// /// m is a float and e is an integer such that x == m * 2**e exactly. @@ -143,10 +193,10 @@ pub fn ulp(x: f64) -> f64 { if x.is_infinite() { return x; } - let x2 = super::nextafter(x, f64::INFINITY); + let x2 = crate::m::nextafter(x, f64::INFINITY); if x2.is_infinite() { // Special case: x is the largest positive representable float - let x2 = super::nextafter(x, f64::NEG_INFINITY); + let x2 = crate::m::nextafter(x, f64::NEG_INFINITY); return x - x2; } x2 - x @@ -156,6 +206,9 @@ pub fn ulp(x: f64) -> f64 { mod tests { use super::*; + /// Edge integer values for testing functions like ldexp + const EDGE_INTS: [i32; 9] = [0, 1, -1, 100, -100, 1024, -1024, i32::MAX, i32::MIN]; + fn test_ldexp(x: f64, i: i32) { use pyo3::prelude::*; @@ -265,7 +318,7 @@ mod tests { #[test] fn edgetest_ldexp() { for &x in &crate::test::EDGE_VALUES { - for &i in &crate::test::EDGE_INTS { + for &i in &EDGE_INTS { test_ldexp(x, i); } } @@ -377,7 +430,7 @@ mod tests { fn test_isclose_impl(a: f64, b: f64, rel_tol: f64, abs_tol: f64) { use pyo3::prelude::*; - let rs_result = isclose(a, b, rel_tol, abs_tol); + let rs_result = isclose(a, b, Some(rel_tol), Some(abs_tol)); pyo3::Python::attach(|py| { let math = pyo3::types::PyModule::import(py, "math").unwrap(); @@ -447,4 +500,91 @@ mod tests { test_isclose_impl(a, b, 1e-9, 0.0); } } + + fn test_fma_impl(x: f64, y: f64, z: f64) { + use pyo3::prelude::*; + + let rs_result = fma(x, y, z); + + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + let py_func = math.getattr("fma").unwrap(); + let py_result = py_func.call1((x, y, z)); + + match py_result { + Ok(result) => { + let py_val: f64 = result.extract().unwrap(); + match rs_result { + Ok(rs_val) => { + if py_val.is_nan() && rs_val.is_nan() { + return; + } + assert_eq!( + py_val.to_bits(), + rs_val.to_bits(), + "fma({x}, {y}, {z}): py={py_val} vs rs={rs_val}" + ); + } + Err(e) => { + panic!("fma({x}, {y}, {z}): py={py_val} but rs returned error {e:?}"); + } + } + } + Err(e) => { + if e.is_instance_of::(py) { + assert_eq!( + rs_result.as_ref().err(), + Some(&Error::EDOM), + "fma({x}, {y}, {z}): py raised ValueError but rs={:?}", + rs_result + ); + } else if e.is_instance_of::(py) { + assert_eq!( + rs_result.as_ref().err(), + Some(&Error::ERANGE), + "fma({x}, {y}, {z}): py raised OverflowError but rs={:?}", + rs_result + ); + } else { + panic!("fma({x}, {y}, {z}): py raised unexpected error {e}"); + } + } + } + }); + } + + #[test] + fn test_fma() { + // Basic tests + test_fma_impl(2.0, 3.0, 4.0); // 2*3+4 = 10 + test_fma_impl(0.0, 0.0, 0.0); + test_fma_impl(1.0, 1.0, 1.0); + test_fma_impl(-1.0, 2.0, 3.0); + // Edge cases + test_fma_impl(f64::INFINITY, 1.0, 0.0); + test_fma_impl(0.0, f64::INFINITY, 0.0); // 0 * inf -> NaN -> ValueError + test_fma_impl(f64::NAN, 1.0, 0.0); + test_fma_impl(1.0, f64::NAN, 0.0); + test_fma_impl(1.0, 1.0, f64::NAN); + // Overflow cases + test_fma_impl(1e308, 10.0, 0.0); // overflow -> OverflowError + test_fma_impl(1e308, 1.0, 1e308); + } + + #[test] + fn edgetest_fma() { + for &x in &crate::test::EDGE_VALUES { + for &y in &crate::test::EDGE_VALUES { + test_fma_impl(x, y, 0.0); + test_fma_impl(x, y, 1.0); + } + } + } + + proptest::proptest! { + #[test] + fn proptest_fma(x: f64, y: f64, z: f64) { + test_fma_impl(x, y, z); + } + } } diff --git a/src/test.rs b/src/test.rs index 2902845..cf9ebf7 100644 --- a/src/test.rs +++ b/src/test.rs @@ -46,9 +46,6 @@ pub(crate) const EDGE_VALUES: [f64; 30] = [ std::f64::consts::TAU, // sin(2*PI) ≈ 0, cos(2*PI) = 1 ]; -/// Edge integer values for testing functions like ldexp -pub(crate) const EDGE_INTS: [i32; 9] = [0, 1, -1, 100, -100, 1024, -1024, i32::MAX, i32::MIN]; - pub(crate) fn unwrap<'py>( py: Python<'py>, py_v: PyResult>, @@ -94,6 +91,32 @@ pub(crate) fn test_math_1(x: f64, func_name: &str, rs_func: impl Fn(f64) -> crat }); } +/// Run a test with Python math module +pub(crate) fn with_py_math(f: F) -> R +where + F: FnOnce(Python, &pyo3::Bound) -> R, +{ + pyo3::Python::attach(|py| { + let math = pyo3::types::PyModule::import(py, "math").unwrap(); + f(py, &math) + }) +} + +/// Assert two f64 values are equal (handles NaN) +pub(crate) fn assert_f64_eq(py: f64, rs: f64, context: impl std::fmt::Display) { + if py.is_nan() && rs.is_nan() { + return; + } + assert_eq!( + py.to_bits(), + rs.to_bits(), + "{}: py={} vs rs={}", + context, + py, + rs + ); +} + /// Test a 2-argument function that returns Result pub(crate) fn test_math_2( x: f64,