-
Notifications
You must be signed in to change notification settings - Fork 158
Expand file tree
/
Copy pathDerivativeTest.jl
More file actions
123 lines (97 loc) · 3.75 KB
/
DerivativeTest.jl
File metadata and controls
123 lines (97 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
module DerivativeTest
import Calculus
import NaNMath
using Test
using Random
using ForwardDiff
using DiffTests
include(joinpath(dirname(@__FILE__), "utils.jl"))
Random.seed!(1)
########################
# test vs. Calculus.jl #
########################
const x = 1
@testset "$f" for f in DiffTests.NUMBER_TO_NUMBER_FUNCS
v = f(x)
d = ForwardDiff.derivative(f, x)
@test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR)
out = DiffResults.DiffResult(zero(v), zero(v))
out = ForwardDiff.derivative!(out, f, x)
@test isapprox(DiffResults.value(out), v)
@test isapprox(DiffResults.derivative(out), d)
end
@testset "$f" for f in DiffTests.NUMBER_TO_ARRAY_FUNCS
v = f(x)
d = ForwardDiff.derivative(f, x)
@test !(eltype(d) <: ForwardDiff.Dual)
@test isapprox(d, Calculus.derivative(f, x), atol=FINITEDIFF_ERROR)
out = similar(v)
out = ForwardDiff.derivative!(out, f, x)
@test isapprox(out, d)
out = DiffResults.DiffResult(similar(v), similar(d))
out = ForwardDiff.derivative!(out, f, x)
@test isapprox(DiffResults.value(out), v)
@test isapprox(DiffResults.derivative(out), d)
end
@testset "$(f!)" for f! in DiffTests.INPLACE_NUMBER_TO_ARRAY_FUNCS
m, n = 3, 2
y = fill(0.0, m, n)
f = x -> (tmp = similar(y, promote_type(eltype(y), typeof(x)), m, n); f!(tmp, x); tmp)
v = f(x)
cfg = ForwardDiff.DerivativeConfig(f!, y, x)
d = ForwardDiff.derivative(f, x)
fill!(y, 0.0)
@test isapprox(ForwardDiff.derivative(f!, y, x), d)
@test isapprox(v, y)
fill!(y, 0.0)
@test isapprox(ForwardDiff.derivative(f!, y, x, cfg), d)
@test isapprox(v, y)
out = similar(v)
fill!(y, 0.0)
ForwardDiff.derivative!(out, f!, y, x)
@test isapprox(out, d)
@test isapprox(v, y)
out = similar(v)
fill!(y, 0.0)
ForwardDiff.derivative!(out, f!, y, x, cfg)
@test isapprox(out, d)
@test isapprox(v, y)
out = DiffResults.DiffResult(similar(v), similar(d))
out = ForwardDiff.derivative!(out, f!, y, x)
@test isapprox(v, y)
@test isapprox(DiffResults.value(out), v)
@test isapprox(DiffResults.derivative(out), d)
out = DiffResults.DiffResult(similar(v), similar(d))
out = ForwardDiff.derivative!(out, f!, y, x, cfg)
@test isapprox(v, y)
@test isapprox(DiffResults.value(out), v)
@test isapprox(DiffResults.derivative(out), d)
end
@testset "exponential function at base zero" begin
@test (x -> ForwardDiff.derivative(y -> x^y, -0.5))(0.0) === -Inf
@test (x -> ForwardDiff.derivative(y -> x^y, 0.0))(0.0) === -Inf
@test (x -> ForwardDiff.derivative(y -> x^y, 0.5))(0.0) === 0.0
@test (x -> ForwardDiff.derivative(y -> x^y, 1.5))(0.0) === 0.0
end
@testset "exponentiation with NaNMath" begin
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(NaN, x), 1.0))
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,NaN), 1.0))
@test !isnan(ForwardDiff.derivative(x -> NaNMath.pow(1.0, x),1.0))
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,0.5), -1.0))
@test isnan(ForwardDiff.derivative(x -> x^NaN, 2.0))
@test ForwardDiff.derivative(x -> x^2.0,2.0) == 4.0
@test_throws DomainError ForwardDiff.derivative(x -> x^0.5, -1.0)
end
@testset "dimension error for derivative" begin
@test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3))
end
@testset "complex output" begin
@test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im)
end
@testset "analytic functions" begin
dexp(x) = ForwardDiff.derivative(y -> exp(complex(0, y)), x)
@test ForwardDiff.derivative(dexp, 0.0) ≈ -1
@test ForwardDiff.derivative(x -> exp(1im*x), 0.7) ≈ im * cis(0.7)
@test ForwardDiff.derivative(x -> sqrt(im + (1+im) * x), 1.23) ≈ (1+im) / (2 * sqrt(im + (1+im)*1.23))
end
end # module