|
| 1 | +module TensorOperationsMooncakeExt |
| 2 | + |
| 3 | +using TensorOperations |
| 4 | +# Mooncake imports ChainRulesCore as CRC to avoid name conflicts |
| 5 | +# here we import it ourselves to ensure the rules from the ChainRulesCore |
| 6 | +# extension are in fact loaded |
| 7 | +using Mooncake, Mooncake.CRC |
| 8 | +using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator |
| 9 | +using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout |
| 10 | +using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent |
| 11 | +using VectorInterface, TupleTools |
| 12 | + |
| 13 | +Mooncake.tangent_type(::Type{Index2Tuple}) = Mooncake.NoTangent |
| 14 | +Mooncake.tangent_type(::Type{<:AbstractBackend}) = Mooncake.NoTangent |
| 15 | +Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent |
| 16 | +Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent |
| 17 | +Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent |
| 18 | + |
| 19 | +trivtuple(N) = ntuple(identity, N) |
| 20 | + |
| 21 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any} |
| 22 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any} |
| 23 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any} |
| 24 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoralloc_add), Any} |
| 25 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), Any} |
| 26 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_type), Any} |
| 27 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoralloc_contract), Any} |
| 28 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.promote_contract), Any} |
| 29 | +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.promote_add), Any} |
| 30 | + |
| 31 | +Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensorfree!), Any} |
| 32 | +Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensoralloc), Any, Any, Any, Any} |
| 33 | + |
| 34 | +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} |
| 35 | +function Mooncake.rrule!!( |
| 36 | + ::CoDual{typeof(tensorcontract!)}, |
| 37 | + C_dC::CoDual{<:AbstractArray{TC}}, |
| 38 | + A_dA::CoDual{<:AbstractArray{TA}}, |
| 39 | + pA_dpA::CoDual{<:Index2Tuple}, |
| 40 | + conjA_dconjA::CoDual{Bool}, |
| 41 | + B_dB::CoDual{<:AbstractArray{TB}}, |
| 42 | + pB_dpB::CoDual{<:Index2Tuple}, |
| 43 | + conjB_dconjB::CoDual{Bool}, |
| 44 | + pAB_dpAB::CoDual{<:Index2Tuple}, |
| 45 | + α_dα::CoDual{Tα}, |
| 46 | + β_dβ::CoDual{Tβ}, |
| 47 | + ba_dba::CoDual..., |
| 48 | + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} |
| 49 | + C, dC = arrayify(C_dC) |
| 50 | + A, dA = arrayify(A_dA) |
| 51 | + B, dB = arrayify(B_dB) |
| 52 | + pA = primal(pA_dpA) |
| 53 | + pB = primal(pB_dpB) |
| 54 | + pAB = primal(pAB_dpAB) |
| 55 | + conjA = primal(conjA_dconjA) |
| 56 | + conjB = primal(conjB_dconjB) |
| 57 | + α = primal(α_dα) |
| 58 | + β = primal(β_dβ) |
| 59 | + ba = primal.(ba_dba) |
| 60 | + C_cache = copy(C) |
| 61 | + TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) |
| 62 | + function contract_pb(::NoRData) |
| 63 | + scale!(C, C_cache, One()) |
| 64 | + if Tα == Zero && Tβ == Zero |
| 65 | + scale!(dC, zero(TC)) |
| 66 | + return ntuple(i -> NoRData(), 11 + length(ba)) |
| 67 | + end |
| 68 | + ipAB = invperm(linearize(pAB)) |
| 69 | + pdC = ( |
| 70 | + TupleTools.getindices(ipAB, trivtuple(numout(pA))), |
| 71 | + TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), |
| 72 | + ) |
| 73 | + ipA = (invperm(linearize(pA)), ()) |
| 74 | + ipB = (invperm(linearize(pB)), ()) |
| 75 | + conjΔC = conjA |
| 76 | + conjB′ = conjA ? conjB : !conjB |
| 77 | + dA = tensorcontract!( |
| 78 | + dA, |
| 79 | + dC, pdC, conjΔC, |
| 80 | + B, reverse(pB), conjB′, |
| 81 | + ipA, |
| 82 | + conjA ? α : conj(α), One(), ba... |
| 83 | + ) |
| 84 | + conjΔC = conjB |
| 85 | + conjA′ = conjB ? conjA : !conjA |
| 86 | + dB = tensorcontract!( |
| 87 | + dB, |
| 88 | + A, reverse(pA), conjA′, |
| 89 | + dC, pdC, conjΔC, |
| 90 | + ipB, |
| 91 | + conjB ? α : conj(α), One(), ba... |
| 92 | + ) |
| 93 | + dα = if _needs_tangent(Tα) |
| 94 | + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) |
| 95 | + # TODO: consider using `inner` |
| 96 | + Mooncake._rdata( |
| 97 | + tensorscalar( |
| 98 | + tensorcontract( |
| 99 | + C_αβ, ((), trivtuple(numind(pAB))), true, |
| 100 | + dC, (trivtuple(numind(pAB)), ()), false, |
| 101 | + ((), ()), One(), ba... |
| 102 | + ) |
| 103 | + ) |
| 104 | + ) |
| 105 | + else |
| 106 | + NoRData() |
| 107 | + end |
| 108 | + dβ = if _needs_tangent(Tβ) |
| 109 | + # TODO: consider using `inner` |
| 110 | + Mooncake._rdata( |
| 111 | + tensorscalar( |
| 112 | + tensorcontract( |
| 113 | + C, ((), trivtuple(numind(pAB))), true, |
| 114 | + dC, (trivtuple(numind(pAB)), ()), false, |
| 115 | + ((), ()), One(), ba... |
| 116 | + ) |
| 117 | + ) |
| 118 | + ) |
| 119 | + else |
| 120 | + NoRData() |
| 121 | + end |
| 122 | + if β === Zero() |
| 123 | + scale!(dC, β) |
| 124 | + else |
| 125 | + scale!(dC, conj(β)) |
| 126 | + end |
| 127 | + return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... |
| 128 | + end |
| 129 | + return C_dC, contract_pb |
| 130 | +end |
| 131 | + |
| 132 | +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}} |
| 133 | +function Mooncake.rrule!!( |
| 134 | + ::CoDual{typeof(tensoradd!)}, |
| 135 | + C_dC::CoDual{<:AbstractArray{TC}}, |
| 136 | + A_dA::CoDual{<:AbstractArray{TA}}, |
| 137 | + pA_dpA::CoDual{<:Index2Tuple}, |
| 138 | + conjA_dconjA::CoDual{Bool}, |
| 139 | + α_dα::CoDual{Tα}, |
| 140 | + β_dβ::CoDual{Tβ}, |
| 141 | + ba_dba::CoDual..., |
| 142 | + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} |
| 143 | + C, dC = arrayify(C_dC) |
| 144 | + A, dA = arrayify(A_dA) |
| 145 | + pA = primal(pA_dpA) |
| 146 | + conjA = primal(conjA_dconjA) |
| 147 | + α = primal(α_dα) |
| 148 | + β = primal(β_dβ) |
| 149 | + ba = primal.(ba_dba) |
| 150 | + C_cache = copy(C) |
| 151 | + TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) |
| 152 | + function add_pb(::NoRData) |
| 153 | + scale!(C, C_cache, One()) |
| 154 | + ipA = invperm(linearize(pA)) |
| 155 | + dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) |
| 156 | + dα = if _needs_tangent(Tα) |
| 157 | + tensorscalar( |
| 158 | + tensorcontract( |
| 159 | + A, ((), linearize(pA)), !conjA, |
| 160 | + dC, (trivtuple(numind(pA)), ()), false, |
| 161 | + ((), ()), One(), ba... |
| 162 | + ) |
| 163 | + ) |
| 164 | + else |
| 165 | + Mooncake.NoRData() |
| 166 | + end |
| 167 | + dβ = if _needs_tangent(Tβ) |
| 168 | + tensorscalar( |
| 169 | + tensorcontract( |
| 170 | + C, ((), trivtuple(numind(pA))), true, |
| 171 | + dC, (trivtuple(numind(pA)), ()), false, |
| 172 | + ((), ()), One(), ba... |
| 173 | + ) |
| 174 | + ) |
| 175 | + else |
| 176 | + Mooncake.NoRData() |
| 177 | + end |
| 178 | + if β === Zero() |
| 179 | + scale!(dC, β) |
| 180 | + else |
| 181 | + scale!(dC, conj(β)) |
| 182 | + end |
| 183 | + return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... |
| 184 | + end |
| 185 | + return C_dC, add_pb |
| 186 | +end |
| 187 | + |
| 188 | +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}} |
| 189 | +function Mooncake.rrule!!( |
| 190 | + ::CoDual{typeof(tensortrace!)}, |
| 191 | + C_dC::CoDual{<:AbstractArray{TC}}, |
| 192 | + A_dA::CoDual{<:AbstractArray{TA}}, |
| 193 | + p_dp::CoDual{<:Index2Tuple}, |
| 194 | + q_dq::CoDual{<:Index2Tuple}, |
| 195 | + conjA_dconjA::CoDual{Bool}, |
| 196 | + α_dα::CoDual{Tα}, |
| 197 | + β_dβ::CoDual{Tβ}, |
| 198 | + ba_dba::CoDual..., |
| 199 | + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} |
| 200 | + C, dC = arrayify(C_dC) |
| 201 | + A, dA = arrayify(A_dA) |
| 202 | + p = primal(p_dp) |
| 203 | + q = primal(q_dq) |
| 204 | + conjA = primal(conjA_dconjA) |
| 205 | + α = primal(α_dα) |
| 206 | + β = primal(β_dβ) |
| 207 | + ba = primal.(ba_dba) |
| 208 | + C_cache = copy(C) |
| 209 | + TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) |
| 210 | + function trace_pb(::NoRData) |
| 211 | + scale!(C, C_cache, One()) |
| 212 | + ip = invperm((linearize(p)..., q[1]..., q[2]...)) |
| 213 | + Es = map(q[1], q[2]) do i1, i2 |
| 214 | + one( |
| 215 | + TensorOperations.tensoralloc_add( |
| 216 | + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA |
| 217 | + ) |
| 218 | + ) |
| 219 | + end |
| 220 | + E = _kron(Es, ba) |
| 221 | + dA = tensorproduct!( |
| 222 | + dA, dC, (trivtuple(numind(p)), ()), conjA, |
| 223 | + E, ((), trivtuple(numind(q))), conjA, |
| 224 | + (ip, ()), |
| 225 | + conjA ? α : conj(α), One(), ba... |
| 226 | + ) |
| 227 | + C_αβ = tensortrace(A, p, q, false, One(), ba...) |
| 228 | + dα = if _needs_tangent(Tα) |
| 229 | + Mooncake._rdata( |
| 230 | + tensorscalar( |
| 231 | + tensorcontract( |
| 232 | + C_αβ, ((), trivtuple(numind(p))), |
| 233 | + !conjA, |
| 234 | + dC, (trivtuple(numind(p)), ()), false, |
| 235 | + ((), ()), One(), ba... |
| 236 | + ) |
| 237 | + ) |
| 238 | + ) |
| 239 | + else |
| 240 | + NoRData() |
| 241 | + end |
| 242 | + dβ = if _needs_tangent(Tβ) |
| 243 | + Mooncake._rdata( |
| 244 | + tensorscalar( |
| 245 | + tensorcontract( |
| 246 | + C, ((), trivtuple(numind(p))), true, |
| 247 | + dC, (trivtuple(numind(p)), ()), false, |
| 248 | + ((), ()), One(), ba... |
| 249 | + ) |
| 250 | + ) |
| 251 | + ) |
| 252 | + else |
| 253 | + NoRData() |
| 254 | + end |
| 255 | + if β === Zero() |
| 256 | + scale!(dC, β) |
| 257 | + else |
| 258 | + scale!(dC, conj(β)) |
| 259 | + end |
| 260 | + return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... |
| 261 | + end |
| 262 | + return C_dC, trace_pb |
| 263 | +end |
| 264 | + |
| 265 | +end |
0 commit comments