Skip to content

Commit a9b9a4b

Browse files
kshyattlkdvos
andauthored
Split pullbacks into separate files (#244)
* Split pullbacks into separate files * Fix StridedViews version for now * Fix ordering * Format * Fix Project.toml * index magic * ChainRules using common pullback functions * only check public functions are included in docs * update docstrings * cleaner syntax * support complex _needs_tangent --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 22ef6dc commit a9b9a4b

11 files changed

Lines changed: 397 additions & 303 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Preferences = "1.4"
4848
PtrArrays = "1.2"
4949
Random = "1"
5050
Strided = "2.2"
51-
StridedViews = "0.3, 0.4"
51+
StridedViews = "0.3, 0.4, ~0.4.2"
5252
Test = "1"
5353
TupleTools = "1.6"
5454
VectorInterface = "0.4.1,0.5"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ makedocs(;
1818
"man/precompilation.md",
1919
],
2020
"Index" => "index/index.md",
21-
]
21+
],
22+
checkdocs = :public
2223
)
2324

2425
# Documenter can also automatically deploy documentation to gh-pages.

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 31 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@ module TensorOperationsChainRulesCoreExt
22

33
using TensorOperations
44
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
5+
using TensorOperations: pullback_dC, pullback_dβ,
6+
tensoradd_pullback_dA, tensoradd_pullback_dα,
7+
tensorcontract_pullback_dA, tensorcontract_pullback_dB, tensorcontract_pullback_dα,
8+
tensortrace_pullback_dA, tensortrace_pullback_dα
9+
510
using TensorOperations: DefaultBackend, DefaultAllocator, _kron
611
using ChainRulesCore
712
using TupleTools
813
using VectorInterface
914
using TupleTools: invperm
1015
using LinearAlgebra
1116

12-
trivtuple(N) = ntuple(identity, N)
13-
1417
@non_differentiable TensorOperations.tensorstructure(args...)
1518
@non_differentiable TensorOperations.tensoradd_structure(args...)
1619
@non_differentiable TensorOperations.tensoradd_type(args...)
@@ -74,53 +77,26 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
7477
projectα = ProjectTo(α)
7578
projectβ = ProjectTo(β)
7679

77-
function pullback(ΔC′)
80+
function tensoradd_pullback(ΔC′)
7881
ΔC = unthunk(ΔC′)
79-
dC = if β === Zero()
80-
ZeroTangent()
81-
else
82-
@thunk projectC(scale(ΔC, conj(β)))
83-
end
84-
dA = @thunk let
85-
ipA = invperm(linearize(pA))
86-
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
87-
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
88-
projectA(_dA)
89-
end
82+
83+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β))
84+
dA = @thunk projectA(tensoradd_pullback_dA(ΔC, C, A, pA, conjA, α, ba...))
9085
= if _needs_tangent(α)
91-
@thunk let
92-
_dα = tensorscalar(
93-
tensorcontract(
94-
A, ((), linearize(pA)), !conjA,
95-
ΔC, (trivtuple(numind(pA)), ()), false,
96-
((), ()), One(), ba...
97-
)
98-
)
99-
projectα(_dα)
100-
end
86+
@thunk projectα(tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...))
10187
else
10288
ZeroTangent()
10389
end
10490
= if _needs_tangent(β)
105-
@thunk let
106-
# TODO: consider using `inner`
107-
_dβ = tensorscalar(
108-
tensorcontract(
109-
C, ((), trivtuple(numind(pA))), true,
110-
ΔC, (trivtuple(numind(pA)), ()), false,
111-
((), ()), One(), ba...
112-
)
113-
)
114-
projectβ(_dβ)
115-
end
91+
@thunk projectβ(pullback_dβ(ΔC, C, β))
11692
else
11793
ZeroTangent()
11894
end
11995
dba = map(_ -> NoTangent(), ba)
12096
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
12197
end
12298

123-
return C′, pullback
99+
return C′, tensoradd_pullback
124100
end
125101

126102
function ChainRulesCore.rrule(
@@ -143,84 +119,31 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
143119
projectα = ProjectTo(α)
144120
projectβ = ProjectTo(β)
145121

146-
function pullback(ΔC′)
122+
function tensorcontract_pullback(ΔC′)
147123
ΔC = unthunk(ΔC′)
148-
ipAB = invperm(linearize(pAB))
149-
pΔC = (
150-
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
151-
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
152-
)
153-
dC = if β === Zero()
154-
ZeroTangent()
155-
else
156-
@thunk projectC(scale(ΔC, conj(β)))
157-
end
158-
dA = @thunk let
159-
ipA = (invperm(linearize(pA)), ())
160-
conjΔC = conjA
161-
conjB′ = conjA ? conjB : !conjB
162-
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
163-
_dA = tensorcontract!(
164-
_dA,
165-
ΔC, pΔC, conjΔC,
166-
B, reverse(pB), conjB′,
167-
ipA,
168-
conjA ? α : conj(α), Zero(), ba...
169-
)
170-
projectA(_dA)
171-
end
172-
dB = @thunk let
173-
ipB = (invperm(linearize(pB)), ())
174-
conjΔC = conjB
175-
conjA′ = conjB ? conjA : !conjA
176-
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
177-
_dB = tensorcontract!(
178-
_dB,
179-
A, reverse(pA), conjA′,
180-
ΔC, pΔC, conjΔC,
181-
ipB,
182-
conjB ? α : conj(α), Zero(), ba...
183-
)
184-
projectB(_dB)
185-
end
124+
125+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β))
126+
dA = @thunk projectA(tensorcontract_pullback_dA(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...))
127+
dB = @thunk projectB(tensorcontract_pullback_dB(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...))
186128
= if _needs_tangent(α)
187-
@thunk let
188-
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
189-
# TODO: consider using `inner`
190-
_dα = tensorscalar(
191-
tensorcontract(
192-
C_αβ, ((), trivtuple(numind(pAB))), true,
193-
ΔC, (trivtuple(numind(pAB)), ()), false,
194-
((), ()), One(), ba...
195-
)
196-
)
197-
projectα(_dα)
198-
end
129+
@thunk projectα(tensorcontract_pullback_dα(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...))
199130
else
200131
ZeroTangent()
201132
end
202133
= if _needs_tangent(β)
203-
@thunk let
204-
# TODO: consider using `inner`
205-
_dβ = tensorscalar(
206-
tensorcontract(
207-
C, ((), trivtuple(numind(pAB))), true,
208-
ΔC, (trivtuple(numind(pAB)), ()), false,
209-
((), ()), One(), ba...
210-
)
211-
)
212-
projectβ(_dβ)
213-
end
134+
@thunk projectβ(pullback_dβ(ΔC, C, β))
214135
else
215136
ZeroTangent()
216137
end
217138
dba = map(_ -> NoTangent(), ba)
218139
return NoTangent(), dC,
219-
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
220-
NoTangent(), dα, dβ, dba...
140+
dA, NoTangent(), NoTangent(),
141+
dB, NoTangent(), NoTangent(),
142+
NoTangent(),
143+
dα, dβ, dba...
221144
end
222145

223-
return C′, pullback
146+
return C′, tensorcontract_pullback
224147
end
225148

226149
function ChainRulesCore.rrule(
@@ -239,67 +162,26 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
239162
projectα = ProjectTo(α)
240163
projectβ = ProjectTo(β)
241164

242-
function pullback(ΔC′)
165+
function tensortrace_pullback(ΔC′)
243166
ΔC = unthunk(ΔC′)
244-
dC = if β === Zero()
245-
ZeroTangent()
246-
else
247-
@thunk projectC(scale(ΔC, conj(β)))
248-
end
249-
dA = @thunk let
250-
ip = invperm((linearize(p)..., q[1]..., q[2]...))
251-
Es = map(q[1], q[2]) do i1, i2
252-
one(
253-
TensorOperations.tensoralloc_add(
254-
scalartype(A), A, ((i1,), (i2,)), conjA
255-
)
256-
)
257-
end
258-
E = _kron(Es, ba)
259-
_dA = zerovector(A, VectorInterface.promote_scale(ΔC, α))
260-
_dA = tensorproduct!(
261-
_dA, ΔC, (trivtuple(numind(p)), ()), conjA,
262-
E, ((), trivtuple(numind(q))), conjA,
263-
(ip, ()),
264-
conjA ? α : conj(α), Zero(), ba...
265-
)
266-
projectA(_dA)
267-
end
167+
168+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β))
169+
dA = @thunk projectA(tensortrace_pullback_dA(ΔC, C, A, p, q, conjA, α, ba...))
268170
= if _needs_tangent(α)
269-
@thunk let
270-
C_αβ = tensortrace(A, p, q, false, One(), ba...)
271-
_dα = tensorscalar(
272-
tensorcontract(
273-
C_αβ, ((), trivtuple(numind(p))),
274-
!conjA,
275-
ΔC, (trivtuple(numind(p)), ()), false,
276-
((), ()), One(), ba...
277-
)
278-
)
279-
projectα(_dα)
280-
end
171+
@thunk projectα(tensortrace_pullback_dα(ΔC, C, A, p, q, conjA, α, ba...))
281172
else
282173
ZeroTangent()
283174
end
284175
= if _needs_tangent(β)
285-
@thunk let
286-
_dβ = tensorscalar(
287-
tensorcontract(
288-
C, ((), trivtuple(numind(p))), true,
289-
ΔC, (trivtuple(numind(p)), ()), false,
290-
((), ()), One(), ba...
291-
)
292-
)
293-
projectβ(_dβ)
294-
end
176+
@thunk projectβ(pullback_dβ(ΔC, C, β))
295177
else
296178
ZeroTangent()
297179
end
298180
dba = map(_ -> NoTangent(), ba)
299181
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
300182
end
301183

302-
return C′, pullback
184+
return C′, tensortrace_pullback
303185
end
304186

305187
# NCON functions

0 commit comments

Comments
 (0)