diff --git a/Project.toml b/Project.toml index 4cfb0b718..f46a87300 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" version = "0.11.0" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" @@ -18,15 +17,18 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] +KernelFunctionsChainRulesCoreExt = "ChainRulesCore" KernelFunctionsKroneckerExt = "Kronecker" KernelFunctionsPDMatsExt = "PDMats" +KernelFunctionsZygoteRulesExt = "ZygoteRules" [compat] ChainRulesCore = "1" diff --git a/examples/train-kernel-parameters/script.jl b/examples/train-kernel-parameters/script.jl index a1087c978..aca1298cd 100644 --- a/examples/train-kernel-parameters/script.jl +++ b/examples/train-kernel-parameters/script.jl @@ -5,6 +5,10 @@ # We load KernelFunctions and some other packages. Note that while we use `Zygote` for automatic differentiation and `Flux.optimise` for optimization, you should be able to replace them with your favourite autodiff framework or optimizer. +# !!! note +# Zygote is not expected to work on Julia ≥ 1.12. Use a different AD package for +# Julia ≥ 1.12, or use Julia 1.11 to run this example. + using KernelFunctions using LinearAlgebra using Distributions diff --git a/src/chainrules.jl b/ext/KernelFunctionsChainRulesCoreExt.jl similarity index 91% rename from src/chainrules.jl rename to ext/KernelFunctionsChainRulesCoreExt.jl index 549373876..ddbda83df 100644 --- a/src/chainrules.jl +++ b/ext/KernelFunctionsChainRulesCoreExt.jl @@ -1,9 +1,18 @@ +module KernelFunctionsChainRulesCoreExt + +using ChainRulesCore: + ChainRulesCore, Tangent, ZeroTangent, NoTangent, @thunk, ProjectTo, unthunk +using Distances: Distances, Euclidean, SqEuclidean +using IrrationalConstants: twoπ +using KernelFunctions: KernelFunctions, Delta, DotProduct, Sinus, ColVecs, RowVecs +using LinearAlgebra: dot + ## Forward Rules # Note that this is type piracy as the derivative should be NaN for x == y. function ChainRulesCore.frule( (_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any}, - d::Distances.Euclidean, + d::Euclidean, x::AbstractVector, y::AbstractVector, ) @@ -116,7 +125,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2 function evaluate_pullback(Δ::Any) r̄ = -2Δ .* abs2_sind_r ./ s.r - s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄) + s̄ = Tangent{typeof(s)}(; r=r̄) return s̄, Δ * gradx, -Δ * gradx end return val, evaluate_pullback @@ -150,7 +159,7 @@ function ChainRulesCore.rrule( x̄[:, j] .-= ds end end - d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + d̄ = Tangent{typeof(d)}(; r=r̄) return NoTangent(), d̄, @thunk(project_x(x̄)) end return Distances.pairwise(d, x; dims), pairwise_pullback @@ -166,7 +175,7 @@ function ChainRulesCore.rrule( n = size(x, dims) m = size(y, dims) x̄ = collect(zero(x)) - ȳ = collect(zero(y)) + ȳ = collect(zero(y)) r̄ = zero(d.r) if dims == 1 for j in 1:m, i in 1:n @@ -175,7 +184,7 @@ function ChainRulesCore.rrule( ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 x̄[i, :] .+= ds - ȳ[j, :] .-= ds + ȳ[j, :] .-= ds end elseif dims == 2 for j in 1:m, i in 1:n @@ -184,11 +193,11 @@ function ChainRulesCore.rrule( ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 x̄[:, i] .+= ds - ȳ[:, j] .-= ds + ȳ[:, j] .-= ds end end - d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) - return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + d̄ = Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) end return Distances.pairwise(d, x, y; dims), pairwise_pullback end @@ -202,7 +211,7 @@ function ChainRulesCore.rrule( Δ = unthunk(z̄) n = size(x, 2) x̄ = collect(zero(x)) - ȳ = collect(zero(y)) + ȳ = collect(zero(y)) r̄ = zero(d.r) for i in 1:n xi = view(x, :, i) @@ -210,10 +219,10 @@ function ChainRulesCore.rrule( ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3 x̄[:, i] .+= ds - ȳ[:, i] .-= ds + ȳ[:, i] .-= ds end - d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) - return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + d̄ = Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) end return Distances.colwise(d, x, y), colwise_pullback end @@ -247,3 +256,5 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) end return RowVecs(X), RowVecs_pullback end + +end diff --git a/src/zygoterules.jl b/ext/KernelFunctionsZygoteRulesExt.jl similarity index 67% rename from src/zygoterules.jl rename to ext/KernelFunctionsZygoteRulesExt.jl index e405a4946..1f2a2aae6 100644 --- a/src/zygoterules.jl +++ b/ext/KernelFunctionsZygoteRulesExt.jl @@ -1,3 +1,8 @@ +module KernelFunctionsZygoteRulesExt + +using KernelFunctions: KernelFunctions, Transform, ColVecs, RowVecs, _map +using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield + ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) return ZygoteRules.pullback(_map, t, X) end @@ -11,3 +16,5 @@ function ZygoteRules._pullback( ) where {f} return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}()) end + +end diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 2c5512c06..5ce131d58 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -50,8 +50,6 @@ export IndependentMOKernel, export tensor, ⊗, compose using Compat -using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent -using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk using CompositionsBase using Distances using FillArrays @@ -62,7 +60,6 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2 using LogExpFunctions: softplus using StatsBase using TensorCore -using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield # Hack to work around Zygote type inference problems. const Distances_pairwise = Distances.pairwise @@ -123,9 +120,6 @@ include("mokernels/slfm.jl") include("mokernels/intrinsiccoregion.jl") include("mokernels/lmm.jl") -include("chainrules.jl") -include("zygoterules.jl") - include("TestUtils.jl") # Kronecker extension stubs diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 669d9721d..5e9a960b3 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -17,10 +17,12 @@ # Tests failing for ForwardDiff and Zygote@0.6. # Related to: https://github.com/FluxML/Zygote.jl/issues/1036 - f(x, y) = x^y - @test_broken !isinf( - Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1] - ) + if _TEST_ZYGOTE + f(x, y) = x^y + @test_broken !isinf( + Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1] + ) + end test_params(k, ([h],)) diff --git a/test/chainrules.jl b/test/chainrules.jl index f2a4ae3a5..ea4b43ab1 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -4,20 +4,24 @@ y = rand(rng, 5) r = rand(rng, 5) - compare_gradient(:Zygote, [x, y]) do xy - Euclidean()(xy[1], xy[2]) - end - compare_gradient(:Zygote, [x, y]) do xy - SqEuclidean()(xy[1], xy[2]) - end - compare_gradient(:Zygote, [x, y]) do xy - KernelFunctions.DotProduct()(xy[1], xy[2]) - end - compare_gradient(:Zygote, [x, y]) do xy - KernelFunctions.Delta()(xy[1], xy[2]) - end - compare_gradient(:Zygote, [x, y]) do xy - KernelFunctions.Sinus(r)(xy[1], xy[2]) + if _TEST_ZYGOTE + compare_gradient(:Zygote, [x, y]) do xy + Euclidean()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + SqEuclidean()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.DotProduct()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.Delta()(xy[1], xy[2]) + end + compare_gradient(:Zygote, [x, y]) do xy + KernelFunctions.Sinus(r)(xy[1], xy[2]) + end + else + @test_broken false # Zygote not supported on Julia >= 1.12 end @testset "rrules for Sinus(r=$r)" for r in (rand(3),) dist = KernelFunctions.Sinus(r) diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index 9df8e4fdb..51d55aa1f 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -36,26 +36,30 @@ # Test implicit gradients @testset "Implicit gradients" begin - k = SqExponentialKernel() ∘ ScaleTransform(2.0) - ps = params(k) - X = rand(10, 1) - x = vec(X) - A = rand(10, 10) - # Implicit - g1 = Zygote.gradient(ps) do - tr(kernelmatrix(k, X; obsdim=1) * A) - end - # Explicit - g2 = Zygote.gradient(k) do k - tr(kernelmatrix(k, X; obsdim=1) * A) - end + if _TEST_ZYGOTE + k = SqExponentialKernel() ∘ ScaleTransform(2.0) + ps = params(k) + X = rand(10, 1) + x = vec(X) + A = rand(10, 10) + # Implicit + g1 = Zygote.gradient(ps) do + tr(kernelmatrix(k, X; obsdim=1) * A) + end + # Explicit + g2 = Zygote.gradient(k) do k + tr(kernelmatrix(k, X; obsdim=1) * A) + end - # Implicit for a vector - g3 = Zygote.gradient(ps) do - tr(kernelmatrix(k, x) * A) + # Implicit for a vector + g3 = Zygote.gradient(ps) do + tr(kernelmatrix(k, x) * A) + end + @test g1[first(ps)] ≈ first(g2).transform.s + @test g1[first(ps)] ≈ g3[first(ps)] + else + @test_broken false # Zygote not supported on Julia >= 1.12 end - @test g1[first(ps)] ≈ first(g2).transform.s - @test g1[first(ps)] ≈ g3[first(ps)] end @testset "Parameters" begin diff --git a/test/test_utils.jl b/test/test_utils.jl index b2172a20a..fdb9f2871 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -31,6 +31,8 @@ end # AD utilities +const _TEST_ZYGOTE = VERSION < v"1.12" + # Type to work around some performance issues that can happen on the reverse-pass of Zygote. # This context doesn't allow any globals. Don't use this if you use globals in your # programme. @@ -42,6 +44,9 @@ Zygote.accum_param(::NoContext, x, Δ) = Δ const FDM = FiniteDifferences.central_fdm(5, 1) +const _DEFAULT_ADS = + _TEST_ZYGOTE ? [:Zygote, :ForwardDiff, :ReverseDiff] : [:ForwardDiff, :ReverseDiff] + gradient(f, s::Symbol, args) = gradient(f, Val(s), args) function gradient(f, ::Val{:Zygote}, args) @@ -90,7 +95,7 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A)) testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) function test_ADs( - kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] + kernelfunction, args=nothing; ADs=_DEFAULT_ADS, dims=[3, 3] ) test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) if !test_fd.anynonpass @@ -101,6 +106,10 @@ function test_ADs( end function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) + if !_TEST_ZYGOTE + @test_broken false + return nothing + end @inferred f(args...) @inferred Zygote._pullback(ctx, f, args...) out, pb = Zygote._pullback(ctx, f, args...) @@ -108,7 +117,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) end function test_ADs( - k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3) + k::MOKernel; ADs=_DEFAULT_ADS, dims=(in=3, out=2, obs=3) ) test_fd = test_FiniteDiff(k, dims) if !test_fd.anynonpass @@ -372,6 +381,10 @@ function test_zygote_perf_heuristic( f, name::String, args1, args2, passes, Δ1=nothing, Δ2=nothing ) @testset "$name" begin + if !_TEST_ZYGOTE + @test_broken false + return nothing + end primal, fwd, pb = ad_constant_allocs_heuristic(f, args1, args2; Δ1, Δ2) if passes[1] @test primal[1] == primal[2] diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index e7af7b3c0..955a0ccec 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -103,7 +103,7 @@ @test gx ≈ ga end - @testset "$(AD)" for AD in [:ReverseDiff, :Zygote] + @testset "$(AD)" for AD in (_TEST_ZYGOTE ? [:ReverseDiff, :Zygote] : [:ReverseDiff]) @test_broken let gx = gradient(AD, X) do x testfunction(tx_row, x, 2) diff --git a/test/utils.jl b/test/utils.jl index a7cabbd05..bbd47ef36 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -53,17 +53,17 @@ KernelFunctions.pairwise!(SqEuclidean(), K, DX, DY) @test K ≈ pairwise(SqEuclidean(), X, Y; dims=2) - let - @test Zygote.pullback(ColVecs, X)[1] == DX - DX, back = Zygote.pullback(ColVecs, X) - @test back((X=ones(size(X)),))[1] == ones(size(X)) - - @test Zygote.pullback(DX -> DX.X, DX)[1] == X - X_, back = Zygote.pullback(DX -> DX.X, DX) - @test back(ones(size(X)))[1].X == ones(size(X)) - end + if _TEST_ZYGOTE + let + @test Zygote.pullback(ColVecs, X)[1] == DX + DX, back = Zygote.pullback(ColVecs, X) + @test back((X=ones(size(X)),))[1] == ones(size(X)) + + @test Zygote.pullback(DX -> DX.X, DX)[1] == X + X_, back = Zygote.pullback(DX -> DX.X, DX) + @test back(ones(size(X)))[1].X == ones(size(X)) + end - if VERSION >= v"1.7" @testset "Zygote type-inference" begin ctx = NoContext() x = ColVecs(randn(2, 4)) @@ -77,6 +77,8 @@ (x, y) -> KernelFunctions.pairwise(SqEuclidean(), x, y), x, y; ctx=ctx ) end + else + @test_broken false # Zygote not supported on Julia >= 1.12 end end @testset "RowVecs" begin @@ -111,14 +113,18 @@ KernelFunctions.pairwise!(SqEuclidean(), K, DX, DY) @test K ≈ pairwise(SqEuclidean(), X, Y; dims=1) - let - @test Zygote.pullback(RowVecs, X)[1] == DX - DX, back = Zygote.pullback(RowVecs, X) - @test back((X=ones(size(X)),))[1] == ones(size(X)) + if _TEST_ZYGOTE + let + @test Zygote.pullback(RowVecs, X)[1] == DX + DX, back = Zygote.pullback(RowVecs, X) + @test back((X=ones(size(X)),))[1] == ones(size(X)) - @test Zygote.pullback(DX -> DX.X, DX)[1] == X - X_, back = Zygote.pullback(DX -> DX.X, DX) - @test back(ones(size(X)))[1].X == ones(size(X)) + @test Zygote.pullback(DX -> DX.X, DX)[1] == X + X_, back = Zygote.pullback(DX -> DX.X, DX) + @test back(ones(size(X)))[1].X == ones(size(X)) + end + else + @test_broken false # Zygote not supported on Julia >= 1.12 end end @testset "ColVecs + RowVecs" begin