Skip to content

Commit 0ad41f2

Browse files
committed
bugfix
1 parent 20faba4 commit 0ad41f2

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

ext/TensorKitMooncakeExt/tangent.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ Mooncake.frule!!(::Dual{typeof(getfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::
198198

199199
# rrules
200200
function _rrule_getfield_common(t_dt::CoDual{<:DiagOrTensorMap}, field_sym::Symbol, n_args::Int)
201-
t = primal(t)
202-
dt = tangent(t)
201+
t = primal(t_dt)
202+
dt = tangent(t_dt)
203203

204204
value_primal = getfield(t, field_sym)
205205
value_dvalue = Mooncake.CoDual(
@@ -224,13 +224,13 @@ function _rrule_getfield_common(t_dt::CoDual{<:DiagOrTensorMap}, field_sym::Symb
224224
end
225225

226226
Mooncake.rrule!!(::CoDual{typeof(Mooncake.lgetfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual) =
227-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 3)
227+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 3)
228228
Mooncake.rrule!!(::CoDual{typeof(Mooncake.lgetfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual, o_do::CoDual) =
229-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 4)
229+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 4)
230230
Mooncake.rrule!!(::CoDual{typeof(getfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual) =
231-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 3)
231+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 3)
232232
Mooncake.rrule!!(::CoDual{typeof(getfield)}, t_dt::CoDual{<:DiagOrTensorMap}, f_df::CoDual, o_do::CoDual) =
233-
_rrule_getfield_common(t_dt, _field_symbol(primal(f_df)), 4)
233+
_rrule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)), 4)
234234

235235

236236
# Custom rules for constructors

0 commit comments

Comments
 (0)