Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ end

function MatrixAlgebraKit.householder_qr!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down Expand Up @@ -102,9 +102,9 @@ end

function MatrixAlgebraKit.householder_qr_null!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down
27 changes: 19 additions & 8 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@ See also [`@algdef`](@ref).
"""
struct Algorithm{name, K} <: AbstractAlgorithm
kwargs::K

# Ensure keywords are always in canonical order
function Algorithm{Name}(kwargs::NamedTuple) where {Name}
kwargs_sorted = _sortkeys(kwargs)
return new{Name, typeof(kwargs_sorted)}(kwargs_sorted)
end
end
Algorithm{Name}(; kwargs...) where {Name} = Algorithm{Name}(NamedTuple(kwargs))

# Utility generated function to canonicalize keys in type-stable way
@generated _sortkeys(nt::NamedTuple{K}) where {K} =
:(NamedTuple{$(Tuple(sort!(collect(K))))}(nt))

name(alg::Algorithm) = name(typeof(alg))
name(::Type{<:Algorithm{N}}) where {N} = N

Expand Down Expand Up @@ -88,7 +100,9 @@ Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
""" select_algorithm

@inline function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
# WARNING: In order to keep everything type stable, this function is marked as foldable.
# This mostly means that the `default_algorithm` implementation must be foldable as well
Base.@assume_effects :foldable function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
Expand Down Expand Up @@ -117,8 +131,10 @@ In general, this is called by [`select_algorithm`](@ref) if no algorithm is spec
explicitly.
New types should prefer to register their default algorithms in the type domain.
""" default_algorithm
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
default_algorithm(f::F, A, B; kwargs...) where {F} = default_algorithm(f, typeof(A), typeof(B); kwargs...)
@inline default_algorithm(f::F, A; kwargs...) where {F} =
default_algorithm(f, typeof(A); kwargs...)
@inline default_algorithm(f::F, A, B; kwargs...) where {F} =
default_algorithm(f, typeof(A), typeof(B); kwargs...)
# avoid infinite recursion:
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F, T}
throw(MethodError(default_algorithm, (f, T)))
Expand Down Expand Up @@ -299,11 +315,6 @@ macro algdef(name)
return esc(
quote
const $name{K} = Algorithm{$(QuoteNode(name)), K}
function $name(; kwargs...)
# TODO: is this necessary/useful?
kw = NamedTuple(kwargs) # normalize type
return $name{typeof(kw)}(kw)
end
function Base.show(io::IO, alg::$name)
return ($_show_alg)(io, alg)
end
Expand Down
17 changes: 10 additions & 7 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
function householder_lq!(
driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive = true, pivoted = false,
blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
positive = true, pivoted = false, blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))

# error messages for disallowing driver - setting combinations
pivoted && (blocksize > 1) &&
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
Expand Down Expand Up @@ -176,10 +177,10 @@ function householder_lq!(
end
function householder_lq!(
driver::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
Expand Down Expand Up @@ -225,8 +226,10 @@ householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
function householder_lq_null!(
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = pivoted ? 1 : YALAPACK.default_qr_blocksize(A)
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : (pivoted ? 1 : YALAPACK.default_qr_blocksize(A))

# error messages for disallowing driver - setting combinations
pivoted && (blocksize > 1) &&
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
Expand All @@ -248,10 +251,10 @@ function householder_lq_null!(
end
function householder_lq_null!(
driver::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
Expand Down
16 changes: 9 additions & 7 deletions src/implementations/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
function householder_qr!(
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false,
blocksize::Int = ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))

# error messages for disallowing driver - setting combinations
(blocksize == 1 || driver === LAPACK()) ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
Expand Down Expand Up @@ -202,10 +204,10 @@ function householder_qr!(
end
function householder_qr!(
driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down Expand Up @@ -249,9 +251,9 @@ householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
householder_qr_null!(default_householder_driver(A), A, N; kwargs...)
function householder_qr_null!(
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false,
blocksize::Int = ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A))
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A))
# error messages for disallowing driver - setting combinations
(blocksize == 1 || driver === LAPACK()) ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
Expand All @@ -277,10 +279,10 @@ function householder_qr_null!(
end
function householder_qr_null!(
driver::Native, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down
8 changes: 7 additions & 1 deletion src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,17 @@ The optional `driver` symbol can be used to choose between different implementat

- `positive::Bool = true` : Fix the gauge of the resulting factors by making the diagonal elements of `L` or `R` non-negative.
- `pivoted::Bool = false` : Use column- or row-pivoting for low-rank input matrices.
- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`.
- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`. Use the default if `blocksize ≤ 0`.

Depending on the driver, various other keywords may be (un)available to customize the implementation.
"""
@algdef Householder
function Householder(;
blocksize::Int = 0, driver::Driver = DefaultDriver(),
pivoted::Bool = false, positive::Bool = true
)
return Householder((; blocksize, driver, pivoted, positive))
end

default_householder_driver(A) = default_householder_driver(typeof(A))
default_householder_driver(::Type) = Native()
Expand Down
4 changes: 2 additions & 2 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)

default_lq_algorithm(T::Type; kwargs...) = throw(MethodError(default_lq_algorithm, (T,)))
default_lq_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} =
Householder(; driver, kwargs...)
default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
Householder(; kwargs...)
default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
DiagonalAlgorithm(; kwargs...)
default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
Expand Down
4 changes: 2 additions & 2 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)

default_qr_algorithm(T::Type; kwargs...) = throw(MethodError(default_qr_algorithm, (T,)))
default_qr_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} =
Householder(; driver, kwargs...)
default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
Householder(; kwargs...)
default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
DiagonalAlgorithm(; kwargs...)
default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
Expand Down
Loading