Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Enzyme = "0.13.118"
EnzymeTestUtils = "0.2.5"
GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
JET = "0.9, 0.10, 0.11"
LinearAlgebra = "1"
Mooncake = "0.5"
ParallelTestRunner = "2"
Expand Down
6 changes: 4 additions & 2 deletions src/common/view.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ diagonal(v::AbstractVector) = Diagonal(v)
function lowertriangularind(A::AbstractMatrix)
Base.require_one_based_indexing(A)
m, n = size(A)
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
minmn = min(m, n)
I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (m - minmn))
offset = 0
for j in 1:n
r = (j + 1):m
Expand All @@ -37,7 +38,8 @@ end
function uppertriangularind(A::AbstractMatrix)
Base.require_one_based_indexing(A)
m, n = size(A)
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
minmn = min(m, n)
I = Vector{Int}(undef, div(minmn * (minmn - 1), 2) + minmn * (n - minmn))
offset = 0
for i in 1:m
r = (i + 1):n
Expand Down
72 changes: 29 additions & 43 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,28 @@ function check_lq_cotangents(
L, Q, ΔL, ΔQ, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
# check_qr_cotangents(Q', L', ΔQ', ΔL', p; gauge_atol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# check_qr_cotangents(Q', L', ΔQ', ΔL', p; gauge_atol)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is a debugging leftover.

minmn = min(size(L, 1), size(Q, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# rows of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
Δgauge_L = norm(ΔL22, Inf)
Δgauge = max(Δgauge, Δgauge_L)
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
ΔQ₂ = view(ΔQ, (p + 1):minmn, :)
ΔQ₃ = ΔQ[(minmn + 1):size(Q, 1), :]
Δgauge_Q = norm(ΔQ₂, Inf)
Q₁ = view(Q, 1:p, :)
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(ΔL, 1), (p + 1):minmn)
Δgauge_L = norm(view(ΔL22, lowertriangularind(ΔL22)), Inf)
Δgauge_L = max(Δgauge_L, norm(view(ΔL22, diagind(ΔL22)), Inf))
Δgauge = max(Δgauge, Δgauge_L)
end
return
end

function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `lq_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return nothing
end

"""
Expand Down Expand Up @@ -67,13 +54,13 @@ function lq_pullback!(
L, Q = LQ
m = size(L, 1)
n = size(Q, 2)
minmn = min(m, n)
p = lq_rank(L; rank_atol)

ΔL, ΔQ = ΔLQ

Q1 = view(Q, 1:p, :)
Q2 = view(Q, (p + 1):size(Q, 1), :)
L11 = view(L, 1:p, 1:p)
L11 = LowerTriangular(view(L, 1:p, 1:p))
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

Expand All @@ -83,12 +70,11 @@ function lq_pullback!(
if !iszerotangent(ΔQ)
ΔQ1 = view(ΔQ, 1:p, :)
copy!(ΔQ̃, ΔQ1)
if p < size(Q, 1)
Q2 = view(Q, (p + 1):size(Q, 1), :)
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
ΔQ2Q1ᴴ = ΔQ2 * Q1'
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
if minmn < size(Q, 1)
ΔQ3 = view(ΔQ, (minmn + 1):size(ΔQ, 1), :)
Q3 = view(Q, (minmn + 1):size(Q, 1), :)
ΔQ3Q1ᴴ = ΔQ3 * Q1'
ΔQ̃ = mul!(ΔQ̃, ΔQ3Q1ᴴ', Q3, -1, 1)
end
end
if !iszerotangent(ΔL) && m > p
Expand All @@ -102,7 +88,7 @@ function lq_pullback!(
# construct M
M = zero!(similar(L, (p, p)))
if !iszerotangent(ΔL)
ΔL11 = view(ΔL, 1:p, 1:p)
ΔL11 = LowerTriangular(view(ΔL, 1:p, 1:p))
M = mul!(M, L11', ΔL11, 1, 1)
end
M = mul!(M, ΔQ̃, Q1', -1, 1)
Expand All @@ -111,8 +97,8 @@ function lq_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ldiv!(LowerTriangular(L11)', M)
ldiv!(LowerTriangular(L11)', ΔQ̃)
ldiv!(L11', M)
ldiv!(L11', ΔQ̃)
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
Expand Down
73 changes: 30 additions & 43 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,26 @@ function check_qr_cotangents(
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
ΔQ₂ = view(ΔQ, :, (p + 1):minmn)
ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full
Δgauge_Q = norm(ΔQ₂, Inf)
Q₁ = view(Q, :, 1:p)
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1)
Δgauge_Q = max(Δgauge_Q, norm(ΔQ₃, Inf))
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge_R = norm(view(ΔR22, uppertriangularind(ΔR22)), Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR22, diagind(ΔR22)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
end
return
end

function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `qr_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return nothing
end

"""
Expand Down Expand Up @@ -69,27 +55,28 @@ function qr_pullback!(
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = qr_rank(R; rank_atol)

ΔQ, ΔR = ΔQR

Q1 = view(Q, :, 1:p)
R11 = view(R, 1:p, 1:p)
R11 = UpperTriangular(view(R, 1:p, 1:p))
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
copy!(ΔQ̃, view(ΔQ, :, 1:p))
if p < size(Q, 2)
Q2 = view(Q, :, (p + 1):size(Q, 2))
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Q1dΔQ2 = Q1' * ΔQ2
check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
ΔQ₁ = view(ΔQ, :, 1:p)
copy!(ΔQ̃, ΔQ₁)
if minmn < size(Q, 2)
ΔQ3 = view(ΔQ, :, (minmn + 1):size(ΔQ, 2)) # extra columns in the case of qr_full
Q3 = view(Q, :, (minmn + 1):size(Q, 2))
Q1ᴴΔQ3 = Q1' * ΔQ3
ΔQ̃ = mul!(ΔQ̃, Q3, Q1ᴴΔQ3', -1, 1)
end
end
if !iszerotangent(ΔR) && n > p
Expand All @@ -103,7 +90,7 @@ function qr_pullback!(
# construct M
M = zero!(similar(R, (p, p)))
if !iszerotangent(ΔR)
ΔR11 = view(ΔR, 1:p, 1:p)
ΔR11 = UpperTriangular(view(ΔR, 1:p, 1:p))
M = mul!(M, ΔR11, R11', 1, 1)
end
M = mul!(M, Q1', ΔQ̃, -1, 1)
Expand All @@ -112,8 +99,8 @@ function qr_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
rdiv!(M, UpperTriangular(R11)')
rdiv!(ΔQ̃, UpperTriangular(R11)')
rdiv!(M, R11') # R11 is upper triangular
rdiv!(ΔQ̃, R11')
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true)
svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true)

function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV))
mask = abs.(Sr' .- Sr) .< degeneracy_atol
Expand Down Expand Up @@ -43,7 +43,7 @@ function svd_pullback!(
minmn = min(m, n)
S = diagview(Smat)
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
r = svd_rank(S, rank_atol)
r = svd_rank(S; rank_atol)
Ur = view(U, :, 1:r)
Vᴴr = view(Vᴴ, 1:r, :)
Sr = view(S, 1:r)
Expand Down
2 changes: 1 addition & 1 deletion test/enzyme/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
Expand Down
2 changes: 1 addition & 1 deletion test/enzyme/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
Expand Down
4 changes: 2 additions & 2 deletions test/enzyme/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
Expand Down
4 changes: 2 additions & 2 deletions test/enzyme/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
Expand Down
4 changes: 2 additions & 2 deletions test/enzyme/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
Expand Down
4 changes: 2 additions & 2 deletions test/enzyme/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
atol = rtol = m * m * TestSuite.precision(T)
if !is_buildkite
TestSuite.test_enzyme_projections(T, (m, m); atol, rtol)
Expand Down
4 changes: 2 additions & 2 deletions test/enzyme/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
Expand Down
2 changes: 1 addition & 1 deletion test/mooncake/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
Expand Down
Loading
Loading