From 5ed39f6ec5cf3b2115cbd043ea497361178291de Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 14 Mar 2026 00:27:23 +0100 Subject: [PATCH 1/4] Implement DifferentiationInterfaceTest infrastructure --- test/Project.toml | 6 + test/basekernels/exponential.jl | 3 + test/basekernels/fbm.jl | 1 + test/basekernels/matern.jl | 2 + test/kernels/scaledkernel.jl | 1 + test/mokernels/intrinsiccoregion.jl | 1 + test/runtests.jl | 3 + test/test_utils.jl | 278 +++++++++++++++++++++++++++- 8 files changed, 288 insertions(+), 7 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index d560ec4e8..1bcf56600 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,11 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -22,8 +25,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "1" AxisArrays = "0.4.3" Compat = "3, 4" +DifferentiationInterface = "0.7" +DifferentiationInterfaceTest = "0.11" Distances = "0.10" Documenter = "0.25, 0.26, 0.27, 1" FiniteDifferences = "0.10.8, 0.11, 0.12" diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index 1b0bb5fc9..e1ea375f4 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -22,6 +22,7 @@ # Standardised tests. TestUtils.test_interface(k) test_ADs(SEKernel) + test_ADs_DIT(SEKernel) test_interface_ad_perf(_ -> SEKernel(), nothing, StableRNG(123456)) end @testset "ExponentialKernel" begin @@ -41,6 +42,7 @@ # Standardised tests. TestUtils.test_interface(k) test_ADs(ExponentialKernel) + test_ADs_DIT(ExponentialKernel) test_interface_ad_perf(_ -> ExponentialKernel(), nothing, StableRNG(123456)) end @testset "GammaExponentialKernel" begin @@ -59,6 +61,7 @@ @test k2(v1, v2) ≈ k(v1, v2) test_ADs(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()]) + test_ADs_DIT(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()]) test_params(k, ([γ],)) TestUtils.test_interface(GammaExponentialKernel(; γ=1.36)) diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 5e9a960b3..079c6dee3 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -14,6 +14,7 @@ test_interface(k; rtol=1e-5) @test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))" test_ADs(FBMKernel; ADs=[:ReverseDiff]) + test_ADs_DIT(FBMKernel; ADs=[:ReverseDiff]) # Tests failing for ForwardDiff and Zygote@0.6. # Related to: https://github.com/FluxML/Zygote.jl/issues/1036 diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 2167b0d9d..16a2adab5 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -21,6 +21,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(() -> MaternKernel(; nu=ν)) + test_ADs_DIT(() -> MaternKernel(; nu=ν)) test_params(k, ([ν],)) @@ -60,6 +61,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(Matern32Kernel) + test_ADs_DIT(Matern32Kernel) test_interface_ad_perf(_ -> Matern32Kernel(), nothing, StableRNG(123456)) end @testset "Matern52Kernel" begin diff --git a/test/kernels/scaledkernel.jl b/test/kernels/scaledkernel.jl index b40aee674..3dce5b5d6 100644 --- a/test/kernels/scaledkernel.jl +++ b/test/kernels/scaledkernel.jl @@ -12,6 +12,7 @@ # Standardised tests. TestUtils.test_interface(ks, Float64) test_ADs(x -> exp(x[1]) * SqExponentialKernel(), rand(1)) + test_ADs_DIT(x -> exp(x[1]) * SqExponentialKernel(), rand(1)) test_interface_ad_perf(c -> c * SEKernel(), 0.3, StableRNG(123456)) test_params(s * k, (k, [s])) diff --git a/test/mokernels/intrinsiccoregion.jl b/test/mokernels/intrinsiccoregion.jl index b08930441..40346afc7 100644 --- a/test/mokernels/intrinsiccoregion.jl +++ b/test/mokernels/intrinsiccoregion.jl @@ -41,6 +41,7 @@ ) test_ADs(icoregionkernel; dims=dims) + test_ADs_DIT(icoregionkernel; dims=dims) @test string(icoregionkernel) == string("Intrinsic Coregion Kernel: ", kernel, " with ", dims.out, " outputs") diff --git a/test/runtests.jl b/test/runtests.jl index e054b992a..a04ef02a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,9 @@ using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences using Compat: only +using ADTypes: AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoFiniteDifferences +using DifferentiationInterface: DifferentiationInterface as DI +using DifferentiationInterfaceTest: Scenario, test_differentiation using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils diff --git a/test/test_utils.jl b/test/test_utils.jl index fdb9f2871..624b0b7b7 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -94,9 +94,7 @@ testfunction(k::MOKernel, A) = sum(kernelmatrix(k, A)) testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A)) testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) -function test_ADs( - kernelfunction, args=nothing; ADs=_DEFAULT_ADS, dims=[3, 3] -) +function test_ADs(kernelfunction, args=nothing; ADs=_DEFAULT_ADS, dims=[3, 3]) test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) if !test_fd.anynonpass for AD in ADs @@ -116,9 +114,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) @inferred collect(pb(out)) end -function test_ADs( - k::MOKernel; ADs=_DEFAULT_ADS, dims=(in=3, out=2, obs=3) -) +function test_ADs(k::MOKernel; ADs=_DEFAULT_ADS, dims=(in=3, out=2, obs=3)) test_fd = test_FiniteDiff(k, dims) if !test_fd.anynonpass for AD in ADs @@ -372,7 +368,7 @@ of the results. for `f(args1...)` etc, `passes[2]` for `Zygote.pullback(f, args1...)`. Let ```julia out, pb = Zygote.pullback(f, args1...) -```` +``` then `passes[3]` indicates whether `pb(out)` checks should pass. This is useful when it is known that some of the tests fail and a fix isn't immediately available. @@ -490,3 +486,271 @@ function __default_input_types() Vector{Float64}, ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}} ] end + +# ============================================================================ +# DifferentiationInterfaceTest-based AD testing infrastructure +# ============================================================================ + +const FD_BACKEND = AutoFiniteDifferences(; fdm=FDM) + +const _DEFAULT_BACKENDS = let + backends = [AutoForwardDiff(), AutoReverseDiff()] + _TEST_ZYGOTE && pushfirst!(backends, AutoZygote()) + backends +end + +const _BACKEND_MAP = Dict{Symbol,Any}( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :FiniteDiff => FD_BACKEND, +) + +# Custom isapprox that handles Zygote returning `nothing` for zero gradients +function _isapprox_nothing(a, b; kwargs...) + a_val = isnothing(a) ? zero(b) : a + b_val = isnothing(b) ? zero(a) : b + return Base.isapprox(a_val, b_val; kwargs...) +end + +function _resolve_backends(ADs) + if ADs === nothing + return _DEFAULT_BACKENDS + else + return [_BACKEND_MAP[ad] for ad in ADs] + end +end + +""" + test_ADs_DIT(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) + +DIT-based version of `test_ADs`. Tests gradient correctness of kernel functions +across multiple AD backends using DifferentiationInterfaceTest.jl. + +Uses FiniteDifferences as reference backend. +""" +function test_ADs_DIT(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) + backends = _resolve_backends(ADs) + k = args === nothing ? kernelfunction() : kernelfunction(args) + rng = MersenneTwister(42) + + # First check that FiniteDifferences works (skip AD tests if not) + scenarios_smoke = _build_kernel_scenarios(k, kernelfunction, args, dims, rng) + fd_ok = true + @testset "FiniteDifferences (DIT)" begin + for (_, f, x) in scenarios_smoke + try + DI.gradient(f, FD_BACKEND, x) + catch + fd_ok = false + @test_broken false # Mark as known failure + end + end + end + + fd_ok || return nothing + + # Build scenarios with reference results + scenarios = Scenario[] + for (name, f, x) in scenarios_smoke + res1 = DI.gradient(f, FD_BACKEND, x) + push!(scenarios, Scenario{:gradient,:out}(f, x; res1=res1, name=name)) + end + + @testset "AD correctness (DIT)" begin + test_differentiation( + backends, + scenarios; + correctness=true, + isapprox=_isapprox_nothing, + atol=1e-8, + rtol=1e-5, + ) + end +end + +""" +Build a list of (name, f, x) tuples describing all gradient tests for a kernel. +""" +function _build_kernel_scenarios(k, kernelfunction, args, dims, rng) + scenarios = Tuple{String,Any,Any}[] + + # 1. kappa tests for SimpleKernels + if k isa SimpleKernel + for d in log.([eps(), rand(rng)]) + f = let k = k + x -> kappa(k, exp(x[1])) + end + push!(scenarios, ("kappa", f, [d])) + end + end + + # 2. kernel evaluation tests + x = rand(rng, dims[1]) + y = rand(rng, dims[1]) + + let k = k, y = y + push!(scenarios, ("k(x,y) w.r.t. x", x -> k(x, y), copy(x))) + end + let k = k, x = x + push!(scenarios, ("k(x,y) w.r.t. y", y -> k(x, y), copy(y))) + end + + # 3. hyperparameter tests + if args !== nothing + let kernelfunction = kernelfunction, x = x, y = y + push!( + scenarios, ("hyperparams (eval)", p -> kernelfunction(p)(x, y), copy(args)) + ) + end + end + + # 4. kernel matrix tests + A = rand(rng, dims...) + B = rand(rng, dims...) + for _testfn in (testfunction, testdiagfunction) + fname = _testfn === testfunction ? "kernelmatrix" : "kernelmatrix_diag" + for dim in 1:2 + let k = k, B = B, dim = dim + push!(scenarios, ("$fname(k,A,$dim)", a -> _testfn(k, a, dim), copy(A))) + end + let k = k, B = B, dim = dim + push!( + scenarios, + ("$fname(k,A,B,$dim) w.r.t. A", a -> _testfn(k, a, B, dim), copy(A)), + ) + end + let k = k, A = A, dim = dim + push!( + scenarios, + ("$fname(k,A,B,$dim) w.r.t. B", b -> _testfn(k, A, b, dim), copy(B)), + ) + end + if args !== nothing + let kernelfunction = kernelfunction, A = A, dim = dim + push!( + scenarios, + ( + "$fname hyperparams unary", + p -> _testfn(kernelfunction(p), A, dim), + copy(args), + ), + ) + end + let kernelfunction = kernelfunction, A = A, B = B, dim = dim + push!( + scenarios, + ( + "$fname hyperparams binary", + p -> _testfn(kernelfunction(p), A, B, dim), + copy(args), + ), + ) + end + end + end + end + + return scenarios +end + +""" + test_ADs_DIT(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) + +DIT-based version of `test_ADs` for multi-output kernels. +""" +function test_ADs_DIT(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) + backends = _resolve_backends(ADs) + rng = MersenneTwister(42) + + scenarios_smoke = _build_mokernel_scenarios(k, dims, rng) + fd_ok = true + @testset "FiniteDifferences (DIT)" begin + for (_, f, x) in scenarios_smoke + try + DI.gradient(f, FD_BACKEND, x) + catch + fd_ok = false + @test_broken false + end + end + end + + if !fd_ok + return nothing + end + + scenarios = Scenario[] + for (name, f, x) in scenarios_smoke + res1 = DI.gradient(f, FD_BACKEND, x) + push!(scenarios, Scenario{:gradient,:out}(f, x; res1=res1, name=name)) + end + + @testset "AD correctness (DIT)" begin + test_differentiation( + backends, + scenarios; + correctness=true, + isapprox=_isapprox_nothing, + atol=1e-8, + rtol=1e-5, + ) + end +end + +""" +Build (name, f, x) tuples for MOKernel gradient tests. +""" +function _build_mokernel_scenarios(k::MOKernel, dims, rng) + scenarios = Tuple{String,Any,Any}[] + + x = (rand(rng, dims.in), rand(rng, 1:(dims.out))) + y = (rand(rng, dims.in), rand(rng, 1:(dims.out))) + + # Kernel eval w.r.t. continuous components + let k = k, x = x, y = y + push!(scenarios, ("MOKernel eval w.r.t. x", a -> k((a, x[2]), y), copy(x[1]))) + end + let k = k, x = x, y = y + push!(scenarios, ("MOKernel eval w.r.t. y", b -> k(x, (b, y[2])), copy(y[1]))) + end + + # Kernel matrices + A = [(randn(rng, dims.in), rand(rng, 1:(dims.out))) for _ in 1:(dims.obs)] + B = [(randn(rng, dims.in), rand(rng, 1:(dims.out))) for _ in 1:(dims.obs)] + A_mat = reduce(hcat, first.(A)) + B_mat = reduce(hcat, first.(B)) + + for _testfn in (testfunction, testdiagfunction) + fname = _testfn === testfunction ? "kernelmatrix" : "kernelmatrix_diag" + + # Unary + let k = k, A = A + f = a -> begin + A_local = tuple.(eachcol(a), last.(A)) + _testfn(k, A_local) + end + push!(scenarios, ("$fname MO unary", f, copy(A_mat))) + end + + # Binary w.r.t. A + let k = k, A = A, B = B + f = a -> begin + A_local = tuple.(eachcol(a), last.(A)) + _testfn(k, A_local, B) + end + push!(scenarios, ("$fname MO binary w.r.t. A", f, copy(A_mat))) + end + + # Binary w.r.t. B + let k = k, A = A, B = B + f = b -> begin + B_local = tuple.(eachcol(b), last.(B)) + _testfn(k, A, B_local) + end + push!(scenarios, ("$fname MO binary w.r.t. B", f, copy(B_mat))) + end + end + + return scenarios +end From 91600bbda67064217db950eff819f844aa77d620 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 14 Mar 2026 12:33:00 +0100 Subject: [PATCH 2/4] Remove old test infrastructure --- test/Project.toml | 4 - test/basekernels/constant.jl | 5 - test/basekernels/cosine.jl | 1 - test/basekernels/exponential.jl | 6 - test/basekernels/exponentiated.jl | 1 - test/basekernels/fbm.jl | 5 - test/basekernels/gabor.jl | 6 - test/basekernels/matern.jl | 28 -- test/basekernels/nn.jl | 1 - test/basekernels/periodic.jl | 1 - test/basekernels/piecewisepolynomial.jl | 24 -- test/basekernels/polynomial.jl | 6 - test/basekernels/rational.jl | 16 +- test/kernels/kernelproduct.jl | 5 - test/kernels/kernelsum.jl | 5 - test/kernels/kerneltensorproduct.jl | 6 - test/kernels/normalizedkernel.jl | 5 - test/kernels/scaledkernel.jl | 4 - test/kernels/transformedkernel.jl | 42 -- test/mokernels/intrinsiccoregion.jl | 1 - test/runtests.jl | 4 +- test/test_utils.jl | 506 ++---------------------- test/transform/ardtransform.jl | 7 - test/transform/chaintransform.jl | 3 - test/transform/functiontransform.jl | 3 - test/transform/lineartransform.jl | 5 - test/transform/periodic_transform.jl | 3 - test/transform/scaletransform.jl | 3 - test/transform/selecttransform.jl | 6 - test/utils.jl | 13 - 30 files changed, 25 insertions(+), 700 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 1bcf56600..09c2ff006 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,6 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" @@ -18,7 +17,6 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -34,13 +32,11 @@ Distances = "0.10" Documenter = "0.25, 0.26, 0.27, 1" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10" -Functors = "0.2, 0.3, 0.4, 0.5" Kronecker = "0.4, 0.5" LogExpFunctions = "0.2, 0.3" PDMats = "0.9, 0.10, 0.11" ReverseDiff = "1.2" SpecialFunctions = "0.10, 1, 2" -StableRNGs = "1" StaticArrays = "1" Statistics = "1" Zygote = "0.6.38" diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index 435a1678a..e4db7bec1 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -10,7 +10,6 @@ TestUtils.test_interface(k, Float64) TestUtils.test_interface(k, Vector{String}) test_ADs(ZeroKernel) - test_interface_ad_perf(_ -> k, nothing, StableRNG(123456)) end @testset "WhiteKernel" begin k = WhiteKernel() @@ -25,7 +24,6 @@ TestUtils.test_interface(k, Float64) TestUtils.test_interface(k, Vector{String}) test_ADs(WhiteKernel) - test_interface_ad_perf(_ -> k, nothing, StableRNG(123456)) end @testset "ConstantKernel" begin c = 2.0 @@ -36,12 +34,9 @@ @test metric(ConstantKernel()) == KernelFunctions.Delta() @test metric(ConstantKernel(; c=2.0)) == KernelFunctions.Delta() @test repr(k) == "Constant Kernel (c = $(c))" - test_params(k, ([c],)) - # Standardised tests. TestUtils.test_interface(k, Float64) TestUtils.test_interface(k, Vector{String}) test_ADs(c -> ConstantKernel(; c=only(c)), [c]) - test_interface_ad_perf(c -> ConstantKernel(; c=c), c, StableRNG(123456)) end end diff --git a/test/basekernels/cosine.jl b/test/basekernels/cosine.jl index 2c083ae8d..ed24e4923 100644 --- a/test/basekernels/cosine.jl +++ b/test/basekernels/cosine.jl @@ -20,5 +20,4 @@ # Standardised tests. TestUtils.test_interface(k, Vector{Float64}) test_ADs(CosineKernel) - test_interface_ad_perf(_ -> CosineKernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index e1ea375f4..f3764454d 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -22,8 +22,6 @@ # Standardised tests. TestUtils.test_interface(k) test_ADs(SEKernel) - test_ADs_DIT(SEKernel) - test_interface_ad_perf(_ -> SEKernel(), nothing, StableRNG(123456)) end @testset "ExponentialKernel" begin k = ExponentialKernel() @@ -42,8 +40,6 @@ # Standardised tests. TestUtils.test_interface(k) test_ADs(ExponentialKernel) - test_ADs_DIT(ExponentialKernel) - test_interface_ad_perf(_ -> ExponentialKernel(), nothing, StableRNG(123456)) end @testset "GammaExponentialKernel" begin γ = 1.0 @@ -61,8 +57,6 @@ @test k2(v1, v2) ≈ k(v1, v2) test_ADs(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()]) - test_ADs_DIT(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()]) - test_params(k, ([γ],)) TestUtils.test_interface(GammaExponentialKernel(; γ=1.36)) #Coherence : diff --git a/test/basekernels/exponentiated.jl b/test/basekernels/exponentiated.jl index 14cc6d0d6..1d209dc49 100644 --- a/test/basekernels/exponentiated.jl +++ b/test/basekernels/exponentiated.jl @@ -14,5 +14,4 @@ # Standardised tests. This kernel appears to be fairly numerically unstable. TestUtils.test_interface(k; atol=1e-3) test_ADs(ExponentiatedKernel) - test_interface_ad_perf(_ -> ExponentiatedKernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 079c6dee3..d89741878 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -14,7 +14,6 @@ test_interface(k; rtol=1e-5) @test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))" test_ADs(FBMKernel; ADs=[:ReverseDiff]) - test_ADs_DIT(FBMKernel; ADs=[:ReverseDiff]) # Tests failing for ForwardDiff and Zygote@0.6. # Related to: https://github.com/FluxML/Zygote.jl/issues/1036 @@ -24,8 +23,4 @@ Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1] ) end - - test_params(k, ([h],)) - - test_interface_ad_perf(h -> FBMKernel(; h=h), h, StableRNG(123456)) end diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index 69dee1b9a..aa3047387 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -32,10 +32,4 @@ ), [ell, p], ) - test_interface_ad_perf((ell, p), StableRNG(123456)) do θ - gaborkernel(; - sqexponential_transform=ScaleTransform(θ[1]), - cosine_transform=ScaleTransform(θ[2]), - ) - end end diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 16a2adab5..8db4d6223 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -21,30 +21,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(() -> MaternKernel(; nu=ν)) - test_ADs_DIT(() -> MaternKernel(; nu=ν)) - - test_params(k, ([ν],)) - - # The performance of this kernel varies quite a lot from method to method, so - # requires us to specify whether performance tests pass or not. - @testset "performance ($T)" for T in [ - Vector{Float64}, - ColVecs{Float64,Matrix{Float64}}, - RowVecs{Float64,Matrix{Float64}}, - ] - xs = example_inputs(StableRNG(123456), Vector{Float64}) - test_interface_ad_perf( - ν -> MaternKernel(; nu=ν), - ν, - xs...; - passes=( - unary=(false, false, false), - binary=(false, false, false), - diag_unary=(true, false, false), - diag_binary=(true, false, false), - ), - ) - end end @testset "Matern32Kernel" begin k = Matern32Kernel() @@ -61,8 +37,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(Matern32Kernel) - test_ADs_DIT(Matern32Kernel) - test_interface_ad_perf(_ -> Matern32Kernel(), nothing, StableRNG(123456)) end @testset "Matern52Kernel" begin k = Matern52Kernel() @@ -82,7 +56,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(Matern52Kernel) - test_interface_ad_perf(_ -> Matern52Kernel(), nothing, StableRNG(123456)) end @testset "Matern72Kernel" begin k = Matern72Kernel() @@ -106,7 +79,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(Matern72Kernel) - test_interface_ad_perf(_ -> Matern72Kernel(), nothing, StableRNG(123456)) end @testset "Coherence Materns" begin @test kappa(MaternKernel(; ν=0.5), x) ≈ kappa(ExponentialKernel(), x) diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index ee4863356..c9dabeb69 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -8,5 +8,4 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(NeuralNetworkKernel) - test_interface_ad_perf(_ -> NeuralNetworkKernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index c33569714..d2cb4047f 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -17,5 +17,4 @@ TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64}) test_ADs(r -> PeriodicKernel(; r=exp.(r)), log.(r)) - test_params(k, (r,)) end diff --git a/test/basekernels/piecewisepolynomial.jl b/test/basekernels/piecewisepolynomial.jl index 4d9979e1a..abc561dd6 100644 --- a/test/basekernels/piecewisepolynomial.jl +++ b/test/basekernels/piecewisepolynomial.jl @@ -33,28 +33,4 @@ TestUtils.test_interface(k, RowVecs{Float64}; dim_in=2) test_ADs(() -> PiecewisePolynomialKernel{degree}(; dim=D)) - test_params(k, ()) - if VERSION >= v"1.8.0" - test_interface_ad_perf(nothing, StableRNG(123456)) do _ - PiecewisePolynomialKernel{degree}(; dim=D) - end - else - @testset "AD Alloc Performance ($T)" for T in [ - Vector{Float64}, - ColVecs{Float64,Matrix{Float64}}, - RowVecs{Float64,Matrix{Float64}}, - ] - test_interface_ad_perf( - _ -> PiecewisePolynomialKernel{degree}(; dim=D), - nothing, - example_inputs(StableRNG(123456), T)...; - passes=( - unary=(true, true, true), - binary=(true, true, true), - diag_unary=(true, true, true), - diag_binary=(true, true, true), - ), - ) - end - end end diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 80a3b1bb8..2dc3a0797 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -19,8 +19,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> LinearKernel(; c=x[1]), [c]) - test_params(LinearKernel(; c=c), ([c],)) - test_interface_ad_perf(c -> LinearKernel(; c=c), c, StableRNG(123456)) end @testset "PolynomialKernel" begin k = PolynomialKernel() @@ -42,9 +40,5 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> PolynomialKernel(; c=x[1]), [c]) - test_params(PolynomialKernel(; c=c), ([c],)) - test_interface_ad_perf( - c -> PolynomialKernel(; degree=2, c=c), 0.3, StableRNG(123456) - ) end end diff --git a/test/basekernels/rational.jl b/test/basekernels/rational.jl index 956295e04..210276b3c 100644 --- a/test/basekernels/rational.jl +++ b/test/basekernels/rational.jl @@ -28,8 +28,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> RationalKernel(; alpha=exp(x[1])), [α]) - test_params(k, ([α],)) - test_interface_ad_perf(α -> RationalKernel(; alpha=α), α, StableRNG(123456)) end @testset "RationalQuadraticKernel" begin @@ -56,12 +54,8 @@ # Standardised tests. TestUtils.test_interface(k, Float64) # test_ADs(x -> RationalQuadraticKernel(; alpha=exp(x[1])), [α]) - test_params(k, ([α],)) - test_interface_ad_perf(α, StableRNG(123456)) do α - RationalQuadraticKernel(; alpha=α) - end - # Check correctness and performance with non-Euclidean metrics. + # Check correctness with non-Euclidean metrics. TestUtils.test_interface( RationalQuadraticKernel(; alpha=α, metric=WeightedEuclidean([1.0, 2.0])), ColVecs{Float64}, @@ -70,10 +64,6 @@ RationalQuadraticKernel(; alpha=α, metric=WeightedEuclidean([1.0, 2.0])), RowVecs{Float64}, ) - types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}] - test_interface_ad_perf(α, StableRNG(123456)) do α - RationalQuadraticKernel(; alpha=α, metric=KernelFunctions.DotProduct()) - end end @testset "GammaRationalKernel" begin @@ -147,9 +137,5 @@ TestUtils.test_interface(k, Float64) a = 1.0 + rand() test_ADs(x -> GammaRationalKernel(; α=x[1], γ=x[2]), [a, 1 + 0.5 * rand()]) - test_params(GammaRationalKernel(; α=a, γ=x), ([a], [x])) - test_interface_ad_perf((2.0, 1.5), StableRNG(123456)) do θ - GammaRationalKernel(; α=θ[1], γ=θ[2]) - end end end diff --git a/test/kernels/kernelproduct.jl b/test/kernels/kernelproduct.jl index 94e816ade..1f63155f2 100644 --- a/test/kernels/kernelproduct.jl +++ b/test/kernels/kernelproduct.jl @@ -21,11 +21,6 @@ test_ADs( x -> KernelProduct(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1) ) - test_interface_ad_perf(2.4, StableRNG(123456)) do c - KernelProduct(SqExponentialKernel(), LinearKernel(; c=c)) - end - test_params(k1 * k2, (k1, k2)) - nested_k = RBFKernel() * (LinearKernel() + CosineKernel() * RBFKernel()) test_type_stability(nested_k) end diff --git a/test/kernels/kernelsum.jl b/test/kernels/kernelsum.jl index eb64aeee8..deb92b775 100644 --- a/test/kernels/kernelsum.jl +++ b/test/kernels/kernelsum.jl @@ -22,11 +22,6 @@ test_interface(kvec, Float64) test_interface(ConstantKernel(; c=1.5) + WhiteKernel(), Vector{String}) test_ADs(x -> KernelSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), rand(1)) - test_interface_ad_perf(2.4, StableRNG(123456)) do c - KernelSum(SqExponentialKernel(), LinearKernel(; c=c)) - end - - test_params(k1 + k2, (k1, k2)) # Regression tests for https://github.com//issues/458 @testset for k in ( diff --git a/test/kernels/kerneltensorproduct.jl b/test/kernels/kerneltensorproduct.jl index a59c972ee..05fe3a37f 100644 --- a/test/kernels/kerneltensorproduct.jl +++ b/test/kernels/kerneltensorproduct.jl @@ -46,12 +46,6 @@ rand(1); dims=[2, 2], ) - types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}] - test_interface_ad_perf(2.1, StableRNG(123456), types) do c - KernelTensorProduct(SqExponentialKernel(), LinearKernel(; c=c)) - end - test_params(KernelTensorProduct(k1, k2), (k1, k2)) - @testset "single kernel" begin kernel = KernelTensorProduct(k1) @test length(kernel) == 1 diff --git a/test/kernels/normalizedkernel.jl b/test/kernels/normalizedkernel.jl index fc020fe4c..cf7cd2bbd 100644 --- a/test/kernels/normalizedkernel.jl +++ b/test/kernels/normalizedkernel.jl @@ -11,9 +11,4 @@ # Standardised tests. TestUtils.test_interface(kn, Float64) test_ADs(x -> NormalizedKernel(exp(x[1]) * SqExponentialKernel()), rand(1)) - test_interface_ad_perf(0.3, StableRNG(123456)) do c - NormalizedKernel(c * SqExponentialKernel()) - end - - test_params(kn, k) end diff --git a/test/kernels/scaledkernel.jl b/test/kernels/scaledkernel.jl index 3dce5b5d6..57214110e 100644 --- a/test/kernels/scaledkernel.jl +++ b/test/kernels/scaledkernel.jl @@ -12,8 +12,4 @@ # Standardised tests. TestUtils.test_interface(ks, Float64) test_ADs(x -> exp(x[1]) * SqExponentialKernel(), rand(1)) - test_ADs_DIT(x -> exp(x[1]) * SqExponentialKernel(), rand(1)) - test_interface_ad_perf(c -> c * SEKernel(), 0.3, StableRNG(123456)) - - test_params(s * k, (k, [s])) end diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index 51d55aa1f..6b3ca9b9a 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -30,46 +30,4 @@ Vector{String}, ) test_ADs(x -> SqExponentialKernel() ∘ ScaleTransform(x[1]), rand(1)) - test_interface_ad_perf(0.35, StableRNG(123456)) do λ - SqExponentialKernel() ∘ ScaleTransform(λ) - end - - # Test implicit gradients - @testset "Implicit gradients" begin - if _TEST_ZYGOTE - k = SqExponentialKernel() ∘ ScaleTransform(2.0) - ps = params(k) - X = rand(10, 1) - x = vec(X) - A = rand(10, 10) - # Implicit - g1 = Zygote.gradient(ps) do - tr(kernelmatrix(k, X; obsdim=1) * A) - end - # Explicit - g2 = Zygote.gradient(k) do k - tr(kernelmatrix(k, X; obsdim=1) * A) - end - - # Implicit for a vector - g3 = Zygote.gradient(ps) do - tr(kernelmatrix(k, x) * A) - end - @test g1[first(ps)] ≈ first(g2).transform.s - @test g1[first(ps)] ≈ g3[first(ps)] - else - @test_broken false # Zygote not supported on Julia >= 1.12 - end - end - - @testset "Parameters" begin - k = ConstantKernel(; c=rand(rng)) - # c = Chain(Dense(3, 2)) - - test_params(k ∘ ScaleTransform(s), (k, [s])) - test_params(k ∘ ARDTransform(v), (k, v)) - test_params(k ∘ LinearTransform(P), (k, P)) - test_params(k ∘ LinearTransform(P) ∘ ScaleTransform(s), (k, [s], P)) - # test_params(k ∘ FunctionTransform(c), (k, c)) - end end diff --git a/test/mokernels/intrinsiccoregion.jl b/test/mokernels/intrinsiccoregion.jl index 40346afc7..b08930441 100644 --- a/test/mokernels/intrinsiccoregion.jl +++ b/test/mokernels/intrinsiccoregion.jl @@ -41,7 +41,6 @@ ) test_ADs(icoregionkernel; dims=dims) - test_ADs_DIT(icoregionkernel; dims=dims) @test string(icoregionkernel) == string("Intrinsic Coregion Kernel: ", kernel, " with ", dims.out, " outputs") diff --git a/test/runtests.jl b/test/runtests.jl index a04ef02a7..842705a21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,14 +4,12 @@ using ChainRulesCore using ChainRulesTestUtils using Distances using Documenter -using Functors: functor using Kronecker: Kronecker using LinearAlgebra using LogExpFunctions using PDMats using Random using SpecialFunctions -using StableRNGs using StaticArrays using Statistics using Test @@ -26,7 +24,7 @@ using DifferentiationInterfaceTest: Scenario, test_differentiation using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils -using KernelFunctions.TestUtils: test_interface, test_type_stability, example_inputs +using KernelFunctions.TestUtils: test_interface, test_type_stability # The GROUP is used to run different sets of tests in parallel on the GitHub Actions CI. # If you want to introduce a new group, ensure you also add it to .github/workflows/ci.yml diff --git a/test/test_utils.jl b/test/test_utils.jl index 624b0b7b7..ec9bf2e79 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,89 +1,13 @@ # More test utilities. Can't be included in KernelFunctions because they introduce a number # of additional deps that we don't want to have in the main package. -# Check parameters of kernels. `trainable`, `params!`, and `params` are taken directly from -# Flux.jl so as to avoid having to depend on Flux at test-time. -trainable(m) = functor(m)[1] - -params!(p::Zygote.Params, x::AbstractArray{<:Number}, seen=Zygote.IdSet()) = push!(p, x) - -function params!(p::Zygote.Params, x, seen=Zygote.IdSet()) - x in seen && return nothing - push!(seen, x) - for child in trainable(x) - params!(p, child, seen) - end -end - -function params(m...) - ps = Zygote.Params() - params!(ps, m) - return ps -end - -function test_params(kernel, reference) - params_kernel = params(kernel) - params_reference = params(reference) - - @test length(params_kernel) == length(params_reference) - @test all(p == q for (p, q) in zip(params_kernel, params_reference)) -end - # AD utilities const _TEST_ZYGOTE = VERSION < v"1.12" -# Type to work around some performance issues that can happen on the reverse-pass of Zygote. -# This context doesn't allow any globals. Don't use this if you use globals in your -# programme. -struct NoContext <: Zygote.AContext end - -Zygote.cache(cx::NoContext) = (cache_fields = nothing) -Base.haskey(cx::NoContext, x) = false -Zygote.accum_param(::NoContext, x, Δ) = Δ - const FDM = FiniteDifferences.central_fdm(5, 1) -const _DEFAULT_ADS = - _TEST_ZYGOTE ? [:Zygote, :ForwardDiff, :ReverseDiff] : [:ForwardDiff, :ReverseDiff] - -gradient(f, s::Symbol, args) = gradient(f, Val(s), args) - -function gradient(f, ::Val{:Zygote}, args) - g = only(Zygote.gradient(f, args)) - if isnothing(g) - if args isa AbstractArray{<:Real} - return zeros(size(args)) # To respect the same output as other ADs - else - return zeros.(size.(args)) - end - else - return g - end -end - -function gradient(f, ::Val{:ForwardDiff}, args) - return ForwardDiff.gradient(f, args) -end - -function gradient(f, ::Val{:ReverseDiff}, args) - return ReverseDiff.gradient(f, args) -end - -function gradient(f, ::Val{:FiniteDiff}, args) - return only(FiniteDifferences.grad(FDM, f, args)) -end - -function compare_gradient(f, ::Val{:FiniteDiff}, args) - @test_nowarn gradient(f, :FiniteDiff, args) -end - -function compare_gradient(f, AD::Symbol, args) - grad_AD = gradient(f, AD, args) - grad_FD = gradient(f, :FiniteDiff, args) - @test grad_AD ≈ grad_FD atol = 1e-8 rtol = 1e-5 -end - +# Helper functions for kernel matrix tests testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B; obsdim=dim)) testfunction(k, A, dim) = sum(kernelmatrix(k, A; obsdim=dim)) testdiagfunction(k, A, dim) = sum(kernelmatrix_diag(k, A; obsdim=dim)) @@ -94,399 +18,6 @@ testfunction(k::MOKernel, A) = sum(kernelmatrix(k, A)) testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A)) testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) -function test_ADs(kernelfunction, args=nothing; ADs=_DEFAULT_ADS, dims=[3, 3]) - test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) - if !test_fd.anynonpass - for AD in ADs - test_AD(AD, kernelfunction, args, dims) - end - end -end - -function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) - if !_TEST_ZYGOTE - @test_broken false - return nothing - end - @inferred f(args...) - @inferred Zygote._pullback(ctx, f, args...) - out, pb = Zygote._pullback(ctx, f, args...) - @inferred collect(pb(out)) -end - -function test_ADs(k::MOKernel; ADs=_DEFAULT_ADS, dims=(in=3, out=2, obs=3)) - test_fd = test_FiniteDiff(k, dims) - if !test_fd.anynonpass - for AD in ADs - test_AD(AD, k, dims) - end - end -end - -function test_FiniteDiff(k::MOKernel, dims=(in=3, out=2, obs=3)) - rng = MersenneTwister(42) - @testset "FiniteDifferences" begin - ## Testing Kernel Functions - x = (rand(rng, dims.in), rand(rng, 1:(dims.out))) - y = (rand(rng, dims.in), rand(rng, 1:(dims.out))) - - @test_nowarn gradient(:FiniteDiff, x[1]) do a - k((a, x[2]), y) - end - - ## Testing Kernel Matrices - - A = [(randn(rng, dims.in), rand(rng, 1:(dims.out))) for i in 1:(dims.obs)] - B = [(randn(rng, dims.in), rand(rng, 1:(dims.out))) for i in 1:(dims.obs)] - - @test_nowarn gradient(:FiniteDiff, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testfunction(k, A) - end - @test_nowarn gradient(:FiniteDiff, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testfunction(k, A, B) - end - @test_nowarn gradient(:FiniteDiff, reduce(hcat, first.(B))) do b - B = tuple.(eachcol(b), last.(B)) - testfunction(k, A, B) - end - - @test_nowarn gradient(:FiniteDiff, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testdiagfunction(k, A) - end - @test_nowarn gradient(:FiniteDiff, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testdiagfunction(k, A, B) - end - @test_nowarn gradient(:FiniteDiff, reduce(hcat, first.(B))) do b - B = tuple.(eachcol(b), last.(B)) - testdiagfunction(k, A, B) - end - end -end - -function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3]) - @testset "$(AD)" begin - k = if args === nothing - kernelfunction() - else - kernelfunction(args) - end - rng = MersenneTwister(42) - - if k isa SimpleKernel - @testset "kappa function" begin - for d in log.([eps(), rand(rng)]) - compare_gradient(AD, [d]) do x - kappa(k, exp(x[1])) - end - end - end - end - - @testset "kernel evaluations" begin - x = rand(rng, dims[1]) - y = rand(rng, dims[1]) - compare_gradient(AD, x) do x - k(x, y) - end - compare_gradient(AD, y) do y - k(x, y) - end - if !(args === nothing) - @testset "hyperparameters" begin - compare_gradient(AD, args) do p - kernelfunction(p)(x, y) - end - end - end - end - - @testset "kernel matrices" begin - A = rand(rng, dims...) - B = rand(rng, dims...) - @testset "$(_testfn)" for _testfn in (testfunction, testdiagfunction) - for dim in 1:2 - compare_gradient(AD, A) do a - _testfn(k, a, dim) - end - compare_gradient(AD, A) do a - _testfn(k, a, B, dim) - end - compare_gradient(AD, B) do b - _testfn(k, A, b, dim) - end - if !(args === nothing) - @testset "hyperparameters" begin - compare_gradient(AD, args) do p - _testfn(kernelfunction(p), A, dim) - end - compare_gradient(AD, args) do p - _testfn(kernelfunction(p), A, B, dim) - end - end - end - end - end - end # kernel matrices - end -end - -function test_AD(AD::Symbol, k::MOKernel, dims=(in=3, out=2, obs=3)) - @testset "$(AD)" begin - rng = MersenneTwister(42) - - # Testing kernel evaluations - x = (rand(rng, dims.in), rand(rng, 1:(dims.out))) - y = (rand(rng, dims.in), rand(rng, 1:(dims.out))) - - compare_gradient(AD, x[1]) do a - k((a, x[2]), y) - end - compare_gradient(AD, y[1]) do b - k(x, (b, y[2])) - end - - # Testing kernel matrices - A = [(randn(rng, dims.in), rand(rng, 1:(dims.out))) for i in 1:(dims.obs)] - B = [(randn(rng, dims.in), rand(rng, 1:(dims.out))) for i in 1:(dims.obs)] - - compare_gradient(AD, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testfunction(k, A) - end - compare_gradient(AD, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testfunction(k, A, B) - end - compare_gradient(AD, reduce(hcat, first.(B))) do b - B = tuple.(eachcol(b), last.(B)) - testfunction(k, A, B) - end - compare_gradient(AD, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testdiagfunction(k, A) - end - compare_gradient(AD, reduce(hcat, first.(A))) do a - A = tuple.(eachcol(a), last.(A)) - testdiagfunction(k, A, B) - end - compare_gradient(AD, reduce(hcat, first.(B))) do b - B = tuple.(eachcol(b), last.(B)) - testdiagfunction(k, A, B) - end - end -end - -function count_allocs(f, args...) - stats = @timed f(args...) - return Base.gc_alloc_count(stats.gcstats) -end - -""" - constant_allocs_heuristic(f, args1::T, args2::T) where {T} - -True if number of allocations associated with evaluating `f(args1...)` is equal to those -required to evaluate `f(args2...)`. Runs `f` beforehand to ensure that compilation-related -allocations are not included. - -Why is this a good test? In lots of situations it will be the case that the total amount of -memory allocated by a function will vary as the input sizes vary, but the total _number_ -of allocations ought to be constant. A common performance bug is that the number of -allocations actually does scale with the size of the inputs (e.g. due to a type -instability), and we would very much like to know if this is happening. - -Typically this kind of condition is not a sufficient condition for good performance, but it -is certainly a necessary condition. - -This kind of test is very quick to conduct (just requires running `f` 4 times). It's also -easier to write than simply checking that the total number of allocations used to execute -a function is below some arbitrary `f`-dependent threshold. -""" -function constant_allocs_heuristic(f, args1::T, args2::T) where {T} - - # Ensure that we're not counting allocations associated with compilation. - f(args1...) - f(args2...) - - allocs_1 = count_allocs(f, args1...) - allocs_2 = count_allocs(f, args2...) - return (allocs_1, allocs_2) -end - -""" - ad_constant_allocs_heuristic(f, args1::T, args2::T; Δ1=nothing, Δ2=nothing) where {T} - -Assesses `constant_allocs_heuristic` for `f`, `Zygote.pullback(f, args...)` and its -pullback for both of `args1` and `args2`. - -`Δ1` and `Δ2` are passed to the pullback associated with `Zygote.pullback(f, args1...)` -and `Zygote.pullback(f, args2...)` respectively. If left as `nothing`, it is assumed that -the output of the primal is an acceptable cotangent to be passed to the corresponding -pullback. -""" -function ad_constant_allocs_heuristic( - f, args1::T, args2::T; Δ1=nothing, Δ2=nothing -) where {T} - - # Check that primal has constant allocations. - primal_heuristic = constant_allocs_heuristic(f, args1, args2) - - # Check that forwards-pass has constant allocations. - forwards_heuristic = constant_allocs_heuristic( - (args...) -> Zygote.pullback(f, args...), args1, args2 - ) - - # Check that pullback has constant allocations for both arguments. Run twice to remove - # compilation-related allocations. - - # First thing - out1, pb1 = Zygote.pullback(f, args1...) - Δ1_val = Δ1 === nothing ? out1 : Δ1 - pb1(Δ1_val) - allocs_1 = count_allocs(pb1, Δ1_val) - - # Second thing - out2, pb2 = Zygote.pullback(f, args2...) - Δ2_val = Δ2 === nothing ? out2 : Δ2 - pb2(Δ2_val) - allocs_2 = count_allocs(pb2, Δ2 === nothing ? out2 : Δ2) - - return primal_heuristic, forwards_heuristic, (allocs_1, allocs_2) -end - -""" - test_zygote_perf_heuristic( - f, name::String, args1, args2, passes, Δ1=nothing, Δ2=nothing - ) - -Executes `ad_constant_allocs_heuristic(f, args1, args2; Δ1, Δ2)` and creates a testset out -of the results. -`passes` is a 3-tuple of booleans. `passes[1]` indicates whether the checks should pass -for `f(args1...)` etc, `passes[2]` for `Zygote.pullback(f, args1...)`. Let -```julia -out, pb = Zygote.pullback(f, args1...) -``` -then `passes[3]` indicates whether `pb(out)` checks should pass. -This is useful when it is known that -some of the tests fail and a fix isn't immediately available. -""" -function test_zygote_perf_heuristic( - f, name::String, args1, args2, passes, Δ1=nothing, Δ2=nothing -) - @testset "$name" begin - if !_TEST_ZYGOTE - @test_broken false - return nothing - end - primal, fwd, pb = ad_constant_allocs_heuristic(f, args1, args2; Δ1, Δ2) - if passes[1] - @test primal[1] == primal[2] - else - @test_broken primal[1] == primal[2] - end - if passes[2] - @test abs(fwd[1] - fwd[2]) <= 1 - else - @test_broken fwd[1] == fwd[2] - end - if passes[3] - @test abs(pb[1] - pb[2]) ≤ 1 - else - @test_broken pb[1] == pb[2] - end - end -end - -""" - test_interface_ad_perf( - f, - θ, - x1::AbstractVector, - x2::AbstractVector, - x3::AbstractVector, - x4::AbstractVector; - passes=( - unary=(true, true, true), - binary=(true, true, true), - diag_unary=(true, true, true), - diag_binary=(true, true, true), - ), - ) - -Runs `test_zygote_perf_heuristic` for unary / binary methods of `kernelmatrix` and -`kernelmatrix_diag`. `f(θ)` must, therefore, output a `Kernel`. -`passes` should be a `NamedTuple` of the form above, providing the `passes` argument -for `test_zygote_perf_heuristic` for each of the methods of `kernelmatrix` and -`kernelmatrix_diag`. -""" -function test_interface_ad_perf( - f, - θ, - x1::AbstractVector, - x2::AbstractVector, - x3::AbstractVector, - x4::AbstractVector; - passes=( - unary=(true, true, true), - binary=(true, true, true), - diag_unary=(true, true, true), - diag_binary=(true, true, true), - ), -) - test_zygote_perf_heuristic( - (θ, x) -> kernelmatrix(f(θ), x), - "kernelmatrix (unary)", - (θ, x1), - (θ, x2), - passes.unary, - ) - test_zygote_perf_heuristic( - (θ, x, y) -> kernelmatrix(f(θ), x, y), - "kernelmatrix (binary)", - (θ, x1, x2), - (θ, x3, x4), - passes.binary, - ) - test_zygote_perf_heuristic( - (θ, x) -> kernelmatrix_diag(f(θ), x), - "kernelmatrix_diag (unary)", - (θ, x1), - (θ, x2), - passes.diag_unary, - ) - return test_zygote_perf_heuristic( - (θ, x) -> kernelmatrix_diag(f(θ), x, x), - "kernelmatrix_diag (binary)", - (θ, x1), - (θ, x2), - passes.diag_binary, - ) -end - -""" - test_interface_ad_perf(f, θ, rng::AbstractRNG, types=__default_input_types()) - -Runs `test_interface_ad_perf` for each of the types in `types`. -Often a good idea to just provide `f`, `θ` and `rng`, as `__default_input_types()` is -intended to cover commonly used types. -Sometimes it's necessary to specify that only a subset should be used for a particular -kernel e.g. where it's only valid for 1-dimensional inputs. -""" -function test_interface_ad_perf(f, θ, rng::AbstractRNG, types=__default_input_types()) - @testset "AD Alloc Performance ($T)" for T in types - test_interface_ad_perf(f, θ, example_inputs(rng, T)...) - end -end - -function __default_input_types() - return [ - Vector{Float64}, ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}} - ] -end - # ============================================================================ # DifferentiationInterfaceTest-based AD testing infrastructure # ============================================================================ @@ -521,15 +52,24 @@ function _resolve_backends(ADs) end end -""" - test_ADs_DIT(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) +# Thin wrappers for direct gradient computation (used by chainrules.jl and selecttransform.jl) +function gradient(f, s::Symbol, args) + return DI.gradient(f, _BACKEND_MAP[s], args) +end + +function compare_gradient(f, AD::Symbol, args) + grad_AD = DI.gradient(f, _BACKEND_MAP[AD], args) + grad_FD = DI.gradient(f, FD_BACKEND, args) + @test _isapprox_nothing(grad_AD, grad_FD; atol=1e-8, rtol=1e-5) +end -DIT-based version of `test_ADs`. Tests gradient correctness of kernel functions -across multiple AD backends using DifferentiationInterfaceTest.jl. +""" + test_ADs(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) -Uses FiniteDifferences as reference backend. +Tests gradient correctness of kernel functions across multiple AD backends +using DifferentiationInterfaceTest.jl. Uses FiniteDifferences as reference backend. """ -function test_ADs_DIT(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) +function test_ADs(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) backends = _resolve_backends(ADs) k = args === nothing ? kernelfunction() : kernelfunction(args) rng = MersenneTwister(42) @@ -537,7 +77,7 @@ function test_ADs_DIT(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) # First check that FiniteDifferences works (skip AD tests if not) scenarios_smoke = _build_kernel_scenarios(k, kernelfunction, args, dims, rng) fd_ok = true - @testset "FiniteDifferences (DIT)" begin + @testset "FiniteDifferences" begin for (_, f, x) in scenarios_smoke try DI.gradient(f, FD_BACKEND, x) @@ -557,7 +97,7 @@ function test_ADs_DIT(kernelfunction, args=nothing; ADs=nothing, dims=[3, 3]) push!(scenarios, Scenario{:gradient,:out}(f, x; res1=res1, name=name)) end - @testset "AD correctness (DIT)" begin + @testset "AD correctness" begin test_differentiation( backends, scenarios; @@ -655,17 +195,17 @@ function _build_kernel_scenarios(k, kernelfunction, args, dims, rng) end """ - test_ADs_DIT(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) + test_ADs(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) -DIT-based version of `test_ADs` for multi-output kernels. +Tests gradient correctness of multi-output kernel functions across multiple AD backends. """ -function test_ADs_DIT(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) +function test_ADs(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) backends = _resolve_backends(ADs) rng = MersenneTwister(42) scenarios_smoke = _build_mokernel_scenarios(k, dims, rng) fd_ok = true - @testset "FiniteDifferences (DIT)" begin + @testset "FiniteDifferences" begin for (_, f, x) in scenarios_smoke try DI.gradient(f, FD_BACKEND, x) @@ -686,7 +226,7 @@ function test_ADs_DIT(k::MOKernel; ADs=nothing, dims=(in=3, out=2, obs=3)) push!(scenarios, Scenario{:gradient,:out}(f, x; res1=res1, name=name)) end - @testset "AD correctness (DIT)" begin + @testset "AD correctness" begin test_differentiation( backends, scenarios; diff --git a/test/transform/ardtransform.jl b/test/transform/ardtransform.jl index 95631a300..26c9969c4 100644 --- a/test/transform/ardtransform.jl +++ b/test/transform/ardtransform.jl @@ -43,11 +43,4 @@ @test repr(t) == "ARD Transform (dims: $D)" test_ADs(x -> SEKernel() ∘ ARDTransform(exp.(x)), randn(rng, 3)) - types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}] - test_interface_ad_perf([1.0, 2.0], StableRNG(123456), types) do ls - SEKernel() ∘ ARDTransform(ls) - end - test_interface_ad_perf([1.0], StableRNG(123456), [Vector{Float64}]) do ls - SEKernel() ∘ ARDTransform(ls) - end end diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index 8d7442e05..8c034f0d5 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -28,7 +28,4 @@ randn(rng, 4); ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote ) - test_interface_ad_perf((1.0, 2.0), StableRNG(123456), [Vector{Float64}]) do θ - SEKernel() ∘ (ScaleTransform(θ[1]) ∘ PeriodicTransform(θ[2])) - end end diff --git a/test/transform/functiontransform.jl b/test/transform/functiontransform.jl index e11be6f8b..164983bc2 100644 --- a/test/transform/functiontransform.jl +++ b/test/transform/functiontransform.jl @@ -38,7 +38,4 @@ @test repr(FunctionTransform(sin)) == "Function Transform: $(sin)" f(a, x) = sin.(a .* x) test_ADs(x -> SEKernel() ∘ FunctionTransform(y -> f(x, y)), randn(rng, 3)) - test_interface_ad_perf(nothing, StableRNG(123456), [Vector{Float64}]) do _ - SEKernel() ∘ FunctionTransform(sin) - end end diff --git a/test/transform/lineartransform.jl b/test/transform/lineartransform.jl index 7b4836532..5b22a4d2d 100644 --- a/test/transform/lineartransform.jl +++ b/test/transform/lineartransform.jl @@ -43,9 +43,4 @@ @test repr(t) == "Linear transform (size(A) = ($Dout, $Din))" test_ADs(x -> SEKernel() ∘ LinearTransform(x), randn(rng, 3, 3)) - rng = StableRNG(123456) - types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}] - test_interface_ad_perf(randn(rng, 3, 2), rng, types) do θ - SEKernel() ∘ LinearTransform(θ) - end end diff --git a/test/transform/periodic_transform.jl b/test/transform/periodic_transform.jl index 9f7d7e45a..59bba1ea9 100644 --- a/test/transform/periodic_transform.jl +++ b/test/transform/periodic_transform.jl @@ -12,7 +12,4 @@ @test kernelmatrix(k_eq_periodic, x) ≈ kernelmatrix(k_eq_transform, x) end - test_interface_ad_perf(0.95, StableRNG(123456), [Vector{Float64}]) do θ - SEKernel() ∘ PeriodicTransform(θ) - end end diff --git a/test/transform/scaletransform.jl b/test/transform/scaletransform.jl index 1d0988441..2143b93bd 100644 --- a/test/transform/scaletransform.jl +++ b/test/transform/scaletransform.jl @@ -20,9 +20,6 @@ @test isequal(ScaleTransform(s), ScaleTransform(s)) @test repr(t) == "Scale Transform (s = $(s2))" test_ADs(x -> SEKernel() ∘ ScaleTransform(exp(x[1])), randn(rng, 1)) - test_interface_ad_perf(0.3, StableRNG(123456)) do c - SEKernel() ∘ ScaleTransform(c) - end @testset "median heuristic" begin for x in (x, XV, XC, XR), dist in (Euclidean(), Cityblock()) diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index 955a0ccec..b06307a09 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -47,12 +47,6 @@ @test repr(ts) == "Select Transform (dims: $(select_symbols2))" test_ADs(() -> SEKernel() ∘ SelectTransform([1, 2])) - test_interface_ad_perf( - _ -> SEKernel(), - nothing, - StableRNG(123456), - [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}], - ) X = randn(rng, (4, 3)) A = AxisArray(X; row=[:a, :b, :c, :d], col=[:x, :y, :z]) diff --git a/test/utils.jl b/test/utils.jl index bbd47ef36..cde61bd5f 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -64,19 +64,6 @@ @test back(ones(size(X)))[1].X == ones(size(X)) end - @testset "Zygote type-inference" begin - ctx = NoContext() - x = ColVecs(randn(2, 4)) - y = ColVecs(randn(2, 3)) - - # Ensure KernelFunctions.pairwise rather than Distances.pairwise is used. - check_zygote_type_stability( - x -> KernelFunctions.pairwise(SqEuclidean(), x), x; ctx=ctx - ) - check_zygote_type_stability( - (x, y) -> KernelFunctions.pairwise(SqEuclidean(), x, y), x, y; ctx=ctx - ) - end else @test_broken false # Zygote not supported on Julia >= 1.12 end From b033b92cc5e0192c4f45449bb980777a1de3db09 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 16 Mar 2026 00:08:46 +0100 Subject: [PATCH 3/4] Fix formatting errors --- test/basekernels/piecewisepolynomial.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/basekernels/piecewisepolynomial.jl b/test/basekernels/piecewisepolynomial.jl index abc561dd6..205058d41 100644 --- a/test/basekernels/piecewisepolynomial.jl +++ b/test/basekernels/piecewisepolynomial.jl @@ -32,5 +32,4 @@ TestUtils.test_interface(k, ColVecs{Float64}; dim_in=2) TestUtils.test_interface(k, RowVecs{Float64}; dim_in=2) test_ADs(() -> PiecewisePolynomialKernel{degree}(; dim=D)) - end From fa88cb9c7521f4848cb9faebe6dbc1a524c05a86 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 16 Mar 2026 00:09:01 +0100 Subject: [PATCH 4/4] Fix Julia LTS failure --- test/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.jl b/test/test_utils.jl index ec9bf2e79..eade8ece0 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -39,8 +39,8 @@ const _BACKEND_MAP = Dict{Symbol,Any}( # Custom isapprox that handles Zygote returning `nothing` for zero gradients function _isapprox_nothing(a, b; kwargs...) - a_val = isnothing(a) ? zero(b) : a - b_val = isnothing(b) ? zero(a) : b + a_val = isnothing(a) ? zero.(b) : a + b_val = isnothing(b) ? zero.(a) : b return Base.isapprox(a_val, b_val; kwargs...) end