Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Preferences = "1.4"
PtrArrays = "1.2"
Random = "1"
Strided = "2.2"
StridedViews = "0.3, 0.4"
StridedViews = "=0.4.1"
Comment thread
kshyatt marked this conversation as resolved.
Outdated
Test = "1"
TupleTools = "1.6"
VectorInterface = "0.4.1,0.5"
Expand Down
4 changes: 1 addition & 3 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
module TensorOperationsChainRulesCoreExt

using TensorOperations
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent, trivtuple
using TensorOperations: DefaultBackend, DefaultAllocator, _kron
using ChainRulesCore
using TupleTools
using VectorInterface
using TupleTools: invperm
using LinearAlgebra

trivtuple(N) = ntuple(identity, N)

@non_differentiable TensorOperations.tensorstructure(args...)
@non_differentiable TensorOperations.tensoradd_structure(args...)
@non_differentiable TensorOperations.tensoradd_type(args...)
Expand Down
153 changes: 10 additions & 143 deletions ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using TensorOperations
# extension are in fact loaded
using Mooncake, Mooncake.CRC
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout
using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!
using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent
using VectorInterface, TupleTools

Expand All @@ -16,8 +16,6 @@ Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent

trivtuple(N) = ntuple(identity, N)

@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any}
Expand Down Expand Up @@ -61,69 +59,9 @@ function Mooncake.rrule!!(
TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
function contract_pb(::NoRData)
scale!(C, C_cache, One())
if Tα == Zero && Tβ == Zero
scale!(dC, zero(TC))
return ntuple(i -> NoRData(), 11 + length(ba))
end
ipAB = invperm(linearize(pAB))
pdC = (
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
)
ipA = (invperm(linearize(pA)), ())
ipB = (invperm(linearize(pB)), ())
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
dA = tensorcontract!(
dA,
dC, pdC, conjΔC,
B, reverse(pB), conjB′,
ipA,
conjA ? α : conj(α), One(), ba...
)
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
dB = tensorcontract!(
dB,
A, reverse(pA), conjA′,
dC, pdC, conjΔC,
ipB,
conjB ? α : conj(α), One(), ba...
)
dα = if _needs_tangent(Tα)
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
Mooncake._rdata(
tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(pAB))), true,
dC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
dβ = if _needs_tangent(Tβ)
# TODO: consider using `inner`
Mooncake._rdata(
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pAB))), true,
dC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
if β === Zero()
scale!(dC, β)
else
scale!(dC, conj(β))
end
dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, B, α, β, pA, pB, pAB, conjA, conjB, ba...)
dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα)
dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ)
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
end
return C_dC, contract_pb
Expand Down Expand Up @@ -151,35 +89,9 @@ function Mooncake.rrule!!(
TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...)
function add_pb(::NoRData)
scale!(C, C_cache, One())
ipA = invperm(linearize(pA))
dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...)
dα = if _needs_tangent(Tα)
tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
dC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
else
Mooncake.NoRData()
end
dβ = if _needs_tangent(Tβ)
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pA))), true,
dC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
else
Mooncake.NoRData()
end
if β === Zero()
scale!(dC, β)
else
scale!(dC, conj(β))
end
dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, α, β, pA, conjA, ba...)
dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα)
dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ)
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
end
return C_dC, add_pb
Expand Down Expand Up @@ -209,54 +121,9 @@ function Mooncake.rrule!!(
TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...)
function trace_pb(::NoRData)
scale!(C, C_cache, One())
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
one(
TensorOperations.tensoralloc_add(
TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA
)
)
end
E = _kron(Es, ba)
dA = tensorproduct!(
dA, dC, (trivtuple(numind(p)), ()), conjA,
E, ((), trivtuple(numind(q))), conjA,
(ip, ()),
conjA ? α : conj(α), One(), ba...
)
C_αβ = tensortrace(A, p, q, false, One(), ba...)
dα = if _needs_tangent(Tα)
Mooncake._rdata(
tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(p))),
!conjA,
dC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
dβ = if _needs_tangent(Tβ)
Mooncake._rdata(
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(p))), true,
dC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
if β === Zero()
scale!(dC, β)
else
scale!(dC, conj(β))
end
dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, α, β, p, q, conjA, ba...)
dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα)
dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ)
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
end
return C_dC, trace_pb
Expand Down
6 changes: 6 additions & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ include("backends.jl")
include("interface.jl")
include("utils.jl")

# Generic pullbacks for AD
#---------------------------
include("pullbacks/add.jl")
include("pullbacks/trace.jl")
include("pullbacks/contract.jl")

# Index notation via macros
#---------------------------
@nospecialize
Expand Down
32 changes: 32 additions & 0 deletions src/pullbacks/add.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
function tensoradd_pullback!(ΔC, ΔA, C, A, α, β, pA, conjA::Bool, ba...)
Comment thread
kshyatt marked this conversation as resolved.
Outdated
ipA = invperm(linearize(pA))
ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA
tensoradd!(ΔAc, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...)
if eltype(ΔC) <: Complex && eltype(ΔA) <: Real
ΔA .+= real.(ΔAc)
end
Comment thread
kshyatt marked this conversation as resolved.
Outdated
Δα = if _needs_tangent(α)
tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
else
nothing
end
Δβ = if _needs_tangent(β)
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pA))), true,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
else
nothing
end
scale!(ΔC, conj(β))
return ΔC, ΔA, Δα, Δβ
end
62 changes: 62 additions & 0 deletions src/pullbacks/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...)
Comment thread
kshyatt marked this conversation as resolved.
Outdated
ipAB = invperm(linearize(pAB))
pdC = (
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
)
ipA = (invperm(linearize(pA)), ())
ipB = (invperm(linearize(pB)), ())
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA
tensorcontract!(
ΔAc,
ΔC, pdC, conjΔC,
B, reverse(pB), conjB′,
ipA,
conjA ? α : conj(α), One(), ba...
)
if eltype(ΔC) <: Complex && eltype(ΔA) <: Real
ΔA .+= real.(ΔAc)
end
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
ΔBc = eltype(ΔC) <: Complex && eltype(ΔB) <: Real ? zerovector(B, VectorInterface.promote_add(ΔC, α)) : ΔB
tensorcontract!(
ΔBc,
A, reverse(pA), conjA′,
ΔC, pdC, conjΔC,
ipB,
conjB ? α : conj(α), One(), ba...
)
if eltype(ΔC) <: Complex && eltype(ΔB) <: Real
ΔB .+= real.(ΔBc)
end
Δα = if _needs_tangent(α)
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
else
nothing
end
Δβ = if _needs_tangent(β)
# TODO: consider using `inner`
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
else
nothing
end
scale!(ΔC, conj(β))
return ΔC, ΔA, ΔB, Δα, Δβ
end
47 changes: 47 additions & 0 deletions src/pullbacks/trace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
function tensortrace_pullback!(ΔC, ΔA, C, A, α, β, p, q, conjA, ba...)
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
one(
TensorOperations.tensoralloc_add(
TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA
)
)
end
E = _kron(Es, ba)
ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA
tensorproduct!(
ΔAc, ΔC, (trivtuple(numind(p)), ()), conjA,
E, ((), trivtuple(numind(q))), conjA,
(ip, ()),
conjA ? α : conj(α), One(), ba...
)
if eltype(ΔC) <: Complex && eltype(ΔA) <: Real
ΔA .+= real.(ΔAc)
end
C_αβ = tensortrace(A, p, q, false, One(), ba...)
Δα = if _needs_tangent(α)
tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(p))),
!conjA,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
else
nothing
end
Δβ = if _needs_tangent(β)
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(p))), true,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
else
nothing
end
scale!(ΔC, conj(β))
return ΔC, ΔA, Δα, Δβ
end
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ _needs_tangent(x) = _needs_tangent(typeof(x))
_needs_tangent(::Type{<:Number}) = true
_needs_tangent(::Type{<:Integer}) = false
_needs_tangent(::Type{<:Union{One, Zero}}) = false

trivtuple(N) = ntuple(identity, N)
Loading