Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ for f in (:eig, :eigh)
_warn_pullback_truncerror(dϵ)

# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
Comment thread
Jutho marked this conversation as resolved.
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required

# restore state
Expand Down Expand Up @@ -351,8 +351,8 @@ for f in (:eig, :eigh)
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
function $f_adjoint!(::NoRData)
# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDV)

# restore state
copy!(A, Ac)
Expand Down Expand Up @@ -425,7 +425,7 @@ for (f!, f) in (
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
if $(f! == svd_compact!)
Expand Down Expand Up @@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
_warn_pullback_truncerror(dϵ)

# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
zero!.(dUSVᴴ)

Expand Down Expand Up @@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
function svd_trunc_adjoint(::NoRData)
# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
Comment thread
Jutho marked this conversation as resolved.
zero!.(dUSVᴴ)

# restore state
Expand Down
22 changes: 15 additions & 7 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
@something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0

function check_qr_cotangents(
Q, R, ΔQ, ΔR, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
Expand All @@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge = max(Δgauge, norm(ΔR22, Inf))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Expand All @@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@warn "`qr_full` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

Expand Down Expand Up @@ -60,9 +69,8 @@ function qr_pullback!(
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0
p = qr_rank(R)

ΔQ, ΔR = ΔQR

Expand All @@ -72,7 +80,7 @@ function qr_pullback!(
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
Expand Down
29 changes: 0 additions & 29 deletions test/mooncake.jl

This file was deleted.

19 changes: 19 additions & 0 deletions test/mooncake/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/lq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/orthnull.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/mooncake/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_mooncake_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ if filter_tests!(testsuite, args)
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
if is_apple_ci
delete!(testsuite, "enzyme")
delete!(testsuite, "mooncake")
filter!(p -> !startswith(first(p), "mooncake/"), testsuite)
delete!(testsuite, "chainrules")
end
Sys.iswindows() && delete!(testsuite, "enzyme")
Expand Down
33 changes: 24 additions & 9 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: diagview
using LinearAlgebra: Diagonal, norm, istriu, istril, I
using Random, StableRNGs
using Mooncake
using AMDGPU, CUDA

const tests = Dict()
Expand Down Expand Up @@ -86,16 +87,30 @@ instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A

include("ad_utils.jl")

include("qr.jl")
include("lq.jl")
include("polar.jl")
include("projections.jl")
include("schur.jl")
include("eig.jl")
include("eigh.jl")
include("orthnull.jl")
include("svd.jl")
include("mooncake.jl")

# Decompositions
# --------------
include("decompositions/qr.jl")
include("decompositions/lq.jl")
include("decompositions/polar.jl")
include("decompositions/schur.jl")
include("decompositions/eig.jl")
include("decompositions/eigh.jl")
include("decompositions/orthnull.jl")
include("decompositions/svd.jl")

# Mooncake
# --------
include("mooncake/mooncake.jl")
include("mooncake/qr.jl")
include("mooncake/lq.jl")
include("mooncake/eig.jl")
include("mooncake/eigh.jl")
include("mooncake/svd.jl")
include("mooncake/polar.jl")
include("mooncake/orthnull.jl")

include("enzyme.jl")
include("chainrules.jl")

Expand Down
Loading