diff --git a/src/factorizations/pullbacks.jl b/src/factorizations/pullbacks.jl index 488089809..3af3c82b3 100644 --- a/src/factorizations/pullbacks.jl +++ b/src/factorizations/pullbacks.jl @@ -24,7 +24,6 @@ for pullback! in (:qr_null_pullback!, :lq_null_pullback!) return Δt end end - _notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t)) for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!) @@ -51,8 +50,61 @@ for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_ foreachblock(Δt, t) do c, (Δb, b) Fc = block.(F, Ref(c)) ΔFc = block.(ΔF, Ref(c)) - return MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...) + MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...) + return nothing end return Δt end end + +for f in (:qr, :lq) + remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!) + remove_f_null_gauge_dependence! = Symbol(:remove_, f, :_null_gauge_dependence!) + @eval function MAK.$remove_f_gauge_dependence!( + ΔF₁::AbstractTensorMap, ΔF₂::AbstractTensorMap, A, F₁, F₂; + kwargs... + ) + foreachblock(ΔF₁, ΔF₂, A, F₁, F₂) do _, (Δf₁, Δf₂, a, f₁, f₂) + MAK.$remove_f_gauge_dependence!(Δf₁, Δf₂, a, f₁, f₂) + return nothing + end + return ΔF₁, ΔF₂ + end + # Already captured by MAK implementation + # @eval function MAK.$remove_f_null_gauge_dependence!(ΔN::AbstractTensorMap, A, N; kwargs...) + # foreachblock(ΔN, A, N) do _, (Δn, a, n) + # $remove_f_gauge_dependence!(Δn, a, n) + # end + # return ΔN + # end +end + +for f in (:eig, :eigh) + remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!) + @eval function MAK.$remove_f_gauge_dependence!(ΔV::AbstractTensorMap, D, V; kwargs...) + foreachblock(ΔV, D, V) do c, (Δv, d, v) + MAK.$remove_f_gauge_dependence!(Δv, d, v; kwargs...) + return nothing + end + return ΔV + end + @eval function MAK.$remove_f_gauge_dependence!(ΔV::AbstractTensorMap, D, V, inds; kwargs...) + foreachblock(ΔV, D, V) do c, (Δv, d, v) + haskey(inds, c) || return nothing + ind = inds[c] + MAK.$remove_f_gauge_dependence!(Δv, d, v, ind; kwargs...) + return nothing + end + return ΔV + end +end +function MAK.remove_svd_gauge_dependence!( + ΔU::AbstractTensorMap, ΔVᴴ::AbstractTensorMap, U, S, Vᴴ; + kwargs... + ) + foreachblock(ΔU, ΔVᴴ, U, S, Vᴴ) do c, (Δu, Δvᴴ, u, s, vᴴ) + MAK.remove_svd_gauge_dependence!(Δu, Δvᴴ, u, s, vᴴ) + return nothing + end + return ΔU, ΔVᴴ +end diff --git a/test/chainrules/factorizations.jl b/test/chainrules/factorizations.jl index 483926efa..376a2b6f9 100644 --- a/test/chainrules/factorizations.jl +++ b/test/chainrules/factorizations.jl @@ -9,7 +9,8 @@ using LinearAlgebra using Zygote using MatrixAlgebraKit using MatrixAlgebraKit: diagview - +using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!, + remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence! # Tests # ----- @@ -52,7 +53,7 @@ for V in spacelist @test_logs (:warn, r"^`qr") match_mode = :any full_pb((ΔQ, ΔR)) end - remove_qrgauge_dependence!(ΔQ, t, Q) + remove_qr_gauge_dependence!(ΔQ, ΔR, t, Q, R) test_ad_rrule(qr_full, t; fkwargs, atol, rtol, output_tangent = (ΔQ, ΔR)) test_ad_rrule( @@ -90,7 +91,7 @@ for V in spacelist # @test_logs (:warn, r"^`lq") match_mode = :any full_pb((ΔL, ΔQ)) end - remove_lqgauge_dependence!(ΔQ, t, Q) + remove_lq_gauge_dependence!(ΔL, ΔQ, t, L, Q) test_ad_rrule(lq_full, t; fkwargs, atol, rtol, output_tangent = (ΔL, ΔQ)) test_ad_rrule( @@ -114,7 +115,7 @@ for V in spacelist Δv = rand_tangent(v) Δd = rand_tangent(d) Δd2 = randn!(similar(d, space(d))) - remove_eiggauge_dependence!(Δv, d, v) + remove_eig_gauge_dependence!(Δv, d, v) test_ad_rrule(eig_full, t; output_tangent = (Δd, Δv), atol, rtol) test_ad_rrule(first ∘ eig_full, t; output_tangent = Δd, atol, rtol) @@ -126,7 +127,7 @@ for V in spacelist Δv = rand_tangent(v) Δd = rand_tangent(d) Δd2 = randn!(similar(d, space(d))) - remove_eighgauge_dependence!(Δv, d, v) + remove_eigh_gauge_dependence!(Δv, d, v) # necessary for FiniteDifferences to not complain eigh_full′ = eigh_full ∘ project_hermitian @@ -155,7 +156,7 @@ for V in spacelist USVᴴ = svd_compact(t) ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ) ΔS2 = randn!(similar(ΔS, space(ΔS))) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol) # test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol) # test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol) @@ -170,7 +171,7 @@ for V in spacelist trunc = truncspace(V_trunc) USVᴴ_trunc = svd_trunc(t; trunc) ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc))) - remove_svdgauge_dependence!( + remove_svd_gauge_dependence!( ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol ) test_ad_rrule( diff --git a/test/mooncake/factorizations.jl b/test/mooncake/factorizations.jl index ae998dddb..8955c4ecf 100644 --- a/test/mooncake/factorizations.jl +++ b/test/mooncake/factorizations.jl @@ -3,6 +3,8 @@ using TensorKit using TensorOperations using VectorInterface: Zero, One using MatrixAlgebraKit +using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!, + remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence! using Mooncake using Random @@ -25,7 +27,7 @@ eltypes = (Float64, ComplexF64) # qr_full/qr_null requires being careful with gauges QR = qr_full(A) ΔQR = Mooncake.randn_tangent(rng, QR) - remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + remove_qr_gauge_dependence!(ΔQR..., A, QR...) Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) # TODO: # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) @@ -37,7 +39,7 @@ eltypes = (Float64, ComplexF64) # qr_full/qr_null requires being careful with gauges QR = qr_full(A) ΔQR = Mooncake.randn_tangent(rng, QR) - remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + remove_qr_gauge_dependence!(ΔQR..., A, QR...) Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) # TODO: # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) @@ -51,7 +53,7 @@ eltypes = (Float64, ComplexF64) # qr_full/qr_null requires being careful with gauges LQ = lq_full(A) ΔLQ = Mooncake.randn_tangent(rng, LQ) - remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) # TODO: # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) @@ -63,7 +65,7 @@ eltypes = (Float64, ComplexF64) # qr_full/qr_null requires being careful with gauges LQ = lq_full(A) ΔLQ = Mooncake.randn_tangent(rng, LQ) - remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) # TODO: # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) @@ -73,13 +75,13 @@ eltypes = (Float64, ComplexF64) for t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) DV = eig_full(t) ΔDV = Mooncake.randn_tangent(rng, DV) - remove_eiggauge_dependence!(ΔDV[2], DV...) + remove_eig_gauge_dependence!(ΔDV[2], DV...) Mooncake.TestUtils.test_rule(rng, eig_full, t; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false) th = project_hermitian(t) DV = eigh_full(th) ΔDV = Mooncake.randn_tangent(rng, DV) - remove_eighgauge_dependence!(ΔDV[2], DV...) + remove_eigh_gauge_dependence!(ΔDV[2], DV...) Mooncake.TestUtils.test_rule(rng, eigh_full ∘ project_hermitian, th; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false) end end @@ -88,20 +90,20 @@ eltypes = (Float64, ComplexF64) for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) USVᴴ = svd_compact(t) ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) - remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) - # USVᴴ = svd_full(t) - # ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) - # remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) - # Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) + USVᴴ = svd_full(t) + ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) + remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) trunc = truncspace(V_trunc) alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc) USVᴴtrunc = svd_trunc(t, alg) ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) - remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) + remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode) end end diff --git a/test/setup.jl b/test/setup.jl index 106aa6c36..9c8244dab 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -7,8 +7,6 @@ export random_fusion export sectorlist, fast_sectorlist # export dim_isapprox export default_spacelist, factorization_spacelist, ad_spacelist -export remove_qrgauge_dependence!, remove_lqgauge_dependence! -export remove_eiggauge_dependence!, remove_eighgauge_dependence!, remove_svdgauge_dependence! export test_ad_rrule export _isunitary, _isone @@ -398,78 +396,6 @@ function ad_spacelist(fast_tests::Bool) return fast_tests ? (Vtr, VRepU₁, VfHubbard, VRepA4Twistedℤ₄) : (Vtr, VRepℤ₂, VRepCU₁, VfHubbard, VRepA4Twistedℤ₄, VIBMRepA4) end -# Gauge-fixing tangents for AD factorization tests -# ------------------------------------------------- -function remove_qrgauge_dependence!(ΔQ, t, Q) - for (c, b) in blocks(ΔQ) - m, n = size(block(t, c)) - minmn = min(m, n) - Qc = block(Q, c) - Q1 = view(Qc, 1:m, 1:minmn) - ΔQ2 = view(b, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - end - return ΔQ -end -function remove_lqgauge_dependence!(ΔQ, t, Q) - for (c, b) in blocks(ΔQ) - m, n = size(block(t, c)) - minmn = min(m, n) - Qc = block(Q, c) - Q1 = view(Qc, 1:minmn, 1:n) - ΔQ2 = view(b, (minmn + 1):n, :) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - end - return ΔQ -end -function remove_eiggauge_dependence!( - ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) - ) - gaugepart = V' * ΔV - for (c, b) in blocks(gaugepart) - Dc = diagview(block(D, c)) - # for some reason this fails only on tests, and I cannot reproduce it in an - # interactive session. - # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 - for j in axes(b, 2), i in axes(b, 1) - abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) - end - end - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) - ) - gaugepart = project_antihermitian!(V' * ΔV) - for (c, b) in blocks(gaugepart) - Dc = diagview(block(D, c)) - # for some reason this fails only on tests, and I cannot reproduce it in an - # interactive session. - # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 - for j in axes(b, 2), i in axes(b, 1) - abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) - end - end - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S) - ) - gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ') - for (c, b) in blocks(gaugepart) - Sd = diagview(block(S, c)) - # for some reason this fails only on tests, and I cannot reproduce it in an - # interactive session. - # b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0 - for j in axes(b, 2), i in axes(b, 1) - abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0) - end - end - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end # ChainRules test utilities # -------------------------