Skip to content

Commit 1bfff08

Browse files
authored
Define default algorithms for SubArray and ReshapedArray (#182)
1 parent 295a354 commit 1bfff08

7 files changed

Lines changed: 45 additions & 30 deletions

File tree

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,20 @@ using LinearAlgebra: BlasFloat
1515

1616
include("yacusolver.jl")
1717

18-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
18+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
1919
return CUSOLVER_HouseholderQR(; kwargs...)
2020
end
21-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
21+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2222
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
2323
return LQViaTransposedQR(qr_alg)
2424
end
25-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
25+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2626
return CUSOLVER_QRIteration(; kwargs...)
2727
end
28-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
28+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
2929
return CUSOLVER_Simple(; kwargs...)
3030
end
31-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
32-
return CUSOLVER_DivideAndConquer(; kwargs...)
33-
end
34-
35-
# include for block sector support
36-
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
37-
return CUSOLVER_HouseholderQR(; kwargs...)
38-
end
39-
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
40-
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
41-
return LQViaTransposedQR(qr_alg)
42-
end
43-
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
44-
return CUSOLVER_Jacobi(; kwargs...)
45-
end
46-
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
47-
return CUSOLVER_Simple(; kwargs...)
48-
end
49-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
31+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuVecOrMat{TT}}
5032
return CUSOLVER_DivideAndConquer(; kwargs...)
5133
end
5234

src/interface/eig.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,18 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc).
161161
# -------------------
162162
default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...)
163163
default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,)))
164-
function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
164+
function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
165165
return LAPACK_Expert(; kwargs...)
166166
end
167167
function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
168168
return DiagonalAlgorithm(; kwargs...)
169169
end
170+
function default_eig_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
171+
return default_eig_algorithm(A)
172+
end
173+
function default_eig_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
174+
return default_eig_algorithm(A)
175+
end
170176

171177
for f in (:eig_full!, :eig_vals!)
172178
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

src/interface/eigh.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,18 @@ default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs.
167167
function default_eigh_algorithm(T::Type; kwargs...)
168168
throw(MethodError(default_eigh_algorithm, (T,)))
169169
end
170-
function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
170+
function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
171171
return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...)
172172
end
173173
function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
174174
return DiagonalAlgorithm(; kwargs...)
175175
end
176+
function default_eigh_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
177+
return default_eigh_algorithm(A)
178+
end
179+
function default_eigh_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
180+
return default_eigh_algorithm(A)
181+
end
176182

177183
for f in (:eigh_full!, :eigh_vals!)
178184
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

src/interface/lq.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,18 @@ end
7575
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
7676
return Native_HouseholderLQ(; kwargs...)
7777
end
78-
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
78+
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
7979
return LAPACK_HouseholderLQ(; kwargs...)
8080
end
8181
function default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
8282
return DiagonalAlgorithm(; kwargs...)
8383
end
84+
function default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
85+
return default_lq_algorithm(A)
86+
end
87+
function default_lq_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
88+
return default_lq_algorithm(A)
89+
end
8490

8591
for f in (:lq_full!, :lq_compact!, :lq_null!)
8692
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

src/interface/qr.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,18 @@ end
7575
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix}
7676
return Native_HouseholderQR(; kwargs...)
7777
end
78-
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
78+
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
7979
return LAPACK_HouseholderQR(; kwargs...)
8080
end
8181
function default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
8282
return DiagonalAlgorithm(; kwargs...)
8383
end
84+
function default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
85+
return default_qr_algorithm(A)
86+
end
87+
function default_qr_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
88+
return default_qr_algorithm(A)
89+
end
8490

8591
for f in (:qr_full!, :qr_compact!, :qr_null!)
8692
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

src/interface/svd.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,18 @@ default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...
161161
function default_svd_algorithm(T::Type; kwargs...)
162162
throw(MethodError(default_svd_algorithm, (T,)))
163163
end
164-
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat}
164+
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
165165
return LAPACK_DivideAndConquer(; kwargs...)
166166
end
167167
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
168168
return DiagonalAlgorithm(; kwargs...)
169169
end
170+
function default_svd_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A}
171+
return default_svd_algorithm(A)
172+
end
173+
function default_svd_algorithm(::Type{SubArray{T, N, A}}) where {T, N, A}
174+
return default_svd_algorithm(A)
175+
end
170176

171177
for f in (:svd_full!, :svd_compact!, :svd_vals!)
172178
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}

src/yalapack.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
1515
using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
1616
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror
1717

18-
# type alias for matrices that are definitely supported by YALAPACK
18+
# type alias for vectors/matrices that are definitely supported by YALAPACK
19+
const BlasVec{T <: BlasFloat} = StridedVector{T}
1920
const BlasMat{T <: BlasFloat} = StridedMatrix{T}
20-
# type alias for matrices that are possibly supported by YALAPACK, after conversion
21+
# type alias for vectors/matrices that are possibly supported by YALAPACK, after conversion
22+
const MaybeBlasVec = Union{BlasVec, AbstractVector{<:Integer}}
2123
const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}}
24+
const MaybeBlasVecOrMat = Union{MaybeBlasVec, MaybeBlasMat}
2225

2326
# LU factorisation (currently unused in MatrixAlgebraKit)
2427
# for (getrf, getrs, elty) in (

0 commit comments

Comments
 (0)