-
Notifications
You must be signed in to change notification settings - Fork 66
Expand file tree
/
Copy pathcomposite.jl
More file actions
306 lines (258 loc) · 10.6 KB
/
composite.jl
File metadata and controls
306 lines (258 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""
Composite{P, T} <: AbstractDifferential
This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
`P` is the the corresponding primal type that this is a differential for.
`Composite{P}` should have fields (technically properties), that match to a subset of the
fields of the primal type; and each should be a differential type matching to the primal
type of that field.
Fields of the P that are not present in the Composite are treated as `Zero`.
`T` is an implementation detail representing the backing data structure.
For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
It should not be passed in by user.
For `Composite`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly
to for a tuple.
For `Composite`s of `struct`s, `getproperty` is overloaded to allow for accessing values
via `comp.fieldname`.
Any fields not explictly present in the `Composite` are treated as being set to `Zero()`.
To make a `Composite` have all the fields of the primal the [`canonicalize`](@ref)
function is provided.
"""
struct Composite{P, T} <: AbstractDifferential
# Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
# (but potentially a different one, as it doesn't contain differentials)
backing::T
end
function Composite{P}(; kwargs...) where P
backing = (; kwargs...) # construct as NamedTuple
return Composite{P, typeof(backing)}(backing)
end
function Composite{P}(args...) where P
return Composite{P, typeof(args)}(args)
end
function Composite{P}() where P<:Tuple
backing = ()
return Composite{P, typeof(backing)}(backing)
end
function Composite{P}(d::Dict) where {P<:Dict}
return Composite{P, typeof(d)}(d)
end
function Base.:(==)(a::Composite{P, T}, b::Composite{P, T}) where {P, T}
return backing(a) == backing(b)
end
function Base.:(==)(a::Composite{P}, b::Composite{P}) where {P, T}
return canonicalize(a) == canonicalize(b)
end
Base.:(==)(a::Composite{P}, b::Composite{Q}) where {P, Q} = false
Base.hash(a::Composite, h::UInt) = Base.hash(backing(canonicalize(a)), h)
function Base.show(io::IO, comp::Composite{P}) where P
print(io, "Composite{")
show(io, P)
print(io, "}")
# allow Tuple or NamedTuple `show` to do the rendering of brackets etc
show(io, backing(comp))
end
Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backing(comp)
Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp)
Base.convert(::Type{<:Dict}, comp::Composite{<:Dict, <:Dict}) = backing(comp)
Base.getindex(comp::Composite, idx) = unthunk(getindex(backing(comp), idx))
# for Tuple
Base.getproperty(comp::Composite, idx::Int) = unthunk(getproperty(backing(comp), idx))
function Base.getproperty(
comp::Composite{P, <:NamedTuple{L}}, idx::Symbol
) where {P, L}
# Need to check L directly, or else this does not constant-fold
idx ∈ L || return Zero()
return unthunk(getproperty(backing(comp), idx))
end
Base.keys(comp::Composite) = keys(backing(comp))
Base.propertynames(comp::Composite) = propertynames(backing(comp))
function Base.iterate(comp::Composite, args...)
out = iterate(backing(comp), args...)
if out isa Nothing
return out
else
element, next_state = out
if comp isa Composite{<:Dict, <:Dict}
return (Pair(element.first, unthunk(element.second)), next_state)
else
return (unthunk(element), next_state)
end
end
end
Base.length(comp::Composite) = length(backing(comp))
Base.eltype(::Type{<:Composite{<:Any, T}}) where T = eltype(T)
function Base.map(f, comp::Composite{P, <:Tuple}) where P
vals::Tuple = map(f, backing(comp))
return Composite{P, typeof(vals)}(vals)
end
function Base.map(f, comp::Composite{P, <:NamedTuple{L}}) where{P, L}
vals = map(f, Tuple(backing(comp)))
named_vals = NamedTuple{L, typeof(vals)}(vals)
return Composite{P, typeof(named_vals)}(named_vals)
end
function Base.map(f, comp::Composite{P, <:Dict}) where {P<:Dict}
return Composite{P}(Dict(k => f(v) for (k, v) in backing(comp)))
end
Base.conj(comp::Composite) = map(conj, comp)
extern(comp::Composite) = backing(map(extern, comp)) # gives a NamedTuple or Tuple
"""
backing(x)
Accesses the backing field of a `Composite`,
or destructures any other composite type into a `NamedTuple`.
Identity function on `Tuple`. and `NamedTuple`s.
This is an internal function used to simplify operations between `Composite`s and the
primal types.
"""
backing(x::Tuple) = x
backing(x::NamedTuple) = x
backing(x::Dict) = x
backing(x::Composite) = getfield(x, :backing)
function backing(x::T)::NamedTuple where T
# note: all computation outside the if @generated happens at runtime.
# so the first 4 lines of the branchs look the same, but can not be moved out.
# see https://github.com/JuliaLang/julia/issues/34283
if @generated
!isstructtype(T) && throw(DomainError(T, "backing can only be use on composite types"))
nfields = fieldcount(T)
names = fieldnames(T)
types = fieldtypes(T)
vals = Expr(:tuple, ntuple(ii->:(getfield(x, $ii)), nfields)...)
return :(NamedTuple{$names, Tuple{$(types...)}}($vals))
else
!isstructtype(T) && throw(DomainError(T, "backing can only be use on composite types"))
nfields = fieldcount(T)
names = fieldnames(T)
types = fieldtypes(T)
vals = ntuple(ii->getfield(x, ii), nfields)
return NamedTuple{names, Tuple{types...}}(vals)
end
end
"""
canonicalize(comp::Composite{P}) -> Composite{P}
Return the canonical `Composite` for the primal type `P`.
The property names of the returned `Composite` match the field names of the primal,
and all fields of `P` not present in the input `comp` are explictly set to `Zero()`.
"""
function canonicalize(comp::Composite{P, <:NamedTuple{L}}) where {P,L}
nil = _zeroed_backing(P)
combined = merge(nil, backing(comp))
if length(combined) !== fieldcount(P)
throw(ArgumentError(
"Composite fields do not match primal fields.\n" *
"Composite fields: $L. Primal ($P) fields: $(fieldnames(P))"
))
end
return Composite{P, typeof(combined)}(combined)
end
# Tuple composites are always in their canonical form
canonicalize(comp::Composite{<:Tuple, <:Tuple}) = comp
# Dict composite are always in their canonical form.
canonicalize(comp::Composite{<:Any, <:AbstractDict}) = comp
"""
_zeroed_backing(P)
Returns a NamedTuple with same fields as `P`, and all values `Zero()`.
"""
@generated function _zeroed_backing(::Type{P}) where P
nil_base = ntuple(fieldcount(P)) do i
(fieldname(P, i), Zero())
end
return (; nil_base...)
end
"""
construct(::Type{T}, fields::[NamedTuple|Tuple])
Constructs an object of type `T`, with the given fields.
Fields must be correct in name and type, and `T` must have a default constructor.
This internally is called to construct structs of the primal type `T`,
after an operation such as the addition of a primal to a composite.
It should be overloaded, if `T` does not have a default constructor,
or if `T` needs to maintain some invarients between its fields.
"""
function construct(::Type{T}, fields::NamedTuple{L}) where {T, L}
# Tested and verified that that this avoids a ton of allocations
if length(L) !== fieldcount(T)
# if length is equal but names differ then we will catch that below anyway.
throw(ArgumentError("Unmatched fields. Type: $(fieldnames(T)), NamedTuple: $L"))
end
if @generated
vals = (:(getproperty(fields, $(QuoteNode(fname)))) for fname in fieldnames(T))
return :(T($(vals...)))
else
return T((getproperty(fields, fname) for fname in fieldnames(T))...)
end
end
construct(::Type{T}, fields::T) where T<:NamedTuple = fields
construct(::Type{T}, fields::T) where T<:Tuple = fields
elementwise_add(a::Tuple, b::Tuple) = map(+, a, b)
function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn}
# Rule of Composite addition: any fields not present are implict hard Zeros
# Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base.
# https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
if @generated
names = Base.merge_names(an, bn)
vals = map(names) do field
a_field = :(getproperty(a, $(QuoteNode(field))))
b_field = :(getproperty(b, $(QuoteNode(field))))
value_expr = if Base.sym_in(field, an)
if Base.sym_in(field, bn)
# in both
:($a_field + $b_field)
else
# only in `an`
a_field
end
else # must be in `b` only
b_field
end
Expr(:kw, field, value_expr)
end
return Expr(:tuple, Expr(:parameters, vals...))
else
names = Base.merge_names(an, bn)
vals = map(names) do field
value = if Base.sym_in(field, an)
a_field = getproperty(a, field)
if Base.sym_in(field, bn)
# in both
b_field = getproperty(b, field)
a_field + b_field
else
# only in `an`
a_field
end
else # must be in `b` only
getproperty(b, field)
end
field => value
end
return (;vals...)
end
end
elementwise_add(a::Dict, b::Dict) = merge(+, a, b)
struct PrimalAdditionFailedException{P} <: Exception
primal::P
differential::Composite{P}
original::Exception
end
function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
println(io, "Could not construct $P after addition.")
println(io, "This probably means no default constructor is defined.")
println(io, "Either define a default constructor")
printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue)
println(io, "\nor overload")
printstyled(io,
"ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))";
color=:blue
)
println(io, "\nor overload")
printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue)
println(io, "\nOriginal Exception:")
printstyled(io, err.original; color=:yellow)
println(io)
end
"""
NO_FIELDS
Constant for the reverse-mode derivative with respect to a structure that has no fields.
The most notable use for this is for the reverse-mode derivative with respect to the
function itself, when that function is not a closure.
"""
const NO_FIELDS = Zero()