@@ -2,15 +2,18 @@ module TensorOperationsChainRulesCoreExt
22
33using TensorOperations
44using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
5+ using TensorOperations: pullback_dC, pullback_dβ,
6+ tensoradd_pullback_dA, tensoradd_pullback_dα,
7+ tensorcontract_pullback_dA, tensorcontract_pullback_dB, tensorcontract_pullback_dα,
8+ tensortrace_pullback_dA, tensortrace_pullback_dα
9+
510using TensorOperations: DefaultBackend, DefaultAllocator, _kron
611using ChainRulesCore
712using TupleTools
813using VectorInterface
914using TupleTools: invperm
1015using LinearAlgebra
1116
12- trivtuple (N) = ntuple (identity, N)
13-
1417@non_differentiable TensorOperations. tensorstructure (args... )
1518@non_differentiable TensorOperations. tensoradd_structure (args... )
1619@non_differentiable TensorOperations. tensoradd_type (args... )
@@ -74,53 +77,26 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
7477 projectα = ProjectTo (α)
7578 projectβ = ProjectTo (β)
7679
77- function pullback (ΔC′)
80+ function tensoradd_pullback (ΔC′)
7881 ΔC = unthunk (ΔC′)
79- dC = if β === Zero ()
80- ZeroTangent ()
81- else
82- @thunk projectC (scale (ΔC, conj (β)))
83- end
84- dA = @thunk let
85- ipA = invperm (linearize (pA))
86- _dA = zerovector (A, VectorInterface. promote_add (ΔC, α))
87- _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
88- projectA (_dA)
89- end
82+
83+ dC = β === Zero () ? ZeroTangent () : @thunk projectC (pullback_dC (ΔC, β))
84+ dA = @thunk projectA (tensoradd_pullback_dA (ΔC, C, A, pA, conjA, α, ba... ))
9085 dα = if _needs_tangent (α)
91- @thunk let
92- _dα = tensorscalar (
93- tensorcontract (
94- A, ((), linearize (pA)), ! conjA,
95- ΔC, (trivtuple (numind (pA)), ()), false ,
96- ((), ()), One (), ba...
97- )
98- )
99- projectα (_dα)
100- end
86+ @thunk projectα (tensoradd_pullback_dα (ΔC, C, A, pA, conjA, α, ba... ))
10187 else
10288 ZeroTangent ()
10389 end
10490 dβ = if _needs_tangent (β)
105- @thunk let
106- # TODO : consider using `inner`
107- _dβ = tensorscalar (
108- tensorcontract (
109- C, ((), trivtuple (numind (pA))), true ,
110- ΔC, (trivtuple (numind (pA)), ()), false ,
111- ((), ()), One (), ba...
112- )
113- )
114- projectβ (_dβ)
115- end
91+ @thunk projectβ (pullback_dβ (ΔC, C, β))
11692 else
11793 ZeroTangent ()
11894 end
11995 dba = map (_ -> NoTangent (), ba)
12096 return NoTangent (), dC, dA, NoTangent (), NoTangent (), dα, dβ, dba...
12197 end
12298
123- return C′, pullback
99+ return C′, tensoradd_pullback
124100end
125101
126102function ChainRulesCore. rrule (
@@ -143,84 +119,31 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
143119 projectα = ProjectTo (α)
144120 projectβ = ProjectTo (β)
145121
146- function pullback (ΔC′)
122+ function tensorcontract_pullback (ΔC′)
147123 ΔC = unthunk (ΔC′)
148- ipAB = invperm (linearize (pAB))
149- pΔC = (
150- TupleTools. getindices (ipAB, trivtuple (numout (pA))),
151- TupleTools. getindices (ipAB, numout (pA) .+ trivtuple (numin (pB))),
152- )
153- dC = if β === Zero ()
154- ZeroTangent ()
155- else
156- @thunk projectC (scale (ΔC, conj (β)))
157- end
158- dA = @thunk let
159- ipA = (invperm (linearize (pA)), ())
160- conjΔC = conjA
161- conjB′ = conjA ? conjB : ! conjB
162- _dA = zerovector (A, promote_contract (scalartype (ΔC), scalartype (B), typeof (α)))
163- _dA = tensorcontract! (
164- _dA,
165- ΔC, pΔC, conjΔC,
166- B, reverse (pB), conjB′,
167- ipA,
168- conjA ? α : conj (α), Zero (), ba...
169- )
170- projectA (_dA)
171- end
172- dB = @thunk let
173- ipB = (invperm (linearize (pB)), ())
174- conjΔC = conjB
175- conjA′ = conjB ? conjA : ! conjA
176- _dB = zerovector (B, promote_contract (scalartype (ΔC), scalartype (A), typeof (α)))
177- _dB = tensorcontract! (
178- _dB,
179- A, reverse (pA), conjA′,
180- ΔC, pΔC, conjΔC,
181- ipB,
182- conjB ? α : conj (α), Zero (), ba...
183- )
184- projectB (_dB)
185- end
124+
125+ dC = β === Zero () ? ZeroTangent () : @thunk projectC (pullback_dC (ΔC, β))
126+ dA = @thunk projectA (tensorcontract_pullback_dA (ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba... ))
127+ dB = @thunk projectB (tensorcontract_pullback_dB (ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba... ))
186128 dα = if _needs_tangent (α)
187- @thunk let
188- C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
189- # TODO : consider using `inner`
190- _dα = tensorscalar (
191- tensorcontract (
192- C_αβ, ((), trivtuple (numind (pAB))), true ,
193- ΔC, (trivtuple (numind (pAB)), ()), false ,
194- ((), ()), One (), ba...
195- )
196- )
197- projectα (_dα)
198- end
129+ @thunk projectα (tensorcontract_pullback_dα (ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba... ))
199130 else
200131 ZeroTangent ()
201132 end
202133 dβ = if _needs_tangent (β)
203- @thunk let
204- # TODO : consider using `inner`
205- _dβ = tensorscalar (
206- tensorcontract (
207- C, ((), trivtuple (numind (pAB))), true ,
208- ΔC, (trivtuple (numind (pAB)), ()), false ,
209- ((), ()), One (), ba...
210- )
211- )
212- projectβ (_dβ)
213- end
134+ @thunk projectβ (pullback_dβ (ΔC, C, β))
214135 else
215136 ZeroTangent ()
216137 end
217138 dba = map (_ -> NoTangent (), ba)
218139 return NoTangent (), dC,
219- dA, NoTangent (), NoTangent (), dB, NoTangent (), NoTangent (),
220- NoTangent (), dα, dβ, dba...
140+ dA, NoTangent (), NoTangent (),
141+ dB, NoTangent (), NoTangent (),
142+ NoTangent (),
143+ dα, dβ, dba...
221144 end
222145
223- return C′, pullback
146+ return C′, tensorcontract_pullback
224147end
225148
226149function ChainRulesCore. rrule (
@@ -239,67 +162,26 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
239162 projectα = ProjectTo (α)
240163 projectβ = ProjectTo (β)
241164
242- function pullback (ΔC′)
165+ function tensortrace_pullback (ΔC′)
243166 ΔC = unthunk (ΔC′)
244- dC = if β === Zero ()
245- ZeroTangent ()
246- else
247- @thunk projectC (scale (ΔC, conj (β)))
248- end
249- dA = @thunk let
250- ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
251- Es = map (q[1 ], q[2 ]) do i1, i2
252- one (
253- TensorOperations. tensoralloc_add (
254- scalartype (A), A, ((i1,), (i2,)), conjA
255- )
256- )
257- end
258- E = _kron (Es, ba)
259- _dA = zerovector (A, VectorInterface. promote_scale (ΔC, α))
260- _dA = tensorproduct! (
261- _dA, ΔC, (trivtuple (numind (p)), ()), conjA,
262- E, ((), trivtuple (numind (q))), conjA,
263- (ip, ()),
264- conjA ? α : conj (α), Zero (), ba...
265- )
266- projectA (_dA)
267- end
167+
168+ dC = β === Zero () ? ZeroTangent () : @thunk projectC (pullback_dC (ΔC, β))
169+ dA = @thunk projectA (tensortrace_pullback_dA (ΔC, C, A, p, q, conjA, α, ba... ))
268170 dα = if _needs_tangent (α)
269- @thunk let
270- C_αβ = tensortrace (A, p, q, false , One (), ba... )
271- _dα = tensorscalar (
272- tensorcontract (
273- C_αβ, ((), trivtuple (numind (p))),
274- ! conjA,
275- ΔC, (trivtuple (numind (p)), ()), false ,
276- ((), ()), One (), ba...
277- )
278- )
279- projectα (_dα)
280- end
171+ @thunk projectα (tensortrace_pullback_dα (ΔC, C, A, p, q, conjA, α, ba... ))
281172 else
282173 ZeroTangent ()
283174 end
284175 dβ = if _needs_tangent (β)
285- @thunk let
286- _dβ = tensorscalar (
287- tensorcontract (
288- C, ((), trivtuple (numind (p))), true ,
289- ΔC, (trivtuple (numind (p)), ()), false ,
290- ((), ()), One (), ba...
291- )
292- )
293- projectβ (_dβ)
294- end
176+ @thunk projectβ (pullback_dβ (ΔC, C, β))
295177 else
296178 ZeroTangent ()
297179 end
298180 dba = map (_ -> NoTangent (), ba)
299181 return NoTangent (), dC, dA, NoTangent (), NoTangent (), NoTangent (), dα, dβ, dba...
300182 end
301183
302- return C′, pullback
184+ return C′, tensortrace_pullback
303185end
304186
305187# NCON functions
0 commit comments