-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathopt_common.jl
More file actions
138 lines (119 loc) · 4.12 KB
/
opt_common.jl
File metadata and controls
138 lines (119 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
using GenericLinearAlgebra
export adjust_for_errtype!, solve_linlsqr!
"""
adjust_for_errtype!(A, b, objfun_vals, errtype)
Adjusts the matrix `A` and vector `b` to `errtype`. `A` is a matrix, e.g., the
Jacobian or system matrix in a least squares problem. `b` is a vector, e.g., the
residuals or the right-hand side in a lieast squares problem. `objfun_vals` is a
vector of objective function values, and `errtype` can be `:abserr` or
`:relerr`, i.e., absolute error or relative error.
"""
function adjust_for_errtype!(A, b, objfun_vals, errtype)
# Assumes a Jacobian and residual-vector that is computed in absolute error.
# Adjusts Jacobian and residual-vector from absolute, to relative error,
# i.e., objective function 1/2 sum_i (Z(x_i)-f(x_i))^2 / f(x_i)^2 and
if (errtype == :abserr)
# Do nothing
elseif (errtype == :relerr)
objfun_vals[objfun_vals.==0] .= eps() * 100 # Avoid division by zero
D = Diagonal(1 ./ objfun_vals)
A[:] = D * A
b[:] = D * b
else
error("Unknown errtype '", errtype, "'.")
end
return (A, b)
end
abstract type LinLsqrSolve end
struct BackslashLinLsqrSolve <: LinLsqrSolve;
end
function solve_linlsqr!(A,b,::BackslashLinLsqrSolve)
d = A \ b
end
struct RealBackslashLinLsqrSolve <: LinLsqrSolve; end
function solve_linlsqr!(A,b,::RealBackslashLinLsqrSolve)
d = vcat(real(A), imag(A)) \ vcat(real(b), imag(b))
end
struct NormEqLinLsqrSolve <: LinLsqrSolve; end
function solve_linsqr!(A,b,::NormEqLinLsqrSolve)
d = (A' * A) \ (A' * b)
end
struct RealNormEqLinLsqrSolve <: LinLsqrSolve; end
function solve_linlsqr!(A,b,::RealNormEqLinLsqrSolve)
Ar = real(A)
Ai = imag(A)
br = real(b)
bi = imag(b)
d = (Ar' * Ar + Ai' * Ai) \ (Ar' * br + Ai' * bi)
end
struct SVDLsqrSolve <: LinLsqrSolve;
tp
droptol
fixed_rank
end
function solve_linlsqr!(A,b,solver::SVDLsqrSolve)
if (solver.tp == :real_svd)
A = vcat(real(A), imag(A))
b = vcat(real(b), imag(b))
end
if (eltype(A) == BigFloat || eltype(A) == Complex{BigFloat})
Sfact = svd!(A; full = false, alg = nothing)
else
Sfact = svd(A)
end
d = Sfact.S
# Use pseudoinverse if droptol>0
Z = (d / d[1]) .< solver.droptol;
II=findall((!).(Z));
nonzero=II[1:Int(min(solver.fixed_rank,length(II)))];
# Only select index nonzero
dinv=zeros(eltype(d),size(d));
dinv[1:length(nonzero)] = 1 ./ d[1:length(nonzero)];
# No explicit construction, only multiplication
# JJ0=Sfact.U*Diagonal(d)*Sfact.Vt
d = Sfact.V * (dinv .* (Sfact.U' * b))
end
"""
d = solve_linlsqr!(A, b, linlsqr, droptol)
Solves the linear least squares problem
Ad=b.
The argument `linlsqr` determines how the linear least squares problem is
solved. It can be `:backslash`, `:real_backslash`, `:nrmeq`, `:real_nrmeq`,
`:svd`, or `:real_svd`. For the latter two options singular values below
`droptol` are disregarded. The `:real_X` options optimizes `d` in the space of
real vectors. The input matrix `A` is sometimes overwritten.
"""
function solve_linlsqr!(A, b, linlsqr, droptol)
if (linlsqr == :backslash)
d = A \ b
elseif (linlsqr == :real_backslash)
d = vcat(real(A), imag(A)) \ vcat(real(b), imag(b))
elseif (linlsqr == :nrmeq)
d = (A' * A) \ (A' * b)
elseif (linlsqr == :real_nrmeq)
Ar = real(A)
Ai = imag(A)
br = real(b)
bi = imag(b)
d = (Ar' * Ar + Ai' * Ai) \ (Ar' * br + Ai' * bi)
elseif (linlsqr == :svd) || (linlsqr == :real_svd)
if (linlsqr == :real_svd)
A = vcat(real(A), imag(A))
b = vcat(real(b), imag(b))
end
if (eltype(A) == BigFloat || eltype(A) == Complex{BigFloat})
Sfact = svd!(A; full = false, alg = nothing)
else
Sfact = svd(A)
end
d = Sfact.S
# Use pseudoinverse if droptol>0
II = (d / d[1]) .< droptol
dinv = 1 ./ d
dinv[II] .= 0
# No explicit construction, only multiplication
# JJ0=Sfact.U*Diagonal(d)*Sfact.Vt
d = Sfact.V * (dinv .* (Sfact.U' * b))
end
return d
end