You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
+5-23Lines changed: 5 additions & 23 deletions
Original file line number
Diff line number
Diff line change
@@ -15,38 +15,20 @@ using LinearAlgebra: BlasFloat
15
15
16
16
include("yacusolver.jl")
17
17
18
-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuMatrix{TT}}
18
+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuVecOrMat{TT}}
19
19
returnCUSOLVER_HouseholderQR(; kwargs...)
20
20
end
21
-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuMatrix{TT}}
21
+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuVecOrMat{TT}}
22
22
qr_alg =CUSOLVER_HouseholderQR(; kwargs...)
23
23
returnLQViaTransposedQR(qr_alg)
24
24
end
25
-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuMatrix{TT}}
25
+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuVecOrMat{TT}}
26
26
returnCUSOLVER_QRIteration(; kwargs...)
27
27
end
28
-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuMatrix{TT}}
28
+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuVecOrMat{TT}}
29
29
returnCUSOLVER_Simple(; kwargs...)
30
30
end
31
-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuMatrix{TT}}
32
-
returnCUSOLVER_DivideAndConquer(; kwargs...)
33
-
end
34
-
35
-
# include for block sector support
36
-
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <:BlasFloat, A <:CuVecOrMat{T}}
37
-
returnCUSOLVER_HouseholderQR(; kwargs...)
38
-
end
39
-
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <:BlasFloat, A <:CuVecOrMat{T}}
40
-
qr_alg =CUSOLVER_HouseholderQR(; kwargs...)
41
-
returnLQViaTransposedQR(qr_alg)
42
-
end
43
-
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <:BlasFloat, A <:CuVecOrMat{T}}
44
-
returnCUSOLVER_Jacobi(; kwargs...)
45
-
end
46
-
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <:BlasFloat, A <:CuVecOrMat{T}}
47
-
returnCUSOLVER_Simple(; kwargs...)
48
-
end
49
-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <:BlasFloat, A <:CuVecOrMat{T}}
31
+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <:BlasFloat, T <:StridedCuVecOrMat{TT}}
0 commit comments