-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathchaintransform.jl
More file actions
31 lines (24 loc) · 1008 Bytes
/
chaintransform.jl
File metadata and controls
31 lines (24 loc) · 1008 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@testset "chaintransform" begin
rng = MersenneTwister(123546)
P = rand(rng, 3, 2)
tp = LinearTransform(P)
f(x) = sin.(x)
tf = FunctionTransform(f)
t = ChainTransform([tp, tf])
# Check composition constructors.
@test (tf ∘ ChainTransform([tp])).transforms == [tp, tf]
@test (ChainTransform([tf]) ∘ tp).transforms == [tp, tf]
# Verify correctness.
x = ColVecs(randn(rng, 2, 3))
x′ = map(t, x)
@test all([t(x[n]) ≈ f(P * x[n]) for n in eachindex(x)])
@test all([t(x[n]) ≈ x′[n] for n in eachindex(x)])
# Verify printing works as expected.
@test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))),
randn(rng, 4);
ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote
)
@test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263"
end