From f52b6f2d7211471338138d9314c5524472136d46 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 09:48:01 -0400 Subject: [PATCH 1/2] Test polar AD for CUDA --- test/mooncake/polar.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 6442b7b33..ba0446f52 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -19,8 +19,17 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) #=if m == n AT = Diagonal{T, Vector{T}} - TestSuite.test_mooncake_left_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - TestSuite.test_mooncake_right_polar(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - end=# # broken due to pullback + TestSuite.test_mooncake_left_polar(AT, m; atol, rtol) + TestSuite.test_mooncake_right_polar(AT, m; atol, rtol) + end=# + end + if T ∈ BLASFloats && CUDA.functional() + m >= n && TestSuite.test_mooncake_left_polar(CuMatrix{T}, (m, n); atol, rtol) + n >= m && TestSuite.test_mooncake_right_polar(CuMatrix{T}, (m, n); atol, rtol) + #=if m == n + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_left_polar(AT, m; atol, rtol) + TestSuite.test_mooncake_right_polar(AT, m; atol, rtol) + end=# end end From 399b2a2d6ed6b599378d1af5714d79c45a6931f6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 10:05:57 -0400 Subject: [PATCH 2/2] Move atol and rtol --- test/mooncake/polar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index ba0446f52..ea7509426 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -13,8 +13,8 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(1234) + atol = rtol = m * n * TestSuite.precision(T) if !is_buildkite - atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) n >= m && TestSuite.test_mooncake_right_polar(T, (m, n); atol, rtol) #=if m == n