diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index e9e1e395..629972cf 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -58,9 +58,9 @@ end function MatrixAlgebraKit.householder_qr!( driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) - blocksize == 1 || + blocksize <= 1 || throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) pivoted && throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) @@ -102,9 +102,9 @@ end function MatrixAlgebraKit.householder_qr_null!( driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) - blocksize == 1 || + blocksize <= 1 || throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) pivoted && throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) diff --git a/src/algorithms.jl b/src/algorithms.jl index 2885e93c..629ec27f 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -19,7 +19,19 @@ See also [`@algdef`](@ref). """ struct Algorithm{name, K} <: AbstractAlgorithm kwargs::K + + # Ensure keywords are always in canonical order + function Algorithm{Name}(kwargs::NamedTuple) where {Name} + kwargs_sorted = _sortkeys(kwargs) + return new{Name, typeof(kwargs_sorted)}(kwargs_sorted) + end end +Algorithm{Name}(; kwargs...) where {Name} = Algorithm{Name}(NamedTuple(kwargs)) + +# Utility generated function to canonicalize keys in type-stable way +@generated _sortkeys(nt::NamedTuple{K}) where {K} = + :(NamedTuple{$(Tuple(sort!(collect(K))))}(nt)) + name(alg::Algorithm) = name(typeof(alg)) name(::Type{<:Algorithm{N}}) where {N} = N @@ -88,7 +100,9 @@ Finally, the same behavior is obtained when the keyword arguments are passed as the third positional argument in the form of a `NamedTuple`. """ select_algorithm -@inline function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg} +# WARNING: In order to keep everything type stable, this function is marked as foldable. +# This mostly means that the `default_algorithm` implementation must be foldable as well +Base.@assume_effects :foldable function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg} if isnothing(alg) return default_algorithm(f, A; kwargs...) elseif alg isa Symbol @@ -117,8 +131,10 @@ In general, this is called by [`select_algorithm`](@ref) if no algorithm is spec explicitly. New types should prefer to register their default algorithms in the type domain. """ default_algorithm -default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...) -default_algorithm(f::F, A, B; kwargs...) where {F} = default_algorithm(f, typeof(A), typeof(B); kwargs...) +@inline default_algorithm(f::F, A; kwargs...) where {F} = + default_algorithm(f, typeof(A); kwargs...) +@inline default_algorithm(f::F, A, B; kwargs...) where {F} = + default_algorithm(f, typeof(A), typeof(B); kwargs...) # avoid infinite recursion: function default_algorithm(f::F, ::Type{T}; kwargs...) where {F, T} throw(MethodError(default_algorithm, (f, T))) @@ -299,11 +315,6 @@ macro algdef(name) return esc( quote const $name{K} = Algorithm{$(QuoteNode(name)), K} - function $name(; kwargs...) - # TODO: is this necessary/useful? - kw = NamedTuple(kwargs) # normalize type - return $name{typeof(kw)}(kw) - end function Base.show(io::IO, alg::$name) return ($_show_alg)(io, alg) end diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index f27f6cde..cf64ccd6 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -120,9 +120,10 @@ householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) = lq_via_qr!(A, L, Q, Householder(; driver, kwargs...)) function householder_lq!( driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; - positive = true, pivoted = false, - blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) + positive = true, pivoted = false, blocksize::Int = 0 ) + blocksize = blocksize > 0 ? blocksize : ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) + # error messages for disallowing driver - setting combinations pivoted && (blocksize > 1) && throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition")) @@ -176,10 +177,10 @@ function householder_lq!( end function householder_lq!( driver::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) # error messages for disallowing driver - setting combinations - blocksize == 1 || + blocksize <= 1 || throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition")) pivoted && throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition")) @@ -225,8 +226,10 @@ householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs... lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...)) function householder_lq_null!( driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = pivoted ? 1 : YALAPACK.default_qr_blocksize(A) + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) + blocksize = blocksize > 0 ? blocksize : (pivoted ? 1 : YALAPACK.default_qr_blocksize(A)) + # error messages for disallowing driver - setting combinations pivoted && (blocksize > 1) && throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition")) @@ -248,10 +251,10 @@ function householder_lq_null!( end function householder_lq_null!( driver::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) # error messages for disallowing driver - setting combinations - blocksize == 1 || + blocksize <= 1 || throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition")) pivoted && throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition")) diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index 766d4dbc..f78d8c44 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -121,8 +121,10 @@ householder_qr!(::DefaultDriver, A, Q, R; kwargs...) = function householder_qr!( driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive::Bool = true, pivoted::Bool = false, - blocksize::Int = ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) + blocksize::Int = 0 ) + blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) + # error messages for disallowing driver - setting combinations (blocksize == 1 || driver === LAPACK()) || throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) @@ -202,10 +204,10 @@ function householder_qr!( end function householder_qr!( driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) # error messages for disallowing driver - setting combinations - blocksize == 1 || + blocksize <= 1 || throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) pivoted && throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) @@ -249,9 +251,9 @@ householder_qr_null!(::DefaultDriver, A, N; kwargs...) = householder_qr_null!(default_householder_driver(A), A, N; kwargs...) function householder_qr_null!( driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, - blocksize::Int = ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A)) + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) + blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A)) # error messages for disallowing driver - setting combinations (blocksize == 1 || driver === LAPACK()) || throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) @@ -277,10 +279,10 @@ function householder_qr_null!( end function householder_qr_null!( driver::Native, A::AbstractMatrix, N::AbstractMatrix; - positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0 ) # error messages for disallowing driver - setting combinations - blocksize == 1 || + blocksize <= 1 || throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) pivoted && throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 253cd30e..016d17b5 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -74,11 +74,17 @@ The optional `driver` symbol can be used to choose between different implementat - `positive::Bool = true` : Fix the gauge of the resulting factors by making the diagonal elements of `L` or `R` non-negative. - `pivoted::Bool = false` : Use column- or row-pivoting for low-rank input matrices. -- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`. +- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`. Use the default if `blocksize ≤ 0`. Depending on the driver, various other keywords may be (un)available to customize the implementation. """ @algdef Householder +function Householder(; + blocksize::Int = 0, driver::Driver = DefaultDriver(), + pivoted::Bool = false, positive::Bool = true + ) + return Householder((; blocksize, driver, pivoted, positive)) +end default_householder_driver(A) = default_householder_driver(typeof(A)) default_householder_driver(::Type) = Native() diff --git a/src/interface/lq.jl b/src/interface/lq.jl index e9c1c32b..9521e5ed 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -71,8 +71,8 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...) default_lq_algorithm(T::Type; kwargs...) = throw(MethodError(default_lq_algorithm, (T,))) -default_lq_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} = - Householder(; driver, kwargs...) +default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} = + Householder(; kwargs...) default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} = DiagonalAlgorithm(; kwargs...) default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index b20a841d..669613f5 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -334,79 +334,79 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth) and @inline select_algorithm(::typeof(right_null!), A, alg::Symbol; kwargs...) = select_algorithm(right_null!, A, Val(alg); kwargs...) -function select_algorithm(::typeof(left_orth!), A, ::Val{:qr}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(left_orth!), A, ::Val{:qr}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("QR-based `left_orth` is incompatible with specifying `trunc`")) alg′ = select_algorithm(qr_compact!, A; kwargs...) return LeftOrthViaQR(alg′) end -function select_algorithm(::typeof(left_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(left_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("Polar-based `left_orth` is incompatible with specifying `trunc`")) alg′ = select_algorithm(left_polar!, A; kwargs...) return LeftOrthViaPolar(alg′) end -function select_algorithm(::typeof(left_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(left_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A; kwargs...) : select_algorithm(svd_trunc!, A; trunc, kwargs...) return LeftOrthViaSVD(alg′) end -function select_algorithm(::typeof(right_orth!), A, ::Val{:lq}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(right_orth!), A, ::Val{:lq}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("LQ-based `right_orth` is incompatible with specifying `trunc`")) alg = select_algorithm(lq_compact!, A; kwargs...) return RightOrthViaLQ(alg) end -function select_algorithm(::typeof(right_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(right_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("Polar-based `right_orth` is incompatible with specifying `trunc`")) alg = select_algorithm(right_polar!, A; kwargs...) return RightOrthViaPolar(alg) end -function select_algorithm(::typeof(right_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(right_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A; kwargs...) : select_algorithm(svd_trunc!, A; trunc, kwargs...) return RightOrthViaSVD(alg′) end -function select_algorithm(::typeof(left_null!), A, ::Val{:qr}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(left_null!), A, ::Val{:qr}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("QR-based `left_null` is incompatible with specifying `trunc`")) alg = select_algorithm(qr_null!, A; kwargs...) return LeftNullViaQR(alg) end -function select_algorithm(::typeof(left_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(left_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) alg_svd = select_algorithm(svd_full!, A, get(kwargs, :svd, nothing)) alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) return LeftNullViaSVD(alg) end -function select_algorithm(::typeof(right_null!), A, ::Val{:lq}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(right_null!), A, ::Val{:lq}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("LQ-based `right_null` is incompatible with specifying `trunc`")) alg = select_algorithm(lq_null!, A; kwargs...) return RightNullViaLQ(alg) end -function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) +@inline function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) alg_svd = select_algorithm(svd_full!, A; kwargs...) alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) return RightNullViaSVD(alg) end -default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = +@inline default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); kwargs...) : select_algorithm(left_orth!, A, Val(:svd); trunc, kwargs...) -default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = +@inline default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); kwargs...) : select_algorithm(right_orth!, A, Val(:svd); trunc, kwargs...) -default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = +@inline default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); kwargs...) : select_algorithm(left_null!, A, Val(:svd); trunc, kwargs...) -default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = +@inline default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) : select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...) diff --git a/src/interface/qr.jl b/src/interface/qr.jl index 89f695ad..df8ae20e 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -71,8 +71,8 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...) default_qr_algorithm(T::Type; kwargs...) = throw(MethodError(default_qr_algorithm, (T,))) -default_qr_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} = - Householder(; driver, kwargs...) +default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} = + Householder(; kwargs...) default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} = DiagonalAlgorithm(; kwargs...) default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} = diff --git a/test/algorithms.jl b/test/algorithms.jl index 078ed2a0..c1bca7d4 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, - default_algorithm, select_algorithm, Householder, LAPACK + default_algorithm, select_algorithm, Householder, DefaultDriver @testset "default_algorithm" begin A = randn(3, 3) @@ -17,21 +17,21 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, LAPACK_MultipleRelativelyRobustRepresentations() end for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null) - @test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK()) + @test @constinferred(default_algorithm(f, A)) == Householder() end for f in (left_polar!, left_polar, right_polar!, right_polar) @test @constinferred(default_algorithm(f, A)) == PolarViaSVD(LAPACK_DivideAndConquer()) end for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null) - @test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK()) + @test @constinferred(default_algorithm(f, A)) == Householder() end for f in (schur_full!, schur_full, schur_vals!, schur_vals) @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() end @test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) == - Householder(; driver = LAPACK(), blocksize = 2) + Householder(; blocksize = 2) end @testset "select_algorithm" begin