-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathtransformedkernel.jl
More file actions
133 lines (104 loc) · 4.6 KB
/
transformedkernel.jl
File metadata and controls
133 lines (104 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
TransformedKernel(k::Kernel, t::Transform)
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
The preferred way to create kernels with input transformations is to use the composition
operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since
this allows optimized implementations for specific kernels and transformations.
See also: [`∘`](@ref)
"""
struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
kernel::Tk
transform::Tr
end
@functor TransformedKernel
function ParameterHandling.flatten(::Type{T}, k::TransformedKernel) where {T<:Real}
kernel_vec, kernel_back = flatten(T, k.kernel)
transform_vec, transform_back = flatten(T, k.transform)
v = vcat(kernel_vec, transform_vec)
n = length(v)
nkernel = length(kernel_vec)
function unflatten_to_transformedkernel(v::Vector{T})
length(v) == n || error("incorrect number of parameters")
kernel = kernel_back(v[1:nkernel])
transform = transform_back(v[(nkernel + 1):end])
return TransformedKernel(kernel, transform)
end
return v, unflatten_to_transformedkernel
end
(k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y))
# Optimizations for scale transforms of simple kernels to save allocations:
# Instead of a multiplying every element of the inputs before evaluating the metric,
# we perform a scalar multiplcation of the distance of the original inputs, if possible.
function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(
x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
)
return kappa(k.kernel, _scale(k.transform, metric(k.kernel), x, y))
end
function _scale(t::ScaleTransform, metric::Euclidean, x, y)
return only(t.s) * evaluate(metric, x, y)
end
function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
return only(t.s)^2 * evaluate(metric, x, y)
end
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
"""
kernel ∘ transform
∘(kernel, transform)
compose(kernel, transform)
Compose a `kernel` with a transformation `transform` of its inputs.
The prefix forms support chains of multiple transformations:
`∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`.
# Definition
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
input transformation ``t`` is defined as
```math
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
```
# Examples
```jldoctest
julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5)
true
julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1)
true
```
See also: [`TransformedKernel`](@ref)
"""
Base.:∘(k::Kernel, t::Transform) = TransformedKernel(k, t)
Base.:∘(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform ∘ t)
# Simplify kernels with identity transformation of the inputs
Base.:∘(k::Kernel, ::IdentityTransform) = k
Base.:∘(k::TransformedKernel, ::IdentityTransform) = k
Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)
function printshifted(io::IO, κ::TransformedKernel, shift::Int)
printshifted(io, κ.kernel, shift)
return print(io, "\n" * ("\t"^(shift + 1)) * "- $(κ.transform)")
end
# Kernel matrix operations
function kernelmatrix_diag!(K::AbstractVector, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x))
end
function kernelmatrix_diag!(
K::AbstractVector, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix_diag!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end
function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x))
end
function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end
function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x))
end
function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix_diag(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end
function kernelmatrix(κ::TransformedKernel, x::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x))
end
function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y))
end