-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathtransformedkernel.jl
More file actions
75 lines (67 loc) · 2.58 KB
/
transformedkernel.jl
File metadata and controls
75 lines (67 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@testset "transformedkernel" begin
rng = MersenneTwister(123456)
x = rand(rng) * 2
v1 = rand(rng, 3)
v2 = rand(rng, 3)
s = rand(rng)
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
@test k ∘ IdentityTransform() === k
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt ∘ IdentityTransform() === kt
@test ktard ∘ IdentityTransform() === ktard
@test kt(v1, v2) == (k ∘ ScaleTransform(s))(v1, v2)
@test kt(v1, v2) ≈ k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)
@test (k ∘ LinearTransform(P') ∘ ScaleTransform(s))(v1, v2) ==
((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2) ==
(k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2)
@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))
TestUtils.test_interface(k, Float64)
TestUtils.test_interface(
TransformedKernel(ConstantKernel(; c=1.5), FunctionTransform(x -> x * "hi")),
Vector{String},
)
test_ADs(x -> SqExponentialKernel() ∘ ScaleTransform(x[1]), rand(1))
test_interface_ad_perf(0.35, StableRNG(123456)) do λ
SqExponentialKernel() ∘ ScaleTransform(λ)
end
# Test implicit gradients
@testset "Implicit gradients" begin
if _TEST_ZYGOTE
k = SqExponentialKernel() ∘ ScaleTransform(2.0)
ps = params(k)
X = rand(10, 1)
x = vec(X)
A = rand(10, 10)
# Implicit
g1 = Zygote.gradient(ps) do
tr(kernelmatrix(k, X; obsdim=1) * A)
end
# Explicit
g2 = Zygote.gradient(k) do k
tr(kernelmatrix(k, X; obsdim=1) * A)
end
# Implicit for a vector
g3 = Zygote.gradient(ps) do
tr(kernelmatrix(k, x) * A)
end
@test g1[first(ps)] ≈ first(g2).transform.s
@test g1[first(ps)] ≈ g3[first(ps)]
else
@test_broken false # Zygote not supported on Julia >= 1.12
end
end
@testset "Parameters" begin
k = ConstantKernel(; c=rand(rng))
# c = Chain(Dense(3, 2))
test_params(k ∘ ScaleTransform(s), (k, [s]))
test_params(k ∘ ARDTransform(v), (k, v))
test_params(k ∘ LinearTransform(P), (k, P))
test_params(k ∘ LinearTransform(P) ∘ ScaleTransform(s), (k, [s], P))
# test_params(k ∘ FunctionTransform(c), (k, c))
end
end