-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhprod.jl
More file actions
67 lines (58 loc) · 1.99 KB
/
hprod.jl
File metadata and controls
67 lines (58 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
Evaluate Hessian-vector products for a batch of points.
"""
function lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
Hv = _maybe_view(bm, :hprod_out, X)
lagrangian_hprod!(bm, X, Θ, Y, V, Hv; obj_weight=obj_weight)
return Hv
end
"""
lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
Evaluate Hessian-vector products for a batch of points.
"""
function lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
Θ = _repeat_params(bm, X)
lagrangian_hprod!(bm, X, Θ, Y, V; obj_weight=obj_weight)
return Hv
end
function lagrangian_hprod!(
bm::BatchModel,
X::AbstractMatrix,
Θ::AbstractMatrix,
Y::AbstractMatrix,
V::AbstractMatrix,
Hv::AbstractMatrix;
obj_weight=1.0,
)
batch_size = size(X, 2)
@lencheck batch_size eachcol(X) eachcol(Θ) eachcol(Y) eachcol(V) eachcol(Hv)
@lencheck bm.model.meta.nvar eachrow(X) eachrow(V) eachrow(Hv)
@lencheck length(bm.model.θ) eachrow(Θ)
@lencheck bm.model.meta.ncon eachrow(Y)
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
ph = _get_prodhelper(bm.model)
H_batch = _maybe_view(bm, :hprod_work, X)
lagrangian_hessian!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight)
fill!(Hv, zero(eltype(Hv)))
kersyspmv_batch(backend)(
Hv,
V,
ph.hesssparsityi,
H_batch,
ph.hessptri;
ndrange = (length(ph.hessptri) - 1, batch_size),
)
synchronize(backend)
kersyspmv2_batch(backend)(
Hv,
V,
ph.hesssparsityj,
H_batch,
ph.hessptrj;
ndrange = (length(ph.hessptrj) - 1, batch_size),
)
synchronize(backend)
return Hv
end