diff --git a/Project.toml b/Project.toml index e079639a..e851d669 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" -JET = "0.9, 0.10" +JET = "0.9, 0.10, 0.11" LinearAlgebra = "1" Mooncake = "0.5" ParallelTestRunner = "2" diff --git a/src/common/view.jl b/src/common/view.jl index e03bfb88..8cd989ea 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -24,7 +24,8 @@ diagonal(v::AbstractVector) = Diagonal(v) function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) m, n = size(A) - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) + minmn = min(m, n) + I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (m - minmn)) offset = 0 for j in 1:n r = (j + 1):m @@ -37,7 +38,8 @@ end function uppertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) m, n = size(A) - I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m)) + minmn = min(m, n) + I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (n - minmn)) offset = 0 for i in 1:m r = (i + 1):n diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 452eda92..885d0102 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -4,41 +4,28 @@ function check_lq_cotangents( L, Q, ΔL, ΔQ, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) + # check_qr_cotangents(Q', L', ΔQ', ΔL', p; gauge_atol) minmn = min(size(L, 1), size(Q, 2)) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # rows of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - Δgauge_Q = norm(ΔQ2, Inf) - Δgauge = max(Δgauge, Δgauge_Q) - end - if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn) - Δgauge_L = norm(ΔL22, Inf) - Δgauge = max(Δgauge, Δgauge_L) - end - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + ΔQ₂ = view(ΔQ, (p + 1):minmn, :) + ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :] + Δgauge_Q = norm(ΔQ₂, Inf) + Q₁ = view(Q, 1:p, :) + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1) + Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) + Δgauge = max(Δgauge, Δgauge_Q) + end + if !iszerotangent(ΔL) + ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn) + Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf) + Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf)) + Δgauge = max(Δgauge, Δgauge_L) end - return -end - -function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2)) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `lq_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) Δgauge ≤ gauge_atol || - @warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return nothing end """ @@ -67,13 +54,13 @@ function lq_pullback!( L, Q = LQ m = size(L, 1) n = size(Q, 2) + minmn = min(m, n) p = lq_rank(L; rank_atol) ΔL, ΔQ = ΔLQ Q1 = view(Q, 1:p, :) - Q2 = view(Q, (p + 1):size(Q, 1), :) - L11 = view(L, 1:p, 1:p) + L11 = LowerTriangular(view(L, 1:p, 1:p)) ΔA1 = view(ΔA, 1:p, :) ΔA2 = view(ΔA, (p + 1):m, :) @@ -83,12 +70,11 @@ function lq_pullback!( if !iszerotangent(ΔQ) ΔQ1 = view(ΔQ, 1:p, :) copy!(ΔQ̃, ΔQ1) - if p < size(Q, 1) - Q2 = view(Q, (p + 1):size(Q, 1), :) - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - ΔQ2Q1ᴴ = ΔQ2 * Q1' - check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol) - ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) + if minmn < size(Q, 1) + ΔQ3 = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) + Q3 = view(Q, (minmn + 1):size(Q, 1), :) + ΔQ3Q1ᴴ = ΔQ3 * Q1' + ΔQ̃ = mul!(ΔQ̃, ΔQ3Q1ᴴ', Q3, -1, 1) end end if !iszerotangent(ΔL) && m > p @@ -102,7 +88,7 @@ function lq_pullback!( # construct M M = zero!(similar(L, (p, p))) if !iszerotangent(ΔL) - ΔL11 = view(ΔL, 1:p, 1:p) + ΔL11 = LowerTriangular(view(ΔL, 1:p, 1:p)) M = mul!(M, L11', ΔL11, 1, 1) end M = mul!(M, ΔQ̃, Q1', -1, 1) @@ -111,8 +97,8 @@ function lq_pullback!( Md = diagview(M) Md .= real.(Md) end - ldiv!(LowerTriangular(L11)', M) - ldiv!(LowerTriangular(L11)', ΔQ̃) + ldiv!(L11', M) + ldiv!(L11', ΔQ̃) ΔA1 = mul!(ΔA1, M, Q1, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 643b2c68..7cb4c06d 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -6,40 +6,26 @@ function check_qr_cotangents( gauge_atol::Real = default_pullback_gauge_atol(ΔQ) ) minmn = min(size(Q, 1), size(R, 2)) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge_Q = norm(ΔQ2, Inf) - Δgauge = max(Δgauge, Δgauge_Q) - end - if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) - Δgauge_R = norm(ΔR22, Inf) - Δgauge = max(Δgauge, Δgauge_R) - end - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + ΔQ₂ = view(ΔQ, :, (p + 1):minmn) + ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full + Δgauge_Q = norm(ΔQ₂, Inf) + Q₁ = view(Q, :, 1:p) + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1) + Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf)) + Δgauge = max(Δgauge, Δgauge_Q) + end + if !iszerotangent(ΔR) + ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) + Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf) + Δgauge_R = max(Δgauge_R, norm(view(ΔR22, diagind(ΔR22)), Inf)) + Δgauge = max(Δgauge, Δgauge_R) end - return -end - -function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2)) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `qr_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) Δgauge ≤ gauge_atol || - @warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return nothing end """ @@ -69,13 +55,14 @@ function qr_pullback!( Q, R = QR m = size(Q, 1) n = size(R, 2) + minmn = min(m, n) Rd = diagview(R) p = qr_rank(R; rank_atol) ΔQ, ΔR = ΔQR Q1 = view(Q, :, 1:p) - R11 = view(R, 1:p, 1:p) + R11 = UpperTriangular(view(R, 1:p, 1:p)) ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) @@ -83,13 +70,13 @@ function qr_pullback!( ΔQ̃ = zero!(similar(Q, (m, p))) if !iszerotangent(ΔQ) - copy!(ΔQ̃, view(ΔQ, :, 1:p)) - if p < size(Q, 2) - Q2 = view(Q, :, (p + 1):size(Q, 2)) - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Q1dΔQ2 = Q1' * ΔQ2 - check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol) - ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) + ΔQ₁ = view(ΔQ, :, 1:p) + copy!(ΔQ̃, ΔQ₁) + if minmn < size(Q, 2) + ΔQ3 = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full + Q3 = view(Q, :, (minmn + 1):size(Q, 2)) + Q1ᴴΔQ3 = Q1' * ΔQ3 + ΔQ̃ = mul!(ΔQ̃, Q3, Q1ᴴΔQ3', -1, 1) end end if !iszerotangent(ΔR) && n > p @@ -103,7 +90,7 @@ function qr_pullback!( # construct M M = zero!(similar(R, (p, p))) if !iszerotangent(ΔR) - ΔR11 = view(ΔR, 1:p, 1:p) + ΔR11 = UpperTriangular(view(ΔR, 1:p, 1:p)) M = mul!(M, ΔR11, R11', 1, 1) end M = mul!(M, Q1', ΔQ̃, -1, 1) @@ -112,8 +99,8 @@ function qr_pullback!( Md = diagview(M) Md .= real.(Md) end - rdiv!(M, UpperTriangular(R11)') - rdiv!(ΔQ̃, UpperTriangular(R11)') + rdiv!(M, R11') # R11 is upper triangular + rdiv!(ΔQ̃, R11') ΔA1 = mul!(ΔA1, Q1, M, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 9c131464..01fdc4f7 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,4 +1,4 @@ -svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true) +svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true) function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV)) mask = abs.(Sr' .- Sr) .< degeneracy_atol @@ -43,7 +43,7 @@ function svd_pullback!( minmn = min(m, n) S = diagview(Smat) length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) - r = svd_rank(S, rank_atol) + r = svd_rank(S; rank_atol) Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) Sr = view(S, 1:r) diff --git a/test/enzyme/eig.jl b/test/enzyme/eig.jl index 7ec7e1de..6536a47b 100644 --- a/test/enzyme/eig.jl +++ b/test/enzyme/eig.jl @@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/enzyme/eigh.jl b/test/enzyme/eigh.jl index e40adb50..9c862d27 100644 --- a/test/enzyme/eigh.jl +++ b/test/enzyme/eigh.jl @@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/enzyme/lq.jl b/test/enzyme/lq.jl index 6699ddfa..f7ae2ebf 100644 --- a/test/enzyme/lq.jl +++ b/test/enzyme/lq.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/enzyme/orthnull.jl b/test/enzyme/orthnull.jl index 2e7a554d..eaeae840 100644 --- a/test/enzyme/orthnull.jl +++ b/test/enzyme/orthnull.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/enzyme/polar.jl b/test/enzyme/polar.jl index 31b89907..6ab965ac 100644 --- a/test/enzyme/polar.jl +++ b/test/enzyme/polar.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/enzyme/projections.jl b/test/enzyme/projections.jl index d9a230d6..52b222a5 100644 --- a/test/enzyme/projections.jl +++ b/test/enzyme/projections.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) atol = rtol = m * m * TestSuite.precision(T) if !is_buildkite TestSuite.test_enzyme_projections(T, (m, m); atol, rtol) diff --git a/test/enzyme/qr.jl b/test/enzyme/qr.jl index 62a169c7..728e267d 100644 --- a/test/enzyme/qr.jl +++ b/test/enzyme/qr.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index 2e8a8606..b313f9b2 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/mooncake/eigh.jl b/test/mooncake/eigh.jl index 5528af0f..800dbaa0 100644 --- a/test/mooncake/eigh.jl +++ b/test/mooncake/eigh.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end diff --git a/test/mooncake/lq.jl b/test/mooncake/lq.jl index 6c9f8fd4..0f05f85a 100644 --- a/test/mooncake/lq.jl +++ b/test/mooncake/lq.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/orthnull.jl b/test/mooncake/orthnull.jl index 6f8dac9a..09e3a28c 100644 --- a/test/mooncake/orthnull.jl +++ b/test/mooncake/orthnull.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/polar.jl b/test/mooncake/polar.jl index 9e4e366e..1faf3c10 100644 --- a/test/mooncake/polar.jl +++ b/test/mooncake/polar.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite atol = rtol = m * n * TestSuite.precision(T) m >= n && TestSuite.test_mooncake_left_polar(T, (m, n); atol, rtol) diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index 9ffc4798..17415e8d 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index 982ec040..d2d40df4 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) - TestSuite.seed_rng!(123) + TestSuite.seed_rng!(1234) if !is_buildkite TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 0e64d9e9..c662a067 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -11,7 +11,7 @@ module TestSuite using Test using MatrixAlgebraKit using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, norm, istriu, istril, I +using LinearAlgebra: Diagonal, norm, istriu, istril, I, mul! using Random, StableRNGs using Mooncake using AMDGPU, CUDA @@ -86,9 +86,9 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz) end instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A)))) -function instantiate_rank_deficient_matrix(T, sz; trunc = trunctol(rtol = 0.5)) +function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2))) A = instantiate_matrix(T, sz) - V, C = left_orth!(A; trunc = trunctol(rtol = 0.5)) + V, C = left_orth!(A; trunc) return mul!(A, V, C) end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 07528257..d20920a5 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -1,9 +1,13 @@ +structured_randn!(A::AbstractMatrix) = randn!(A) +structured_randn!(A::Diagonal) = (randn!(diagview(A)); return A) + """ - remove_eig_gauge_dependence!(ΔV, D, V) + remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The -eigenvectors are only determined up to complex phase (and unitary mixing for degenerate -eigenvalues), so the corresponding components of `ΔV` are projected out. +eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding components of +`ΔV` are projected out. """ function remove_eig_gauge_dependence!( ΔV, D, V; @@ -16,12 +20,12 @@ function remove_eig_gauge_dependence!( end """ - remove_eigh_gauge_dependence!(ΔV, D, V) + remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix -`V`. The eigenvectors are only determined up to complex phase (and unitary mixing for -degenerate eigenvalues), so the corresponding anti-Hermitian components of `V' * ΔV` are -projected out. +`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian +components of `V' * ΔV` are projected out. """ function remove_eigh_gauge_dependence!( ΔV, D, V; @@ -35,48 +39,55 @@ function remove_eigh_gauge_dependence!( end """ - remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = ..., rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The -singular vectors are only determined up to a common complex phase per singular value (and -unitary mixing for degenerate singular values), so the corresponding anti-Hermitian components -of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. For the full SVD, the extra columns of `U` -and rows of `Vᴴ` beyond `min(m, n)` are additionally zeroed out. +singular vectors are only determined up to a common complex phase per singular value (or a +unitary transformation across singular vectors associated with degenerate singular values), +so the corresponding anti-Hermitian components of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. +For the full SVD, the extra columns of `U` and rows of `Vᴴ` beyond the rank `r` are +additionally zeroed out, where `r = count(diagview(S) .> rank_atol)`. """ function remove_svd_gauge_dependence!( ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S), + rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S) ) - minmn = length(diagview(S)) - U₁ = view(U, :, 1:minmn) - Vᴴ₁ = view(Vᴴ, 1:minmn, :) - ΔU₁ = view(ΔU, :, 1:minmn) - ΔVᴴ₁ = view(ΔVᴴ, 1:minmn, :) + r = MatrixAlgebraKit.svd_rank(diagview(S); rank_atol) + U₁ = view(U, :, 1:r) + Vᴴ₁ = view(Vᴴ, 1:r, :) + ΔU₁ = view(ΔU, :, 1:r) + ΔVᴴ₁ = view(ΔVᴴ, 1:r, :) Sdiag = diagview(S) gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) gaugepart = project_antihermitian!(gaugepart) gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 mul!(ΔU₁, U₁, gaugepart, -1, 1) - ΔU[:, (minmn + 1):end] .= 0 - ΔVᴴ[(minmn + 1):end, :] .= 0 + ΔU[:, (r + 1):end] .= 0 + ΔVᴴ[(r + 1):end, :] .= 0 return ΔU, ΔVᴴ end """ - remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R) + remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔQ` and `ΔR` of the QR factors `Q` and `R`. For the full QR decomposition, the extra columns of `Q` beyond the rank `r` are not uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out. """ -function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R) - r = MatrixAlgebraKit.qr_rank(R) - Q₁ = @view Q[:, 1:r] - ΔQ₂ = @view ΔQ[:, (r + 1):end] - Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ - mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) - view(ΔR, (r + 1):size(ΔR, 1), :) .= 0 +function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R)) + r = MatrixAlgebraKit.qr_rank(R; rank_atol) + minmn = min(size(A)...) + Q₁ = view(Q, :, 1:r) + ΔQ₂ = view(ΔQ, :, (r + 1):minmn) + ΔQ₂ .= 0 + ΔQ₃ = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full + Q₁ᴴΔQ₃ = Q₁' * ΔQ₃ + mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃) + ΔR22 = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2)) + MatrixAlgebraKit.diagview(ΔR22) .= 0 + view(ΔR22, MatrixAlgebraKit.uppertriangularind(ΔR22)) .= 0 return ΔQ, ΔR end @@ -93,20 +104,25 @@ function remove_qr_null_gauge_dependence!(ΔN, A, N) end """ - remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q) + remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and `Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. Additionally, columns of `ΔL` beyond the rank are zeroed out. """ -function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q) - r = MatrixAlgebraKit.lq_rank(L) - Q₁ = @view Q[1:r, :] - ΔQ₂ = @view ΔQ[(r + 1):end, :] - ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' - mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) - view(ΔL, :, (r + 1):size(ΔL, 2)) .= 0 +function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) + r = MatrixAlgebraKit.lq_rank(L; rank_atol) + minmn = min(size(A)...) + Q₁ = view(Q, 1:r, :) + ΔQ₂ = view(ΔQ, (r + 1):minmn, :) + ΔQ₂ .= 0 + ΔQ₃ = view(ΔQ, (minmn + 1):size(ΔQ, 1), :) # extra rows in the case of lq_full + ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁' + mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁) + ΔL22 = view(ΔL, (r + 1):size(ΔL, 1), (r + 1):minmn) + MatrixAlgebraKit.diagview(ΔL22) .= 0 + view(ΔL22, MatrixAlgebraKit.lowertriangularind(ΔL22)) .= 0 return ΔL, ΔQ end @@ -130,11 +146,7 @@ Remove the gauge-dependent part from the cotangent `ΔN` of the left null space space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column span of the compact QR factor `Q₁` of `A`. """ -function remove_left_null_gauge_dependence!(ΔN, A, N) - Q, _ = qr_compact(A) - mul!(ΔN, Q, Q' * ΔN) - return ΔN -end +remove_left_null_gauge_dependence!(ΔN, A, N) = remove_qr_null_gauge_dependence!(ΔN, A, N) """ remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) @@ -143,11 +155,7 @@ Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null sp null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of the compact LQ factor `Q₁` of `A`. """ -function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - _, Q = lq_compact(A) - mul!(ΔNᴴ, ΔNᴴ * Q', Q) - return ΔNᴴ -end +remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) """ call_and_zero!(f!, A, alg) @@ -240,149 +248,129 @@ end function ad_qr_compact_setup(A) QR = qr_compact(A) - ΔQR = randn!.(copy.(QR)) + ΔQR = structured_randn!.(similar.(QR)) remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end -function ad_qr_compact_setup(A::Diagonal) - m, n = size(A) - minmn = min(m, n) - QR = qr_compact(A) - T = eltype(A) - ΔQ = Diagonal(randn!(similar(A.diag, T, m))) - ΔR = Diagonal(randn!(similar(A.diag, T, m))) - return QR, (ΔQ, ΔR) -end - function ad_qr_null_setup(A) N = qr_null(A) - ΔN = randn!(copy(N)) + ΔN = structured_randn!(similar(N)) remove_qr_null_gauge_dependence!(ΔN, A, N) return N, ΔN end function ad_qr_full_setup(A) QR = qr_full(A) - ΔQR = randn!.(copy.(QR)) + ΔQR = structured_randn!.(similar.(QR)) remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end -ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) -function ad_qr_rank_deficient_compact_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) - Q, R = qr_compact(Ard) - QR = (Q, R) - ΔQ = randn!(similar(A, T, m, minmn)) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - MatrixAlgebraKit.zero!(ΔQ2) - ΔR = randn!(similar(A, T, minmn, n)) - view(ΔR, (r + 1):minmn, :) .= 0 - return (Q, R), (ΔQ, ΔR) -end - -function ad_qr_rank_deficient_compact_setup(A::Diagonal) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard_ = randn!(similar(A, T, m)) - MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) - Ard = Diagonal(Ard_) - Q, R = qr_compact(Ard) - ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) - ΔR = Diagonal(randn!(similar(diagview(A), T, m))) - MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) - MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) - return (Q, R), (ΔQ, ΔR) -end +# function ad_qr_rank_deficient_compact_setup(A) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) +# Q, R = qr_compact(Ard) +# QR = (Q, R) +# ΔQ = randn!(similar(A, T, m, minmn)) +# Q1 = view(Q, 1:m, 1:r) +# Q2 = view(Q, 1:m, (r + 1):minmn) +# ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) +# MatrixAlgebraKit.zero!(ΔQ2) +# ΔR = randn!(similar(A, T, minmn, n)) +# view(ΔR, (r + 1):minmn, :) .= 0 +# return (Q, R), (ΔQ, ΔR) +# end + +# function ad_qr_rank_deficient_compact_setup(A::Diagonal) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard_ = randn!(similar(A, T, m)) +# MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) +# Ard = Diagonal(Ard_) +# Q, R = qr_compact(Ard) +# ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) +# ΔR = Diagonal(randn!(similar(diagview(A), T, m))) +# MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) +# MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) +# return (Q, R), (ΔQ, ΔR) +# end function ad_lq_compact_setup(A) LQ = lq_compact(A) - ΔLQ = randn!.(copy.(LQ)) + ΔLQ = structured_randn!.(similar.(LQ)) remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end -ad_lq_compact_setup(A::Diagonal) = ad_qr_compact_setup(A) function ad_lq_null_setup(A) - T = eltype(A) Nᴴ = lq_null(A) - ΔNᴴ = randn!(similar(A, T, size(Nᴴ)...)) + ΔNᴴ = structured_randn!(similar(Nᴴ)) remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) return Nᴴ, ΔNᴴ end function ad_lq_full_setup(A) LQ = lq_full(A) - ΔLQ = randn!.(copy.(LQ)) + ΔLQ = structured_randn!.(similar.(LQ)) remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end -ad_lq_full_setup(A::Diagonal) = ad_qr_full_setup(A) -function ad_lq_rank_deficient_compact_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) - L, Q = lq_compact(Ard) - ΔL = randn!(similar(A, T, m, minmn)) - ΔQ = randn!(similar(A, T, minmn, n)) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - return (L, Q), (ΔL, ΔQ) -end -ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) +# function ad_lq_rank_deficient_compact_setup(A) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) +# L, Q = lq_compact(Ard) +# ΔL = randn!(similar(A, T, m, minmn)) +# ΔQ = randn!(similar(A, T, minmn, n)) +# Q1 = view(Q, 1:r, 1:n) +# Q2 = view(Q, (r + 1):minmn, 1:n) +# ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) +# ΔQ2 .= 0 +# view(ΔL, :, (r + 1):minmn) .= 0 +# return (L, Q), (ΔL, ΔQ) +# end +# ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) function ad_eig_full_setup(A) - m, n = size(A) - T = eltype(A) - DV = eig_full(A) - D, V = DV - ΔV = randn!(similar(A, complex(T), m, m)) + D, V = eig_full(A) + ΔD, ΔV = structured_randn!.(similar.((D, V))) ΔV = remove_eig_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A, complex(T), m))) - return DV, (ΔD, ΔV) + return (D, V), (ΔD, ΔV) end -function ad_eig_full_setup(A::Diagonal) - m, n = size(A) - T = complex(eltype(A)) - DV = eig_full(A) - D, V = DV - ΔV = randn!(similar(A.diag, T, m, m)) - ΔV = remove_eig_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A.diag, T, m))) - return DV, (ΔD, ΔV) -end +# function ad_eig_full_setup(A::Diagonal) +# m, n = size(A) +# T = complex(eltype(A)) +# DV = eig_full(A) +# D, V = DV +# ΔV = randn!(similar(A.diag, T, m, m)) +# ΔV = remove_eig_gauge_dependence!(ΔV, D, V) +# ΔD = Diagonal(randn!(similar(A.diag, T, m))) +# return DV, (ΔD, ΔV) +# end function ad_eig_vals_setup(A) - m, n = size(A) - T = complex(eltype(A)) D = eig_vals(A) - ΔD = randn!(similar(A, complex(T), m)) + ΔD = randn!(similar(D)) return D, ΔD end -function ad_eig_vals_setup(A::Diagonal) - m, n = size(A) - T = complex(eltype(A)) - D = eig_vals(A) - ΔD = randn!(similar(A.diag, T, m)) - return D, ΔD -end +# function ad_eig_vals_setup(A::Diagonal) +# m, n = size(A) +# T = complex(eltype(A)) +# D = eig_vals(A) +# ΔD = randn!(similar(A.diag, T, m)) +# return D, ΔD +# end function ad_eig_trunc_setup(A, truncalg) DV, ΔDV = ad_eig_full_setup(A) @@ -395,21 +383,15 @@ function ad_eig_trunc_setup(A, truncalg) end function ad_eigh_full_setup(A) - m, n = size(A) - T = eltype(A) - DV = eigh_full(A) - D, V = DV - ΔV = randn!(similar(A, T, m, m)) + D, V = eigh_full(A) + ΔD, ΔV = structured_randn!.(similar.((D, V))) ΔV = remove_eigh_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A, real(T), m))) - return DV, (ΔD, ΔV) + return (D, V), (ΔD, ΔV) end function ad_eigh_vals_setup(A) - m, n = size(A) - T = eltype(A) D = eigh_vals(A) - ΔD = randn!(similar(A, real(T), m)) + ΔD = randn!(similar(D)) return D, ΔD end @@ -424,55 +406,39 @@ function ad_eigh_trunc_setup(A, truncalg) end function ad_svd_compact_setup(A) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - ΔU = randn!(similar(A, T, m, minmn)) - ΔS = Diagonal(randn!(similar(A, real(T), minmn))) - ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) + ΔU, ΔS, ΔVᴴ = structured_randn!.(similar.((U, S, Vᴴ))) ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -function ad_svd_compact_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - ΔU = randn!(similar(A.diag, T, m, n)) - ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) - ΔVᴴ = randn!(similar(A.diag, T, m, n)) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) - return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) -end +# function ad_svd_compact_setup(A::Diagonal) +# m, n = size(A) +# T = eltype(A) +# minmn = min(m, n) +# ΔU = randn!(similar(A.diag, T, m, n)) +# ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) +# ΔVᴴ = randn!(similar(A.diag, T, m, n)) +# U, S, Vᴴ = svd_compact(A) +# ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) +# return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) +# end function ad_svd_full_setup(A) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - (_, _, _), (ΔU, ΔS, ΔVᴴ) = ad_svd_compact_setup(A) - ΔUfull = similar(A, T, m, m) - ΔUfull .= zero(T) - ΔSfull = similar(A, real(T), m, n) - ΔSfull .= zero(real(T)) - ΔVᴴfull = similar(A, T, n, n) - ΔVᴴfull .= zero(T) U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS) - return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) + ΔU = structured_randn!(similar(U)) + ΔVᴴ = structured_randn!(similar(Vᴴ)) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔS = zero(S) + randn!(diagview(ΔS)) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) +# ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) function ad_svd_vals_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) S = svd_vals(A) - ΔS = randn!(similar(A, real(T), minmn)) + ΔS = randn!(similar(S)) return S, ΔS end @@ -489,48 +455,25 @@ function ad_svd_trunc_setup(A, truncalg) end function ad_left_polar_setup(A) - m, n = size(A) - T = eltype(A) WP = left_polar(A) - ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) - return WP, ΔWP -end - -function ad_left_polar_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) - WP = left_polar(A) - ΔWP = (Diagonal(randn!(similar(A.diag))), randn!(similar(WP[2]))) + ΔWP = structured_randn!.(similar.(WP)) return WP, ΔWP end function ad_right_polar_setup(A) - m, n = size(A) - T = eltype(A) - PWᴴ = right_polar(A) - ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) - return PWᴴ, ΔPWᴴ -end -function ad_right_polar_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) PWᴴ = right_polar(A) - ΔPWᴴ = (randn!(similar(PWᴴ[1])), Diagonal(randn!(similar(A.diag)))) + ΔPWᴴ = structured_randn!.(similar.(PWᴴ)) return PWᴴ, ΔPWᴴ end function ad_left_orth_setup(A) - m, n = size(A) - T = eltype(A) VC = left_orth(A) - ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) + ΔVC = structured_randn!.(similar.(VC)) return VC, ΔVC end function ad_left_orth_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) VC = left_orth(A) - ΔVC = (Diagonal(randn!(similar(A.diag, T, m))), Diagonal(randn!(similar(A.diag, T, m)))) + ΔVC = structured_randn!.(similar.(VC)) return VC, ΔVC end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 52806750..c722fdb6 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -17,14 +17,14 @@ for f in @eval begin function $copy_f(input, alg) if $_hermitian - input = (input + input') / 2 + input = project_hermitian(input) end return $f(input, alg) end function ChainRulesCore.rrule(::typeof($copy_f), input, alg) output = MatrixAlgebraKit.initialize_output($f!, input, alg) if $_hermitian - input = (input + input') / 2 + input = project_hermitian(input) else input = copy(input) end @@ -113,7 +113,7 @@ function test_chainrules_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + QR, ΔQR = ad_qr_compact_setup(Ard) ΔQ, ΔR = ΔQR test_rrule( cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); @@ -190,7 +190,7 @@ function test_chainrules_lq( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + LQ, ΔLQ = ad_lq_compact_setup(Ard) test_rrule( cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol diff --git a/test/testsuite/enzyme/lq.jl b/test/testsuite/enzyme/lq.jl index f84933a5..e4aa8d8e 100644 --- a/test/testsuite/enzyme/lq.jl +++ b/test/testsuite/enzyme/lq.jl @@ -38,7 +38,7 @@ function test_enzyme_lq_compact_rank_deficient( r = min(m, n) - 5 A = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) alg = MatrixAlgebraKit.select_algorithm(lq_compact, A) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(A) + LQ, ΔLQ = ad_lq_compact_setup(A) test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) test_reverse(call_and_zero!, RT, (lq_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) end diff --git a/test/testsuite/enzyme/qr.jl b/test/testsuite/enzyme/qr.jl index 08d1abdb..1d9f33a5 100644 --- a/test/testsuite/enzyme/qr.jl +++ b/test/testsuite/enzyme/qr.jl @@ -38,7 +38,7 @@ function test_enzyme_qr_compact_rank_deficient( r = min(m, n) - 5 A = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) alg = MatrixAlgebraKit.select_algorithm(qr_compact, A) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(A) + QR, ΔQR = ad_qr_compact_setup(A) test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) test_reverse(call_and_zero!, RT, (qr_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) end