Skip to content

Commit 22ef6dc

Browse files
kshyattKatharine Hyatt
andauthored
Rules with Mooncake (#240)
* Rules with Mooncake * Bump version and minimum Julia requirement * No custom rule for tensorscalar * Use scale to make things more generic * Fix bad dC behaviour * More consistency * Just use beta to scale * See you later allocator * Vararg * Format * Comments * Update autodiff.md with warnings * Fix typo --------- Co-authored-by: Katharine Hyatt <katharine.s.hyatt@gmail.com>
1 parent 465adff commit 22ef6dc

9 files changed

Lines changed: 410 additions & 27 deletions

File tree

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
3-
version = "5.4"
3+
version = "5.5.0"
44
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
55

66
[deps]
@@ -23,11 +23,13 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
2323
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2424
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2525
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
26+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2627

2728
[extensions]
2829
TensorOperationsBumperExt = "Bumper"
2930
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
3031
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
32+
TensorOperationsMooncakeExt = "Mooncake"
3133

3234
[compat]
3335
Aqua = "0.6, 0.7, 0.8"
@@ -39,6 +41,7 @@ DynamicPolynomials = "0.5, 0.6"
3941
LRUCache = "1"
4042
LinearAlgebra = "1.6"
4143
Logging = "1.6"
44+
Mooncake = "0.4.195"
4245
PackageExtensionCompat = "1"
4346
PrecompileTools = "1.1"
4447
Preferences = "1.4"
@@ -59,9 +62,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5962
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
6063
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
6164
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
65+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
6266
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6367
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6468
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6569

6670
[targets]
67-
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper"]
71+
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"]

docs/src/man/autodiff.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
# Automatic differentiation
22

33
TensorOperations offers experimental support for reverse-mode automatic diffentiation (AD)
4-
through the use of [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). As the basic
4+
through the use of [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)
5+
and [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl). As the basic
56
operations are multi-linear, the vector-Jacobian products thereof can all be expressed in
67
terms of the operations defined in VectorInterface and TensorOperations. Thus, any custom
78
type whose tangent type also support these interfaces will automatically inherit
89
reverse-mode AD support.
910

1011
As the [`@tensor`](@ref) macro rewrites everything in terms of the basic tensor operations,
11-
the reverse-mode rules for these methods are supplied. However, because most AD-engines do
12+
the reverse-mode rules for these methods are supplied. However, because ChainRules.jl does
1213
not support in-place mutation, effectively these operations will be replaced with a
1314
non-mutating version. This is similar to the behaviour found in
1415
[BangBang.jl](https://github.com/JuliaFolds/BangBang.jl), as the operations will be
1516
in-place, except for the pieces of code that are being differentiated. In effect, this
1617
amounts to replacing all assignments (`=`) with definitions (`:=`) within the context of
1718
[`@tensor`](@ref).
1819

20+
Mooncake.jl *does* support in-place mutation, and as a result on the reverse pass
21+
all mutated input variables should be restored to their state before the forward-pass
22+
function was called. Currently, this is **not done** for buffers you provide to various
23+
TensorOperations functions, so relying on the state of the buffer (e.g. a bumper) being
24+
restored will **silently** return incorrect results.
25+
1926
!!! warning "Experimental"
2027

2128
While some rudimentary tests are run, the AD support is currently not incredibly
2229
well-tested. Because of the way it is implemented, the use of AD will tacitly replace
2330
mutating operations with a non-mutating variant. This might lead to unwanted bugs that
24-
are hard to track down. Additionally, for mixed scalar types their also might be
25-
unexpected or unwanted behaviour.
31+
are hard to track down. Additionally, for mixed scalar types there also might be
32+
unexpected or unwanted behaviour.

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module TensorOperationsChainRulesCoreExt
22

33
using TensorOperations
4-
using TensorOperations: numind, numin, numout, promote_contract
5-
using TensorOperations: DefaultBackend, DefaultAllocator
4+
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
5+
using TensorOperations: DefaultBackend, DefaultAllocator, _kron
66
using ChainRulesCore
77
using TupleTools
88
using VectorInterface
@@ -55,13 +55,6 @@ function ChainRulesCore.rrule(::typeof(tensorscalar), C)
5555
return tensorscalar(C), tensorscalar_pullback
5656
end
5757

58-
# To avoid computing rrules for α and β when these aren't needed, we want to have a
59-
# type-stable quick bail-out
60-
_needs_tangent(x) = _needs_tangent(typeof(x))
61-
_needs_tangent(::Type{<:Number}) = true
62-
_needs_tangent(::Type{<:Integer}) = false
63-
_needs_tangent(::Type{<:Union{One, Zero}}) = false
64-
6558
# The current `rrule` design makes sure that the implementation for custom types does
6659
# not need to support the backend or allocator arguments
6760
function ChainRulesCore.rrule(
@@ -309,15 +302,6 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
309302
return C′, pullback
310303
end
311304

312-
_kron(Es::NTuple{1}, ba) = Es[1]
313-
function _kron(Es::NTuple{N, Any}, ba) where {N}
314-
E1 = Es[1]
315-
E2 = _kron(Base.tail(Es), ba)
316-
p2 = ((), trivtuple(2 * N - 2))
317-
p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...))
318-
return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...)
319-
end
320-
321305
# NCON functions
322306
@non_differentiable TensorOperations.ncontree(args...)
323307
@non_differentiable TensorOperations.nconoutput(args...)
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
module TensorOperationsMooncakeExt
2+
3+
using TensorOperations
4+
# Mooncake imports ChainRulesCore as CRC to avoid name conflicts
5+
# here we import it ourselves to ensure the rules from the ChainRulesCore
6+
# extension are in fact loaded
7+
using Mooncake, Mooncake.CRC
8+
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
9+
using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout
10+
using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent
11+
using VectorInterface, TupleTools
12+
13+
Mooncake.tangent_type(::Type{Index2Tuple}) = Mooncake.NoTangent
14+
Mooncake.tangent_type(::Type{<:AbstractBackend}) = Mooncake.NoTangent
15+
Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent
16+
Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent
17+
Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent
18+
19+
trivtuple(N) = ntuple(identity, N)
20+
21+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any}
22+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any}
23+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any}
24+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoralloc_add), Any}
25+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), Any}
26+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_type), Any}
27+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoralloc_contract), Any}
28+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.promote_contract), Any}
29+
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.promote_add), Any}
30+
31+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensorfree!), Any}
32+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensoralloc), Any, Any, Any, Any}
33+
34+
Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}}
35+
function Mooncake.rrule!!(
36+
::CoDual{typeof(tensorcontract!)},
37+
C_dC::CoDual{<:AbstractArray{TC}},
38+
A_dA::CoDual{<:AbstractArray{TA}},
39+
pA_dpA::CoDual{<:Index2Tuple},
40+
conjA_dconjA::CoDual{Bool},
41+
B_dB::CoDual{<:AbstractArray{TB}},
42+
pB_dpB::CoDual{<:Index2Tuple},
43+
conjB_dconjB::CoDual{Bool},
44+
pAB_dpAB::CoDual{<:Index2Tuple},
45+
α_dα::CoDual{Tα},
46+
β_dβ::CoDual{Tβ},
47+
ba_dba::CoDual...,
48+
) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
49+
C, dC = arrayify(C_dC)
50+
A, dA = arrayify(A_dA)
51+
B, dB = arrayify(B_dB)
52+
pA = primal(pA_dpA)
53+
pB = primal(pB_dpB)
54+
pAB = primal(pAB_dpAB)
55+
conjA = primal(conjA_dconjA)
56+
conjB = primal(conjB_dconjB)
57+
α = primal(α_dα)
58+
β = primal(β_dβ)
59+
ba = primal.(ba_dba)
60+
C_cache = copy(C)
61+
TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
62+
function contract_pb(::NoRData)
63+
scale!(C, C_cache, One())
64+
if== Zero &&== Zero
65+
scale!(dC, zero(TC))
66+
return ntuple(i -> NoRData(), 11 + length(ba))
67+
end
68+
ipAB = invperm(linearize(pAB))
69+
pdC = (
70+
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
71+
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
72+
)
73+
ipA = (invperm(linearize(pA)), ())
74+
ipB = (invperm(linearize(pB)), ())
75+
conjΔC = conjA
76+
conjB′ = conjA ? conjB : !conjB
77+
dA = tensorcontract!(
78+
dA,
79+
dC, pdC, conjΔC,
80+
B, reverse(pB), conjB′,
81+
ipA,
82+
conjA ? α : conj(α), One(), ba...
83+
)
84+
conjΔC = conjB
85+
conjA′ = conjB ? conjA : !conjA
86+
dB = tensorcontract!(
87+
dB,
88+
A, reverse(pA), conjA′,
89+
dC, pdC, conjΔC,
90+
ipB,
91+
conjB ? α : conj(α), One(), ba...
92+
)
93+
= if _needs_tangent(Tα)
94+
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
95+
# TODO: consider using `inner`
96+
Mooncake._rdata(
97+
tensorscalar(
98+
tensorcontract(
99+
C_αβ, ((), trivtuple(numind(pAB))), true,
100+
dC, (trivtuple(numind(pAB)), ()), false,
101+
((), ()), One(), ba...
102+
)
103+
)
104+
)
105+
else
106+
NoRData()
107+
end
108+
= if _needs_tangent(Tβ)
109+
# TODO: consider using `inner`
110+
Mooncake._rdata(
111+
tensorscalar(
112+
tensorcontract(
113+
C, ((), trivtuple(numind(pAB))), true,
114+
dC, (trivtuple(numind(pAB)), ()), false,
115+
((), ()), One(), ba...
116+
)
117+
)
118+
)
119+
else
120+
NoRData()
121+
end
122+
if β === Zero()
123+
scale!(dC, β)
124+
else
125+
scale!(dC, conj(β))
126+
end
127+
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
128+
end
129+
return C_dC, contract_pb
130+
end
131+
132+
Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}}
133+
function Mooncake.rrule!!(
134+
::CoDual{typeof(tensoradd!)},
135+
C_dC::CoDual{<:AbstractArray{TC}},
136+
A_dA::CoDual{<:AbstractArray{TA}},
137+
pA_dpA::CoDual{<:Index2Tuple},
138+
conjA_dconjA::CoDual{Bool},
139+
α_dα::CoDual{Tα},
140+
β_dβ::CoDual{Tβ},
141+
ba_dba::CoDual...,
142+
) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
143+
C, dC = arrayify(C_dC)
144+
A, dA = arrayify(A_dA)
145+
pA = primal(pA_dpA)
146+
conjA = primal(conjA_dconjA)
147+
α = primal(α_dα)
148+
β = primal(β_dβ)
149+
ba = primal.(ba_dba)
150+
C_cache = copy(C)
151+
TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...)
152+
function add_pb(::NoRData)
153+
scale!(C, C_cache, One())
154+
ipA = invperm(linearize(pA))
155+
dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...)
156+
= if _needs_tangent(Tα)
157+
tensorscalar(
158+
tensorcontract(
159+
A, ((), linearize(pA)), !conjA,
160+
dC, (trivtuple(numind(pA)), ()), false,
161+
((), ()), One(), ba...
162+
)
163+
)
164+
else
165+
Mooncake.NoRData()
166+
end
167+
= if _needs_tangent(Tβ)
168+
tensorscalar(
169+
tensorcontract(
170+
C, ((), trivtuple(numind(pA))), true,
171+
dC, (trivtuple(numind(pA)), ()), false,
172+
((), ()), One(), ba...
173+
)
174+
)
175+
else
176+
Mooncake.NoRData()
177+
end
178+
if β === Zero()
179+
scale!(dC, β)
180+
else
181+
scale!(dC, conj(β))
182+
end
183+
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
184+
end
185+
return C_dC, add_pb
186+
end
187+
188+
Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}}
189+
function Mooncake.rrule!!(
190+
::CoDual{typeof(tensortrace!)},
191+
C_dC::CoDual{<:AbstractArray{TC}},
192+
A_dA::CoDual{<:AbstractArray{TA}},
193+
p_dp::CoDual{<:Index2Tuple},
194+
q_dq::CoDual{<:Index2Tuple},
195+
conjA_dconjA::CoDual{Bool},
196+
α_dα::CoDual{Tα},
197+
β_dβ::CoDual{Tβ},
198+
ba_dba::CoDual...,
199+
) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
200+
C, dC = arrayify(C_dC)
201+
A, dA = arrayify(A_dA)
202+
p = primal(p_dp)
203+
q = primal(q_dq)
204+
conjA = primal(conjA_dconjA)
205+
α = primal(α_dα)
206+
β = primal(β_dβ)
207+
ba = primal.(ba_dba)
208+
C_cache = copy(C)
209+
TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...)
210+
function trace_pb(::NoRData)
211+
scale!(C, C_cache, One())
212+
ip = invperm((linearize(p)..., q[1]..., q[2]...))
213+
Es = map(q[1], q[2]) do i1, i2
214+
one(
215+
TensorOperations.tensoralloc_add(
216+
TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA
217+
)
218+
)
219+
end
220+
E = _kron(Es, ba)
221+
dA = tensorproduct!(
222+
dA, dC, (trivtuple(numind(p)), ()), conjA,
223+
E, ((), trivtuple(numind(q))), conjA,
224+
(ip, ()),
225+
conjA ? α : conj(α), One(), ba...
226+
)
227+
C_αβ = tensortrace(A, p, q, false, One(), ba...)
228+
= if _needs_tangent(Tα)
229+
Mooncake._rdata(
230+
tensorscalar(
231+
tensorcontract(
232+
C_αβ, ((), trivtuple(numind(p))),
233+
!conjA,
234+
dC, (trivtuple(numind(p)), ()), false,
235+
((), ()), One(), ba...
236+
)
237+
)
238+
)
239+
else
240+
NoRData()
241+
end
242+
= if _needs_tangent(Tβ)
243+
Mooncake._rdata(
244+
tensorscalar(
245+
tensorcontract(
246+
C, ((), trivtuple(numind(p))), true,
247+
dC, (trivtuple(numind(p)), ()), false,
248+
((), ()), One(), ba...
249+
)
250+
)
251+
)
252+
else
253+
NoRData()
254+
end
255+
if β === Zero()
256+
scale!(dC, β)
257+
else
258+
scale!(dC, conj(β))
259+
end
260+
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
261+
end
262+
return C_dC, trace_pb
263+
end
264+
265+
end

src/TensorOperations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export checkcontractible, tensorcost
3434
include("indices.jl")
3535
include("backends.jl")
3636
include("interface.jl")
37+
include("utils.jl")
3738

3839
# Index notation via macros
3940
#---------------------------

0 commit comments

Comments
 (0)