From 2df4bce756570d3c722bdc0cd18907b11663e3f0 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 27 Feb 2026 12:25:45 +0100 Subject: [PATCH 1/6] some changes to the ad test utils --- Project.toml | 2 +- src/pullbacks/svd.jl | 4 +- test/testsuite/TestSuite.jl | 6 +- test/testsuite/ad_utils.jl | 360 +++++++++++++++-------------------- test/testsuite/chainrules.jl | 8 +- 5 files changed, 159 insertions(+), 221 deletions(-) diff --git a/Project.toml b/Project.toml index e079639a..e851d669 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" -JET = "0.9, 0.10" +JET = "0.9, 0.10, 0.11" LinearAlgebra = "1" Mooncake = "0.5" ParallelTestRunner = "2" diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 9c131464..01fdc4f7 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,4 +1,4 @@ -svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true) +svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true) function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV)) mask = abs.(Sr' .- Sr) .< degeneracy_atol @@ -43,7 +43,7 @@ function svd_pullback!( minmn = min(m, n) S = diagview(Smat) length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) - r = svd_rank(S, rank_atol) + r = svd_rank(S; rank_atol) Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) Sr = view(S, 1:r) diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 0e64d9e9..c662a067 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -11,7 +11,7 @@ module TestSuite using Test using MatrixAlgebraKit using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, norm, istriu, istril, I +using LinearAlgebra: Diagonal, norm, istriu, istril, I, mul! using Random, StableRNGs using Mooncake using AMDGPU, CUDA @@ -86,9 +86,9 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz) end instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A)))) -function instantiate_rank_deficient_matrix(T, sz; trunc = trunctol(rtol = 0.5)) +function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2))) A = instantiate_matrix(T, sz) - V, C = left_orth!(A; trunc = trunctol(rtol = 0.5)) + V, C = left_orth!(A; trunc) return mul!(A, V, C) end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 07528257..3b32cfc9 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -1,9 +1,13 @@ +structured_randn!(A::AbstractMatrix) = randn!(A) +structured_randn!(A::Diagonal) = (randn!(diagview(A)); return A) + """ - remove_eig_gauge_dependence!(ΔV, D, V) + remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The -eigenvectors are only determined up to complex phase (and unitary mixing for degenerate -eigenvalues), so the corresponding components of `ΔV` are projected out. +eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding components of +`ΔV` are projected out. """ function remove_eig_gauge_dependence!( ΔV, D, V; @@ -16,12 +20,12 @@ function remove_eig_gauge_dependence!( end """ - remove_eigh_gauge_dependence!(ΔV, D, V) + remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix -`V`. The eigenvectors are only determined up to complex phase (and unitary mixing for -degenerate eigenvalues), so the corresponding anti-Hermitian components of `V' * ΔV` are -projected out. +`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian +components of `V' * ΔV` are projected out. """ function remove_eigh_gauge_dependence!( ΔV, D, V; @@ -35,47 +39,51 @@ function remove_eigh_gauge_dependence!( end """ - remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = ..., rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The -singular vectors are only determined up to a common complex phase per singular value (and -unitary mixing for degenerate singular values), so the corresponding anti-Hermitian components -of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. For the full SVD, the extra columns of `U` -and rows of `Vᴴ` beyond `min(m, n)` are additionally zeroed out. +singular vectors are only determined up to a common complex phase per singular value (or a +unitary transformation across singular vectors associated with degenerate singular values), +so the corresponding anti-Hermitian components of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. +For the full SVD, the extra columns of `U` and rows of `Vᴴ` beyond the rank `r` are +additionally zeroed out, where `r = count(diagview(S) .> rank_atol)`. """ function remove_svd_gauge_dependence!( ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S), + rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S) ) - minmn = length(diagview(S)) - U₁ = view(U, :, 1:minmn) - Vᴴ₁ = view(Vᴴ, 1:minmn, :) - ΔU₁ = view(ΔU, :, 1:minmn) - ΔVᴴ₁ = view(ΔVᴴ, 1:minmn, :) + r = MatrixAlgebraKit.svd_rank(diagview(S); rank_atol) + U₁ = view(U, :, 1:r) + Vᴴ₁ = view(Vᴴ, 1:r, :) + ΔU₁ = view(ΔU, :, 1:r) + ΔVᴴ₁ = view(ΔVᴴ, 1:r, :) Sdiag = diagview(S) gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) gaugepart = project_antihermitian!(gaugepart) gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 mul!(ΔU₁, U₁, gaugepart, -1, 1) - ΔU[:, (minmn + 1):end] .= 0 - ΔVᴴ[(minmn + 1):end, :] .= 0 + ΔU[:, (r + 1):end] .= 0 + ΔVᴴ[(r + 1):end, :] .= 0 return ΔU, ΔVᴴ end """ - remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R) + remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔQ` and `ΔR` of the QR factors `Q` and `R`. For the full QR decomposition, the extra columns of `Q` beyond the rank `r` are not uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out. """ -function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R) - r = MatrixAlgebraKit.qr_rank(R) +function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R)) + r = MatrixAlgebraKit.qr_rank(R; rank_atol) Q₁ = @view Q[:, 1:r] ΔQ₂ = @view ΔQ[:, (r + 1):end] - Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ - mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) + ΔQ₂ .= 0 + # TODO: refine this by differentiating between rank deficiency and qr_full cases + # Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ + # mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) view(ΔR, (r + 1):size(ΔR, 1), :) .= 0 return ΔQ, ΔR end @@ -93,19 +101,21 @@ function remove_qr_null_gauge_dependence!(ΔN, A, N) end """ - remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q) + remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and `Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. Additionally, columns of `ΔL` beyond the rank are zeroed out. """ -function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q) - r = MatrixAlgebraKit.lq_rank(L) +function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) + r = MatrixAlgebraKit.lq_rank(L; rank_atol) Q₁ = @view Q[1:r, :] ΔQ₂ = @view ΔQ[(r + 1):end, :] - ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' - mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) + ΔQ₂ .= 0 + # TODO: refine this by differentiating between rank deficiency and lq_full cases + # ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' + # mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) view(ΔL, :, (r + 1):size(ΔL, 2)) .= 0 return ΔL, ΔQ end @@ -130,11 +140,7 @@ Remove the gauge-dependent part from the cotangent `ΔN` of the left null space space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column span of the compact QR factor `Q₁` of `A`. """ -function remove_left_null_gauge_dependence!(ΔN, A, N) - Q, _ = qr_compact(A) - mul!(ΔN, Q, Q' * ΔN) - return ΔN -end +remove_left_null_gauge_dependence!(ΔN, A, N) = remove_qr_null_gauge_dependence!(ΔN, A, N) """ remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) @@ -143,11 +149,7 @@ Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null sp null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of the compact LQ factor `Q₁` of `A`. """ -function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - _, Q = lq_compact(A) - mul!(ΔNᴴ, ΔNᴴ * Q', Q) - return ΔNᴴ -end +remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) """ call_and_zero!(f!, A, alg) @@ -240,21 +242,11 @@ end function ad_qr_compact_setup(A) QR = qr_compact(A) - ΔQR = randn!.(copy.(QR)) - remove_qr_gauge_dependence!(ΔQR..., A, QR...) + ΔQR = structured_randn!.(copy.(QR)) + A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end -function ad_qr_compact_setup(A::Diagonal) - m, n = size(A) - minmn = min(m, n) - QR = qr_compact(A) - T = eltype(A) - ΔQ = Diagonal(randn!(similar(A.diag, T, m))) - ΔR = Diagonal(randn!(similar(A.diag, T, m))) - return QR, (ΔQ, ΔR) -end - function ad_qr_null_setup(A) N = qr_null(A) ΔN = randn!(copy(N)) @@ -264,53 +256,51 @@ end function ad_qr_full_setup(A) QR = qr_full(A) - ΔQR = randn!.(copy.(QR)) - remove_qr_gauge_dependence!(ΔQR..., A, QR...) + ΔQR = structured_randn!.(copy.(QR)) + A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end -ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) -function ad_qr_rank_deficient_compact_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) - Q, R = qr_compact(Ard) - QR = (Q, R) - ΔQ = randn!(similar(A, T, m, minmn)) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - MatrixAlgebraKit.zero!(ΔQ2) - ΔR = randn!(similar(A, T, minmn, n)) - view(ΔR, (r + 1):minmn, :) .= 0 - return (Q, R), (ΔQ, ΔR) -end - -function ad_qr_rank_deficient_compact_setup(A::Diagonal) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard_ = randn!(similar(A, T, m)) - MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) - Ard = Diagonal(Ard_) - Q, R = qr_compact(Ard) - ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) - ΔR = Diagonal(randn!(similar(diagview(A), T, m))) - MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) - MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) - return (Q, R), (ΔQ, ΔR) -end +# function ad_qr_rank_deficient_compact_setup(A) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) +# Q, R = qr_compact(Ard) +# QR = (Q, R) +# ΔQ = randn!(similar(A, T, m, minmn)) +# Q1 = view(Q, 1:m, 1:r) +# Q2 = view(Q, 1:m, (r + 1):minmn) +# ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) +# MatrixAlgebraKit.zero!(ΔQ2) +# ΔR = randn!(similar(A, T, minmn, n)) +# view(ΔR, (r + 1):minmn, :) .= 0 +# return (Q, R), (ΔQ, ΔR) +# end + +# function ad_qr_rank_deficient_compact_setup(A::Diagonal) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard_ = randn!(similar(A, T, m)) +# MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) +# Ard = Diagonal(Ard_) +# Q, R = qr_compact(Ard) +# ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) +# ΔR = Diagonal(randn!(similar(diagview(A), T, m))) +# MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) +# MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) +# return (Q, R), (ΔQ, ΔR) +# end function ad_lq_compact_setup(A) LQ = lq_compact(A) - ΔLQ = randn!.(copy.(LQ)) - remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + ΔLQ = structured_randn!.(copy.(LQ)) + A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end -ad_lq_compact_setup(A::Diagonal) = ad_qr_compact_setup(A) function ad_lq_null_setup(A) T = eltype(A) @@ -322,67 +312,60 @@ end function ad_lq_full_setup(A) LQ = lq_full(A) - ΔLQ = randn!.(copy.(LQ)) - remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + ΔLQ = structured_randn!.(copy.(LQ)) + A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end -ad_lq_full_setup(A::Diagonal) = ad_qr_full_setup(A) -function ad_lq_rank_deficient_compact_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) - L, Q = lq_compact(Ard) - ΔL = randn!(similar(A, T, m, minmn)) - ΔQ = randn!(similar(A, T, minmn, n)) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - return (L, Q), (ΔL, ΔQ) -end -ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) +# function ad_lq_rank_deficient_compact_setup(A) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) +# L, Q = lq_compact(Ard) +# ΔL = randn!(similar(A, T, m, minmn)) +# ΔQ = randn!(similar(A, T, minmn, n)) +# Q1 = view(Q, 1:r, 1:n) +# Q2 = view(Q, (r + 1):minmn, 1:n) +# ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) +# ΔQ2 .= 0 +# view(ΔL, :, (r + 1):minmn) .= 0 +# return (L, Q), (ΔL, ΔQ) +# end +# ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) function ad_eig_full_setup(A) - m, n = size(A) - T = eltype(A) - DV = eig_full(A) - D, V = DV - ΔV = randn!(similar(A, complex(T), m, m)) + D, V = eig_full(A) + ΔD, ΔV = structured_randn!.(similar.((D, V))) ΔV = remove_eig_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A, complex(T), m))) - return DV, (ΔD, ΔV) + return (D, V), (ΔD, ΔV) end -function ad_eig_full_setup(A::Diagonal) - m, n = size(A) - T = complex(eltype(A)) - DV = eig_full(A) - D, V = DV - ΔV = randn!(similar(A.diag, T, m, m)) - ΔV = remove_eig_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A.diag, T, m))) - return DV, (ΔD, ΔV) -end +# function ad_eig_full_setup(A::Diagonal) +# m, n = size(A) +# T = complex(eltype(A)) +# DV = eig_full(A) +# D, V = DV +# ΔV = randn!(similar(A.diag, T, m, m)) +# ΔV = remove_eig_gauge_dependence!(ΔV, D, V) +# ΔD = Diagonal(randn!(similar(A.diag, T, m))) +# return DV, (ΔD, ΔV) +# end function ad_eig_vals_setup(A) - m, n = size(A) - T = complex(eltype(A)) D = eig_vals(A) - ΔD = randn!(similar(A, complex(T), m)) + ΔD = randn!(similar(D)) return D, ΔD end -function ad_eig_vals_setup(A::Diagonal) - m, n = size(A) - T = complex(eltype(A)) - D = eig_vals(A) - ΔD = randn!(similar(A.diag, T, m)) - return D, ΔD -end +# function ad_eig_vals_setup(A::Diagonal) +# m, n = size(A) +# T = complex(eltype(A)) +# D = eig_vals(A) +# ΔD = randn!(similar(A.diag, T, m)) +# return D, ΔD +# end function ad_eig_trunc_setup(A, truncalg) DV, ΔDV = ad_eig_full_setup(A) @@ -395,21 +378,15 @@ function ad_eig_trunc_setup(A, truncalg) end function ad_eigh_full_setup(A) - m, n = size(A) - T = eltype(A) - DV = eigh_full(A) - D, V = DV - ΔV = randn!(similar(A, T, m, m)) + D, V = eigh_full(A) + ΔD, ΔV = structured_randn!.(similar.((D, V))) ΔV = remove_eigh_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A, real(T), m))) - return DV, (ΔD, ΔV) + return (D, V), (ΔD, ΔV) end function ad_eigh_vals_setup(A) - m, n = size(A) - T = eltype(A) D = eigh_vals(A) - ΔD = randn!(similar(A, real(T), m)) + ΔD = randn!(similar(D)) return D, ΔD end @@ -424,55 +401,39 @@ function ad_eigh_trunc_setup(A, truncalg) end function ad_svd_compact_setup(A) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - ΔU = randn!(similar(A, T, m, minmn)) - ΔS = Diagonal(randn!(similar(A, real(T), minmn))) - ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) + ΔU, ΔS, ΔVᴴ = structured_randn!.(similar.((U, S, Vᴴ))) ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -function ad_svd_compact_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - ΔU = randn!(similar(A.diag, T, m, n)) - ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) - ΔVᴴ = randn!(similar(A.diag, T, m, n)) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) - return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) -end +# function ad_svd_compact_setup(A::Diagonal) +# m, n = size(A) +# T = eltype(A) +# minmn = min(m, n) +# ΔU = randn!(similar(A.diag, T, m, n)) +# ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) +# ΔVᴴ = randn!(similar(A.diag, T, m, n)) +# U, S, Vᴴ = svd_compact(A) +# ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) +# return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) +# end function ad_svd_full_setup(A) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - (_, _, _), (ΔU, ΔS, ΔVᴴ) = ad_svd_compact_setup(A) - ΔUfull = similar(A, T, m, m) - ΔUfull .= zero(T) - ΔSfull = similar(A, real(T), m, n) - ΔSfull .= zero(real(T)) - ΔVᴴfull = similar(A, T, n, n) - ΔVᴴfull .= zero(T) U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS) - return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) + ΔU = structured_randn!(similar(U)) + ΔVᴴ = structured_randn!(similar(Vᴴ)) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔS = zero(S) + randn!(diagview(ΔS)) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) +# ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) function ad_svd_vals_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) S = svd_vals(A) - ΔS = randn!(similar(A, real(T), minmn)) + ΔS = randn!(similar(S)) return S, ΔS end @@ -489,48 +450,25 @@ function ad_svd_trunc_setup(A, truncalg) end function ad_left_polar_setup(A) - m, n = size(A) - T = eltype(A) WP = left_polar(A) - ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) - return WP, ΔWP -end - -function ad_left_polar_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) - WP = left_polar(A) - ΔWP = (Diagonal(randn!(similar(A.diag))), randn!(similar(WP[2]))) + ΔWP = structured_randn!.(similar.(WP)) return WP, ΔWP end function ad_right_polar_setup(A) - m, n = size(A) - T = eltype(A) - PWᴴ = right_polar(A) - ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) - return PWᴴ, ΔPWᴴ -end -function ad_right_polar_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) PWᴴ = right_polar(A) - ΔPWᴴ = (randn!(similar(PWᴴ[1])), Diagonal(randn!(similar(A.diag)))) + ΔPWᴴ = structured_randn!.(similar.(PWᴴ)) return PWᴴ, ΔPWᴴ end function ad_left_orth_setup(A) - m, n = size(A) - T = eltype(A) VC = left_orth(A) - ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) + ΔVC = structured_randn!.(similar.(VC)) return VC, ΔVC end function ad_left_orth_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) VC = left_orth(A) - ΔVC = (Diagonal(randn!(similar(A.diag, T, m))), Diagonal(randn!(similar(A.diag, T, m)))) + ΔVC = structured_randn!.(similar.(VC)) return VC, ΔVC end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 52806750..c722fdb6 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -17,14 +17,14 @@ for f in @eval begin function $copy_f(input, alg) if $_hermitian - input = (input + input') / 2 + input = project_hermitian(input) end return $f(input, alg) end function ChainRulesCore.rrule(::typeof($copy_f), input, alg) output = MatrixAlgebraKit.initialize_output($f!, input, alg) if $_hermitian - input = (input + input') / 2 + input = project_hermitian(input) else input = copy(input) end @@ -113,7 +113,7 @@ function test_chainrules_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + QR, ΔQR = ad_qr_compact_setup(Ard) ΔQ, ΔR = ΔQR test_rrule( cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); @@ -190,7 +190,7 @@ function test_chainrules_lq( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + LQ, ΔLQ = ad_lq_compact_setup(Ard) test_rrule( cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol From 01f1be7577de08cf164081162cb28908e5b03c1d Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sat, 7 Mar 2026 02:01:54 +0100 Subject: [PATCH 2/6] some more changes / lq tests are failing for unknown reasons --- src/common/view.jl | 6 ++-- src/pullbacks/lq.jl | 70 ++++++++++++++---------------------- src/pullbacks/qr.jl | 72 +++++++++++++++----------------------- test/testsuite/ad_utils.jl | 49 ++++++++++++++------------ 4 files changed, 86 insertions(+), 111 deletions(-) diff --git a/src/common/view.jl b/src/common/view.jl index e03bfb88..8cd989ea 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -24,7 +24,8 @@ diagonal(v::AbstractVector) = Diagonal(v) function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) m, n = size(A) - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) + minmn = min(m, n) + I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (m - minmn)) offset = 0 for j in 1:n r = (j + 1):m @@ -37,7 +38,8 @@ end function uppertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) m, n = size(A) - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) + minmn = min(m, n) + I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (n - minmn)) offset = 0 for i in 1:m r = (i + 1):n diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 452eda92..977929f1 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -5,40 +5,25 @@ function check_lq_cotangents( gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) minmn = min(size(L, 1), size(Q, 2)) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # rows of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - Δgauge_Q = norm(ΔQ2, Inf) - Δgauge = max(Δgauge, Δgauge_Q) - end - if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn) - Δgauge_L = norm(ΔL22, Inf) - Δgauge = max(Δgauge, Δgauge_L) - end - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + ΔQ₂ = view(ΔQ, (p + 1):minmn, :) + ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :] + Δgauge_Q = norm(ΔQ₂, Inf) + Q₁ = view(Q, 1:p, :) + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) + Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) + Δgauge = max(Δgauge, Δgauge_Q) + end + if !iszerotangent(ΔL) + ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) + Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf) + Δgauge = max(Δgauge, Δgauge_L) end - return -end - -function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2)) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `lq_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) Δgauge ≤ gauge_atol || - @warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return nothing end """ @@ -67,13 +52,13 @@ function lq_pullback!( L, Q = LQ m = size(L, 1) n = size(Q, 2) + minmn = min(m, n) p = lq_rank(L; rank_atol) ΔL, ΔQ = ΔLQ Q1 = view(Q, 1:p, :) - Q2 = view(Q, (p + 1):size(Q, 1), :) - L11 = view(L, 1:p, 1:p) + L11 = LowerTriangular(view(L, 1:p, 1:p)) ΔA1 = view(ΔA, 1:p, :) ΔA2 = view(ΔA, (p + 1):m, :) @@ -83,12 +68,11 @@ function lq_pullback!( if !iszerotangent(ΔQ) ΔQ1 = view(ΔQ, 1:p, :) copy!(ΔQ̃, ΔQ1) - if p < size(Q, 1) - Q2 = view(Q, (p + 1):size(Q, 1), :) - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - ΔQ2Q1ᴴ = ΔQ2 * Q1' - check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol) - ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) + if minmn < size(Q, 1) + ΔQ3 = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) + Q3 = view(Q, (minmn + 1):size(Q, 1), :) + ΔQ3Q1ᴴ = ΔQ3 * Q1' + ΔQ̃ = mul!(ΔQ̃, ΔQ3Q1ᴴ', Q3, -1, 1) end end if !iszerotangent(ΔL) && m > p @@ -102,7 +86,7 @@ function lq_pullback!( # construct M M = zero!(similar(L, (p, p))) if !iszerotangent(ΔL) - ΔL11 = view(ΔL, 1:p, 1:p) + ΔL11 = LowerTriangular(view(ΔL, 1:p, 1:p)) M = mul!(M, L11', ΔL11, 1, 1) end M = mul!(M, ΔQ̃, Q1', -1, 1) @@ -111,8 +95,8 @@ function lq_pullback!( Md = diagview(M) Md .= real.(Md) end - ldiv!(LowerTriangular(L11)', M) - ldiv!(LowerTriangular(L11)', ΔQ̃) + ldiv!(L11', M) + ldiv!(L11', ΔQ̃) ΔA1 = mul!(ΔA1, M, Q1, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 643b2c68..3a40460c 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -6,40 +6,25 @@ function check_qr_cotangents( gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) minmn = min(size(Q, 1), size(R, 2)) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge_Q = norm(ΔQ2, Inf) - Δgauge = max(Δgauge, Δgauge_Q) - end - if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) - Δgauge_R = norm(ΔR22, Inf) - Δgauge = max(Δgauge, Δgauge_R) - end - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + ΔQ₂ = view(ΔQ, :, (p + 1):minmn) + ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full + Δgauge_Q = norm(ΔQ₂, Inf) + Q₁ = view(Q, :, 1:p) + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1) + Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) + Δgauge = max(Δgauge, Δgauge_Q) + end + if !iszerotangent(ΔR) + ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) + Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf) + Δgauge = max(Δgauge, Δgauge_R) end - return -end - -function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2)) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `qr_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) Δgauge ≤ gauge_atol || - @warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return nothing end """ @@ -69,13 +54,14 @@ function qr_pullback!( Q, R = QR m = size(Q, 1) n = size(R, 2) + minmn = min(m, n) Rd = diagview(R) p = qr_rank(R; rank_atol) ΔQ, ΔR = ΔQR Q1 = view(Q, :, 1:p) - R11 = view(R, 1:p, 1:p) + R11 = UpperTriangular(view(R, 1:p, 1:p)) ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) @@ -83,13 +69,13 @@ function qr_pullback!( ΔQ̃ = zero!(similar(Q, (m, p))) if !iszerotangent(ΔQ) - copy!(ΔQ̃, view(ΔQ, :, 1:p)) - if p < size(Q, 2) - Q2 = view(Q, :, (p + 1):size(Q, 2)) - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Q1dΔQ2 = Q1' * ΔQ2 - check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol) - ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) + ΔQ₁ = view(ΔQ, :, 1:p) + copy!(ΔQ̃, ΔQ₁) + if minmn < size(Q, 2) + ΔQ3 = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full + Q3 = view(Q, :, (minmn + 1):size(Q, 2)) + Q1ᴴΔQ3 = Q1' * ΔQ3 + ΔQ̃ = mul!(ΔQ̃, Q3, Q1ᴴΔQ3', -1, 1) end end if !iszerotangent(ΔR) && n > p @@ -103,7 +89,7 @@ function qr_pullback!( # construct M M = zero!(similar(R, (p, p))) if !iszerotangent(ΔR) - ΔR11 = view(ΔR, 1:p, 1:p) + ΔR11 = UpperTriangular(view(ΔR, 1:p, 1:p)) M = mul!(M, ΔR11, R11', 1, 1) end M = mul!(M, Q1', ΔQ̃, -1, 1) @@ -112,8 +98,8 @@ function qr_pullback!( Md = diagview(M) Md .= real.(Md) end - rdiv!(M, UpperTriangular(R11)') - rdiv!(ΔQ̃, UpperTriangular(R11)') + rdiv!(M, R11') # R11 is upper triangular + rdiv!(ΔQ̃, R11') ΔA1 = mul!(ΔA1, Q1, M, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 3b32cfc9..54d754ee 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -78,13 +78,15 @@ ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out. """ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R)) r = MatrixAlgebraKit.qr_rank(R; rank_atol) - Q₁ = @view Q[:, 1:r] - ΔQ₂ = @view ΔQ[:, (r + 1):end] + minmn = min(size(A)...) + Q₁ = view(Q, :, 1:r) + ΔQ₂ = view(ΔQ, :, (r + 1):minmn) ΔQ₂ .= 0 - # TODO: refine this by differentiating between rank deficiency and qr_full cases - # Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ - # mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) - view(ΔR, (r + 1):size(ΔR, 1), :) .= 0 + ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) + ΔR22 = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) + view(ΔR22, MatrixAlgebraKit.uppertriangularind(ΔR22)) .= 0 return ΔQ, ΔR end @@ -110,13 +112,15 @@ Additionally, columns of `ΔL` beyond the rank are zeroed out. """ function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) r = MatrixAlgebraKit.lq_rank(L; rank_atol) - Q₁ = @view Q[1:r, :] - ΔQ₂ = @view ΔQ[(r + 1):end, :] + minmn = min(size(A)...) + Q₁ = view(Q, 1:r, :) + ΔQ₂ = view(ΔQ, (r + 1):minmn, :) ΔQ₂ .= 0 - # TODO: refine this by differentiating between rank deficiency and lq_full cases - # ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' - # mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) - view(ΔL, :, (r + 1):size(ΔL, 2)) .= 0 + ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) + ΔL22 = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) + view(ΔL22, MatrixAlgebraKit.lowertriangularind(ΔL22)) .= 0 return ΔL, ΔQ end @@ -242,22 +246,22 @@ end function ad_qr_compact_setup(A) QR = qr_compact(A) - ΔQR = structured_randn!.(copy.(QR)) - A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...) + ΔQR = structured_randn!.(similar.(QR)) + remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end function ad_qr_null_setup(A) N = qr_null(A) - ΔN = randn!(copy(N)) + ΔN = structured_randn!(similar(N)) remove_qr_null_gauge_dependence!(ΔN, A, N) return N, ΔN end function ad_qr_full_setup(A) QR = qr_full(A) - ΔQR = structured_randn!.(copy.(QR)) - A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...) + ΔQR = structured_randn!.(similar.(QR)) + remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end @@ -297,23 +301,22 @@ end function ad_lq_compact_setup(A) LQ = lq_compact(A) - ΔLQ = structured_randn!.(copy.(LQ)) - A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + ΔLQ = structured_randn!.(similar.(LQ)) + remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end function ad_lq_null_setup(A) - T = eltype(A) Nᴴ = lq_null(A) - ΔNᴴ = randn!(similar(A, T, size(Nᴴ)...)) + ΔNᴴ = structured_randn!(similar(Nᴴ)) remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) return Nᴴ, ΔNᴴ end function ad_lq_full_setup(A) LQ = lq_full(A) - ΔLQ = structured_randn!.(copy.(LQ)) - A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + ΔLQ = structured_randn!.(similar.(LQ)) + remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end From cb2f4868586d9ec2a7afcbac1e95f349b93e8bde Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sun, 8 Mar 2026 22:33:59 +0100 Subject: [PATCH 3/6] one more attempt --- src/pullbacks/lq.jl | 2 ++ src/pullbacks/qr.jl | 1 + test/mooncake/eig.jl | 2 +- test/mooncake/eigh.jl | 2 +- test/mooncake/lq.jl | 4 ++-- test/mooncake/orthnull.jl | 2 +- test/mooncake/polar.jl | 2 +- test/mooncake/qr.jl | 4 ++-- test/mooncake/svd.jl | 2 +- test/testsuite/ad_utils.jl | 2 ++ 10 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 977929f1..885d0102 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -4,6 +4,7 @@ function check_lq_cotangents( L, Q, ΔL, ΔQ, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) + # check_qr_cotangents(Q', L', ΔQ', ΔL', p; gauge_atol) minmn = min(size(L, 1), size(Q, 2)) Δgauge = abs(zero(eltype(Q))) if !iszerotangent(ΔQ) @@ -19,6 +20,7 @@ function check_lq_cotangents( if !iszerotangent(ΔL) ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf) + Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf)) Δgauge = max(Δgauge, Δgauge_L) end Δgauge ≤ gauge_atol || diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 3a40460c..7cb4c06d 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -20,6 +20,7 @@ function check_qr_cotangents( if !iszerotangent(ΔR) ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf) + Δgauge_R = max(Δgauge_R, norm(view(ΔR22, diagind(ΔR22)), Inf)) Δgauge = max(Δgauge, Δgauge_R) end Δgauge ≤ gauge_atol || diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index 2e8a8606..b313f9b2 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index 5528af0f..800dbaa0 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 6c9f8fd4..0f05f85a 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 6f8dac9a..09e3a28c 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 9e4e366e..1faf3c10 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 9ffc4798..17415e8d 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index 982ec040..d2d40df4 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 54d754ee..d20920a5 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -86,6 +86,7 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) ΔR22 = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) + MatrixAlgebraKit.diagview(ΔR22) .= 0 view(ΔR22, MatrixAlgebraKit.uppertriangularind(ΔR22)) .= 0 return ΔQ, ΔR end @@ -120,6 +121,7 @@ function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebr ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) ΔL22 = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) + MatrixAlgebraKit.diagview(ΔL22) .= 0 view(ΔL22, MatrixAlgebraKit.lowertriangularind(ΔL22)) .= 0 return ΔL, ΔQ end From b235cef8dc7ce27442cbcf6660e681b426475594 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Mon, 9 Mar 2026 00:04:16 +0100 Subject: [PATCH 4/6] update enzyme --- test/enzyme/eig.jl | 2 +- test/enzyme/eigh.jl | 2 +- test/enzyme/lq.jl | 4 ++-- test/enzyme/orthnull.jl | 4 ++-- test/enzyme/polar.jl | 4 ++-- test/enzyme/projections.jl | 4 ++-- test/enzyme/qr.jl | 4 ++-- test/testsuite/enzyme/lq.jl | 2 +- test/testsuite/enzyme/qr.jl | 2 +- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl index 7ec7e1de..6536a47b 100644 --- a/test/enzyme/eig.jl +++ b/test/enzyme/eig.jl @@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl index e40adb50..9c862d27 100644 --- a/test/enzyme/eigh.jl +++ b/test/enzyme/eigh.jl @@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/enzyme/lq.jl b/test/enzyme/lq.jl index 6699ddfa..f7ae2ebf 100644 --- a/test/enzyme/lq.jl +++ b/test/enzyme/lq.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/enzyme/orthnull.jl b/test/enzyme/orthnull.jl index 2e7a554d..eaeae840 100644 --- a/test/enzyme/orthnull.jl +++ b/test/enzyme/orthnull.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/enzyme/polar.jl b/test/enzyme/polar.jl index 31b89907..6ab965ac 100644 --- a/test/enzyme/polar.jl +++ b/test/enzyme/polar.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/enzyme/projections.jl b/test/enzyme/projections.jl index d9a230d6..52b222a5 100644 --- a/test/enzyme/projections.jl +++ b/test/enzyme/projections.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) atol = rtol = m * m * TestSuite.precision(T) if !is_buildkite TestSuite.test_enzyme_projections(T, (m, m); atol, rtol) diff --git a/test/enzyme/qr.jl b/test/enzyme/qr.jl index 62a169c7..728e267d 100644 --- a/test/enzyme/qr.jl +++ b/test/enzyme/qr.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/testsuite/enzyme/lq.jl b/test/testsuite/enzyme/lq.jl index f84933a5..e4aa8d8e 100644 --- a/test/testsuite/enzyme/lq.jl +++ b/test/testsuite/enzyme/lq.jl @@ -38,7 +38,7 @@ function test_enzyme_lq_compact_rank_deficient( r = min(m, n) - 5 A = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) alg = MatrixAlgebraKit.select_algorithm(lq_compact, A) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(A) + LQ, ΔLQ = ad_lq_compact_setup(A) test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) test_reverse(call_and_zero!, RT, (lq_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) end diff --git a/test/testsuite/enzyme/qr.jl b/test/testsuite/enzyme/qr.jl index 08d1abdb..1d9f33a5 100644 --- a/test/testsuite/enzyme/qr.jl +++ b/test/testsuite/enzyme/qr.jl @@ -38,7 +38,7 @@ function test_enzyme_qr_compact_rank_deficient( r = min(m, n) - 5 A = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) alg = MatrixAlgebraKit.select_algorithm(qr_compact, A) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(A) + QR, ΔQR = ad_qr_compact_setup(A) test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) test_reverse(call_and_zero!, RT, (qr_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) end From 550e7619640a02cd75c5efd461640a1593d9f25a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 10 Mar 2026 11:32:47 -0400 Subject: [PATCH 5/6] unicode subscripts --- src/pullbacks/lq.jl | 44 +++++++++++++++++++------------------- src/pullbacks/qr.jl | 40 +++++++++++++++++----------------- test/testsuite/ad_utils.jl | 12 +++++------ 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 885d0102..ec7a8a1b 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -59,48 +59,48 @@ function lq_pullback!( ΔL, ΔQ = ΔLQ - Q1 = view(Q, 1:p, :) - L11 = LowerTriangular(view(L, 1:p, 1:p)) - ΔA1 = view(ΔA, 1:p, :) - ΔA2 = view(ΔA, (p + 1):m, :) + Q₁ = view(Q, 1:p, :) + L₁₁ = LowerTriangular(view(L, 1:p, 1:p)) + ΔA₁ = view(ΔA, 1:p, :) + ΔA₂ = view(ΔA, (p + 1):m, :) check_lq_cotangents(L, Q, ΔL, ΔQ, p; gauge_atol) ΔQ̃ = zero!(similar(Q, (p, n))) if !iszerotangent(ΔQ) - ΔQ1 = view(ΔQ, 1:p, :) - copy!(ΔQ̃, ΔQ1) + ΔQ₁ = view(ΔQ, 1:p, :) + copy!(ΔQ̃, ΔQ₁) if minmn < size(Q, 1) - ΔQ3 = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) - Q3 = view(Q, (minmn + 1):size(Q, 1), :) - ΔQ3Q1ᴴ = ΔQ3 * Q1' - ΔQ̃ = mul!(ΔQ̃, ΔQ3Q1ᴴ', Q3, -1, 1) + ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) + Q₃ = view(Q, (minmn + 1):size(Q, 1), :) + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + ΔQ̃ = mul!(ΔQ̃, ΔQ₃Q₁ᴴ', Q₃, -1, 1) end end if !iszerotangent(ΔL) && m > p - L21 = view(L, (p + 1):m, 1:p) - ΔL21 = view(ΔL, (p + 1):m, 1:p) - ΔQ̃ = mul!(ΔQ̃, L21' * ΔL21, Q1, -1, 1) - # Adding ΔA2 contribution - ΔA2 = mul!(ΔA2, ΔL21, Q1, 1, 1) + L₂₁ = view(L, (p + 1):m, 1:p) + ΔL₂₁ = view(ΔL, (p + 1):m, 1:p) + ΔQ̃ = mul!(ΔQ̃, L₂₁' * ΔL₂₁, Q₁, -1, 1) + # Adding ΔA₂ contribution + ΔA₂ = mul!(ΔA₂, ΔL₂₁, Q₁, 1, 1) end # construct M M = zero!(similar(L, (p, p))) if !iszerotangent(ΔL) - ΔL11 = LowerTriangular(view(ΔL, 1:p, 1:p)) - M = mul!(M, L11', ΔL11, 1, 1) + ΔL₁₁ = LowerTriangular(view(ΔL, 1:p, 1:p)) + M = mul!(M, L₁₁', ΔL₁₁, 1, 1) end - M = mul!(M, ΔQ̃, Q1', -1, 1) + M = mul!(M, ΔQ̃, Q₁', -1, 1) view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M))) if eltype(M) <: Complex Md = diagview(M) Md .= real.(Md) end - ldiv!(L11', M) - ldiv!(L11', ΔQ̃) - ΔA1 = mul!(ΔA1, M, Q1, +1, 1) - ΔA1 .+= ΔQ̃ + ldiv!(L₁₁', M) + ldiv!(L₁₁', ΔQ̃) + ΔA₁ = mul!(ΔA₁, M, Q₁, +1, 1) + ΔA₁ .+= ΔQ̃ return ΔA end diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 7cb4c06d..70c5aa89 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -61,10 +61,10 @@ function qr_pullback!( ΔQ, ΔR = ΔQR - Q1 = view(Q, :, 1:p) - R11 = UpperTriangular(view(R, 1:p, 1:p)) - ΔA1 = view(ΔA, :, 1:p) - ΔA2 = view(ΔA, :, (p + 1):n) + Q₁ = view(Q, :, 1:p) + R₁₁ = UpperTriangular(view(R, 1:p, 1:p)) + ΔA₁ = view(ΔA, :, 1:p) + ΔA₂ = view(ΔA, :, (p + 1):n) check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol) @@ -73,36 +73,36 @@ function qr_pullback!( ΔQ₁ = view(ΔQ, :, 1:p) copy!(ΔQ̃, ΔQ₁) if minmn < size(Q, 2) - ΔQ3 = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full - Q3 = view(Q, :, (minmn + 1):size(Q, 2)) - Q1ᴴΔQ3 = Q1' * ΔQ3 - ΔQ̃ = mul!(ΔQ̃, Q3, Q1ᴴΔQ3', -1, 1) + ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full + Q₃ = view(Q, :, (minmn + 1):size(Q, 2)) + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + ΔQ̃ = mul!(ΔQ̃, Q₃, Q₁ᴴΔQ₃', -1, 1) end end if !iszerotangent(ΔR) && n > p - R12 = view(R, 1:p, (p + 1):n) - ΔR12 = view(ΔR, 1:p, (p + 1):n) - ΔQ̃ = mul!(ΔQ̃, Q1, ΔR12 * R12', -1, 1) - # Adding ΔA2 contribution - ΔA2 = mul!(ΔA2, Q1, ΔR12, 1, 1) + R₁₂ = view(R, 1:p, (p + 1):n) + ΔR₁₂ = view(ΔR, 1:p, (p + 1):n) + ΔQ̃ = mul!(ΔQ̃, Q₁, ΔR₁₂ * R₁₂', -1, 1) + # Adding ΔA₂ contribution + ΔA₂ = mul!(ΔA₂, Q₁, ΔR₁₂, 1, 1) end # construct M M = zero!(similar(R, (p, p))) if !iszerotangent(ΔR) - ΔR11 = UpperTriangular(view(ΔR, 1:p, 1:p)) - M = mul!(M, ΔR11, R11', 1, 1) + ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p)) + M = mul!(M, ΔR₁₁, R₁₁', 1, 1) end - M = mul!(M, Q1', ΔQ̃, -1, 1) + M = mul!(M, Q₁', ΔQ̃, -1, 1) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = diagview(M) Md .= real.(Md) end - rdiv!(M, R11') # R11 is upper triangular - rdiv!(ΔQ̃, R11') - ΔA1 = mul!(ΔA1, Q1, M, +1, 1) - ΔA1 .+= ΔQ̃ + rdiv!(M, R₁₁') # R₁₁ is upper triangular + rdiv!(ΔQ̃, R₁₁') + ΔA₁ = mul!(ΔA₁, Q₁, M, +1, 1) + ΔA₁ .+= ΔQ̃ return ΔA end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index d20920a5..44184d1a 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -85,9 +85,9 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) - ΔR22 = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) - MatrixAlgebraKit.diagview(ΔR22) .= 0 - view(ΔR22, MatrixAlgebraKit.uppertriangularind(ΔR22)) .= 0 + ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) + MatrixAlgebraKit.diagview(ΔR₂₂) .= 0 + view(ΔR₂₂, MatrixAlgebraKit.uppertriangularind(ΔR₂₂)) .= 0 return ΔQ, ΔR end @@ -120,9 +120,9 @@ function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebr ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) - ΔL22 = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) - MatrixAlgebraKit.diagview(ΔL22) .= 0 - view(ΔL22, MatrixAlgebraKit.lowertriangularind(ΔL22)) .= 0 + ΔL₂₂ = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) + MatrixAlgebraKit.diagview(ΔL₂₂) .= 0 + view(ΔL₂₂, MatrixAlgebraKit.lowertriangularind(ΔL₂₂)) .= 0 return ΔL, ΔQ end From 740d4738938383239890165c146e8196a0b600d0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 10 Mar 2026 11:33:23 -0400 Subject: [PATCH 6/6] remove debug code --- src/pullbacks/lq.jl | 1 - test/testsuite/ad_utils.jl | 85 -------------------------------------- 2 files changed, 86 deletions(-) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index ec7a8a1b..1a41c246 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -4,7 +4,6 @@ function check_lq_cotangents( L, Q, ΔL, ΔQ, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) - # check_qr_cotangents(Q', L', ΔQ', ΔL', p; gauge_atol) minmn = min(size(L, 1), size(Q, 2)) Δgauge = abs(zero(eltype(Q))) if !iszerotangent(ΔQ) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 44184d1a..32fc1748 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -267,40 +267,6 @@ function ad_qr_full_setup(A) return QR, ΔQR end -# function ad_qr_rank_deficient_compact_setup(A) -# m, n = size(A) -# minmn = min(m, n) -# T = eltype(A) -# r = minmn - 5 -# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) -# Q, R = qr_compact(Ard) -# QR = (Q, R) -# ΔQ = randn!(similar(A, T, m, minmn)) -# Q1 = view(Q, 1:m, 1:r) -# Q2 = view(Q, 1:m, (r + 1):minmn) -# ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) -# MatrixAlgebraKit.zero!(ΔQ2) -# ΔR = randn!(similar(A, T, minmn, n)) -# view(ΔR, (r + 1):minmn, :) .= 0 -# return (Q, R), (ΔQ, ΔR) -# end - -# function ad_qr_rank_deficient_compact_setup(A::Diagonal) -# m, n = size(A) -# minmn = min(m, n) -# T = eltype(A) -# r = minmn - 5 -# Ard_ = randn!(similar(A, T, m)) -# MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) -# Ard = Diagonal(Ard_) -# Q, R = qr_compact(Ard) -# ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) -# ΔR = Diagonal(randn!(similar(diagview(A), T, m))) -# MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) -# MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) -# return (Q, R), (ΔQ, ΔR) -# end - function ad_lq_compact_setup(A) LQ = lq_compact(A) ΔLQ = structured_randn!.(similar.(LQ)) @@ -322,24 +288,6 @@ function ad_lq_full_setup(A) return LQ, ΔLQ end -# function ad_lq_rank_deficient_compact_setup(A) -# m, n = size(A) -# minmn = min(m, n) -# T = eltype(A) -# r = minmn - 5 -# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) -# L, Q = lq_compact(Ard) -# ΔL = randn!(similar(A, T, m, minmn)) -# ΔQ = randn!(similar(A, T, minmn, n)) -# Q1 = view(Q, 1:r, 1:n) -# Q2 = view(Q, (r + 1):minmn, 1:n) -# ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) -# ΔQ2 .= 0 -# view(ΔL, :, (r + 1):minmn) .= 0 -# return (L, Q), (ΔL, ΔQ) -# end -# ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) - function ad_eig_full_setup(A) D, V = eig_full(A) ΔD, ΔV = structured_randn!.(similar.((D, V))) @@ -347,31 +295,12 @@ function ad_eig_full_setup(A) return (D, V), (ΔD, ΔV) end -# function ad_eig_full_setup(A::Diagonal) -# m, n = size(A) -# T = complex(eltype(A)) -# DV = eig_full(A) -# D, V = DV -# ΔV = randn!(similar(A.diag, T, m, m)) -# ΔV = remove_eig_gauge_dependence!(ΔV, D, V) -# ΔD = Diagonal(randn!(similar(A.diag, T, m))) -# return DV, (ΔD, ΔV) -# end - function ad_eig_vals_setup(A) D = eig_vals(A) ΔD = randn!(similar(D)) return D, ΔD end -# function ad_eig_vals_setup(A::Diagonal) -# m, n = size(A) -# T = complex(eltype(A)) -# D = eig_vals(A) -# ΔD = randn!(similar(A.diag, T, m)) -# return D, ΔD -# end - function ad_eig_trunc_setup(A, truncalg) DV, ΔDV = ad_eig_full_setup(A) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) @@ -412,18 +341,6 @@ function ad_svd_compact_setup(A) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -# function ad_svd_compact_setup(A::Diagonal) -# m, n = size(A) -# T = eltype(A) -# minmn = min(m, n) -# ΔU = randn!(similar(A.diag, T, m, n)) -# ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) -# ΔVᴴ = randn!(similar(A.diag, T, m, n)) -# U, S, Vᴴ = svd_compact(A) -# ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) -# return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) -# end - function ad_svd_full_setup(A) U, S, Vᴴ = svd_full(A) ΔU = structured_randn!(similar(U)) @@ -434,8 +351,6 @@ function ad_svd_full_setup(A) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -# ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) - function ad_svd_vals_setup(A) S = svd_vals(A) ΔS = randn!(similar(S))