-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrad.jl
More file actions
80 lines (68 loc) · 2.24 KB
/
grad.jl
File metadata and controls
80 lines (68 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
Evaluate objective gradient for a batch of points.
"""
function objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
G = _maybe_view(bm, :grad_out, X)
objective_gradient!(bm, X, Θ, G)
return G
end
"""
objective_gradient!(bm::BatchModel, X::AbstractMatrix)
Evaluate objective gradient for a batch of points.
"""
function objective_gradient!(bm::BatchModel, X::AbstractMatrix)
Θ = _repeat_params(bm, X)
objective_gradient!(bm, X, Θ)
end
function _objective_gradient!(backend, grad_work, objs, X, Θ)
sgradient_batch!(backend, grad_work, objs, X, Θ, one(eltype(grad_work)))
_objective_gradient!(backend, grad_work, objs.inner, X, Θ)
synchronize(backend)
end
function _objective_gradient!(backend, grad_work, objs::ExaModels.ObjectiveNull, X, Θ) end
function sgradient_batch!(
backend::B,
Y,
f,
X,
Θ,
adj,
) where {B<:KernelAbstractions.Backend}
if !isempty(f.itr)
batch_size = size(X, 2)
kerg_batch(backend)(Y, f.f, f.itr, X, Θ, adj; ndrange = (length(f.itr), batch_size))
end
end
"""
objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix)
Evaluate gradients for a batch of points with different parameters.
"""
function objective_gradient!(
bm::BatchModel,
X::AbstractMatrix,
Θ::AbstractMatrix,
G::AbstractMatrix,
)
batch_size = size(X, 2)
@lencheck batch_size eachcol(X) eachcol(Θ) eachcol(G)
@lencheck bm.model.meta.nvar eachrow(X) eachrow(G)
@lencheck length(bm.model.θ) eachrow(Θ) # FIXME
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)
grad_work = _maybe_view(bm, :grad_work, X)
if !isempty(grad_work)
fill!(grad_work, zero(eltype(grad_work)))
_objective_gradient!(backend, grad_work, bm.model.objs, X, Θ)
fill!(G, zero(eltype(G)))
compress_to_dense_batch(backend)(
G,
grad_work,
bm.model.ext.gptr,
bm.model.ext.gsparsity;
ndrange = (length(bm.model.ext.gptr) - 1, batch_size),
)
synchronize(backend)
end
return G
end