Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.11.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand All @@ -18,15 +17,18 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
KernelFunctionsChainRulesCoreExt = "ChainRulesCore"
KernelFunctionsKroneckerExt = "Kronecker"
KernelFunctionsPDMatsExt = "PDMats"
KernelFunctionsZygoteRulesExt = "ZygoteRules"

[compat]
ChainRulesCore = "1"
Expand Down
4 changes: 4 additions & 0 deletions examples/train-kernel-parameters/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

# We load KernelFunctions and some other packages. Note that while we use `Zygote` for automatic differentiation and `Flux.optimise` for optimization, you should be able to replace them with your favourite autodiff framework or optimizer.

# !!! note
# Zygote is not expected to work on Julia ≥ 1.12. Use a different AD package for
# Julia ≥ 1.12, or use Julia 1.11 to run this example.

using KernelFunctions
using LinearAlgebra
using Distributions
Expand Down
35 changes: 23 additions & 12 deletions src/chainrules.jl → ext/KernelFunctionsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
module KernelFunctionsChainRulesCoreExt

using ChainRulesCore:
ChainRulesCore, Tangent, ZeroTangent, NoTangent, @thunk, ProjectTo, unthunk
using Distances: Distances, Euclidean, SqEuclidean
using IrrationalConstants: twoπ
using KernelFunctions: KernelFunctions, Delta, DotProduct, Sinus, ColVecs, RowVecs
using LinearAlgebra: dot

## Forward Rules

# Note that this is type piracy as the derivative should be NaN for x == y.
function ChainRulesCore.frule(
(_, Δx, Δy)::Tuple{<:Any,<:Any,<:Any},
d::Distances.Euclidean,
d::Euclidean,
x::AbstractVector,
y::AbstractVector,
)
Expand Down Expand Up @@ -116,7 +125,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2
function evaluate_pullback(Δ::Any)
r̄ = -2Δ .* abs2_sind_r ./ s.r
s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
s̄ = Tangent{typeof(s)}(; r=r̄)
return s̄, Δ * gradx, -Δ * gradx
end
return val, evaluate_pullback
Expand Down Expand Up @@ -150,7 +159,7 @@ function ChainRulesCore.rrule(
x̄[:, j] .-= ds
end
end
d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
d̄ = Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄))
end
return Distances.pairwise(d, x; dims), pairwise_pullback
Expand All @@ -166,7 +175,7 @@ function ChainRulesCore.rrule(
n = size(x, dims)
m = size(y, dims)
x̄ = collect(zero(x))
= collect(zero(y))
ȳ = collect(zero(y))
r̄ = zero(d.r)
if dims == 1
for j in 1:m, i in 1:n
Expand All @@ -175,7 +184,7 @@ function ChainRulesCore.rrule(
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[i, :] .+= ds
[j, :] .-= ds
ȳ[j, :] .-= ds
end
elseif dims == 2
for j in 1:m, i in 1:n
Expand All @@ -184,11 +193,11 @@ function ChainRulesCore.rrule(
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[:, i] .+= ds
[:, j] .-= ds
ȳ[:, j] .-= ds
end
end
d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y())
d̄ = Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
end
return Distances.pairwise(d, x, y; dims), pairwise_pullback
end
Expand All @@ -202,18 +211,18 @@ function ChainRulesCore.rrule(
Δ = unthunk(z̄)
n = size(x, 2)
x̄ = collect(zero(x))
= collect(zero(y))
ȳ = collect(zero(y))
r̄ = zero(d.r)
for i in 1:n
xi = view(x, :, i)
yi = view(y, :, i)
ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2
r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
x̄[:, i] .+= ds
[:, i] .-= ds
ȳ[:, i] .-= ds
end
d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y())
d̄ = Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
end
return Distances.colwise(d, x, y), colwise_pullback
end
Expand Down Expand Up @@ -247,3 +256,5 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
end
return RowVecs(X), RowVecs_pullback
end

end
7 changes: 7 additions & 0 deletions src/zygoterules.jl → ext/KernelFunctionsZygoteRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module KernelFunctionsZygoteRulesExt

using KernelFunctions: KernelFunctions, Transform, ColVecs, RowVecs, _map
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs)
return ZygoteRules.pullback(_map, t, X)
end
Expand All @@ -11,3 +16,5 @@ function ZygoteRules._pullback(
) where {f}
return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}())
end

end
6 changes: 0 additions & 6 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ export IndependentMOKernel,
export tensor, ⊗, compose

using Compat
using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent
using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk
using CompositionsBase
using Distances
using FillArrays
Expand All @@ -62,7 +60,6 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using StatsBase
using TensorCore
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

# Hack to work around Zygote type inference problems.
const Distances_pairwise = Distances.pairwise
Expand Down Expand Up @@ -123,9 +120,6 @@ include("mokernels/slfm.jl")
include("mokernels/intrinsiccoregion.jl")
include("mokernels/lmm.jl")

include("chainrules.jl")
include("zygoterules.jl")

include("TestUtils.jl")

# Kronecker extension stubs
Expand Down
10 changes: 6 additions & 4 deletions test/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

# Tests failing for ForwardDiff and Zygote@0.6.
# Related to: https://github.com/FluxML/Zygote.jl/issues/1036
f(x, y) = x^y
@test_broken !isinf(
Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1]
)
if _TEST_ZYGOTE
f(x, y) = x^y
@test_broken !isinf(
Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1]
)
end

test_params(k, ([h],))

Expand Down
32 changes: 18 additions & 14 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@
y = rand(rng, 5)
r = rand(rng, 5)

compare_gradient(:Zygote, [x, y]) do xy
Euclidean()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
SqEuclidean()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.DotProduct()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Delta()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Sinus(r)(xy[1], xy[2])
if _TEST_ZYGOTE
compare_gradient(:Zygote, [x, y]) do xy
Euclidean()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
SqEuclidean()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.DotProduct()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Delta()(xy[1], xy[2])
end
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Sinus(r)(xy[1], xy[2])
end
else
@test_broken false # Zygote not supported on Julia >= 1.12
end
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
dist = KernelFunctions.Sinus(r)
Expand Down
40 changes: 22 additions & 18 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,30 @@

# Test implicit gradients
@testset "Implicit gradients" begin
k = SqExponentialKernel() ∘ ScaleTransform(2.0)
ps = params(k)
X = rand(10, 1)
x = vec(X)
A = rand(10, 10)
# Implicit
g1 = Zygote.gradient(ps) do
tr(kernelmatrix(k, X; obsdim=1) * A)
end
# Explicit
g2 = Zygote.gradient(k) do k
tr(kernelmatrix(k, X; obsdim=1) * A)
end
if _TEST_ZYGOTE
k = SqExponentialKernel() ∘ ScaleTransform(2.0)
ps = params(k)
X = rand(10, 1)
x = vec(X)
A = rand(10, 10)
# Implicit
g1 = Zygote.gradient(ps) do
tr(kernelmatrix(k, X; obsdim=1) * A)
end
# Explicit
g2 = Zygote.gradient(k) do k
tr(kernelmatrix(k, X; obsdim=1) * A)
end

# Implicit for a vector
g3 = Zygote.gradient(ps) do
tr(kernelmatrix(k, x) * A)
# Implicit for a vector
g3 = Zygote.gradient(ps) do
tr(kernelmatrix(k, x) * A)
end
@test g1[first(ps)] ≈ first(g2).transform.s
@test g1[first(ps)] ≈ g3[first(ps)]
else
@test_broken false # Zygote not supported on Julia >= 1.12
end
@test g1[first(ps)] ≈ first(g2).transform.s
@test g1[first(ps)] ≈ g3[first(ps)]
end

@testset "Parameters" begin
Expand Down
17 changes: 15 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ end

# AD utilities

const _TEST_ZYGOTE = VERSION < v"1.12"

# Type to work around some performance issues that can happen on the reverse-pass of Zygote.
# This context doesn't allow any globals. Don't use this if you use globals in your
# programme.
Expand All @@ -42,6 +44,9 @@ Zygote.accum_param(::NoContext, x, Δ) = Δ

const FDM = FiniteDifferences.central_fdm(5, 1)

const _DEFAULT_ADS =
_TEST_ZYGOTE ? [:Zygote, :ForwardDiff, :ReverseDiff] : [:ForwardDiff, :ReverseDiff]

gradient(f, s::Symbol, args) = gradient(f, Val(s), args)

function gradient(f, ::Val{:Zygote}, args)
Expand Down Expand Up @@ -90,7 +95,7 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A))
testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))

function test_ADs(
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3]
kernelfunction, args=nothing; ADs=_DEFAULT_ADS, dims=[3, 3]
)
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims)
if !test_fd.anynonpass
Expand All @@ -101,14 +106,18 @@ function test_ADs(
end

function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
if !_TEST_ZYGOTE
@test_broken false
return nothing
end
@inferred f(args...)
@inferred Zygote._pullback(ctx, f, args...)
out, pb = Zygote._pullback(ctx, f, args...)
@inferred collect(pb(out))
end

function test_ADs(
k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3)
k::MOKernel; ADs=_DEFAULT_ADS, dims=(in=3, out=2, obs=3)
)
test_fd = test_FiniteDiff(k, dims)
if !test_fd.anynonpass
Expand Down Expand Up @@ -372,6 +381,10 @@ function test_zygote_perf_heuristic(
f, name::String, args1, args2, passes, Δ1=nothing, Δ2=nothing
)
@testset "$name" begin
if !_TEST_ZYGOTE
@test_broken false
return nothing
end
primal, fwd, pb = ad_constant_allocs_heuristic(f, args1, args2; Δ1, Δ2)
if passes[1]
@test primal[1] == primal[2]
Expand Down
2 changes: 1 addition & 1 deletion test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
@test gx ≈ ga
end

@testset "$(AD)" for AD in [:ReverseDiff, :Zygote]
@testset "$(AD)" for AD in (_TEST_ZYGOTE ? [:ReverseDiff, :Zygote] : [:ReverseDiff])
@test_broken let
gx = gradient(AD, X) do x
testfunction(tx_row, x, 2)
Expand Down
Loading
Loading