fix: improvements to bypass scalar indexing and improve GPU support#375
Conversation
3bed38d to
8665c4a
Compare
eabfce9 to
0c903ac
Compare
Codecov Report❌ Patch coverage is
... and 1 file with indirect coverage changes 🚀 New features to boost your workflow:
|
|
Let's make this a draft too to cut down on CI thrash |
f5857b3 to
32e182d
Compare
f5faaf6 to
2359d28
Compare
lkdvos
left a comment
There was a problem hiding this comment.
Left some comment throughout, there are some things that I am not entirely convinced by but the rest looks great, thanks for working through all of this!
For the similarstoragetype(tensor, storagetype) calls that you added, this seems like something we should probably discuss over a separate PR, and it would be great if we could consolidate this one to get the remainder of the fixes in.
Would you be up for splitting these two things, and then getting this merged?
The same kind of holds for some of the comments I made too, if we can just postpone the things that are not obvious, but already get the other parts in, that would probably be helpful.
(Note that I am very much aware that none of this is your fault and this PR has lived for too long so the design shifts a bit, for which I do apologize!)
|
It's completely fine!! This has stayed open as I work through adding more tests for MPSKit, so I think we can pare off the simpler stuff we agree on, and then discuss things that are more contentious. |
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
8a12178 to
ad62dad
Compare
d29251a to
3c5a575
Compare
lkdvos
left a comment
There was a problem hiding this comment.
It seems like some of the rebasing and the github UI has made it hard to spot the comments I left before, although I think many of them are still unresolved and could be discussed :)
|
Think I addressed everything and cleaned up the diff a bit as well |
|
AMD test fail is unrelated. Does anyone have objections to this getting merged today? |
| function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} | ||
| # returning a CuSectorVector wrecks things in truncate_{co}domain | ||
| # because of scalar indexing | ||
| return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) | ||
| end | ||
|
|
||
| for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace)) | ||
| @eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat) | ||
| # returning a CuSectorVector wrecks things in truncate_{co}domain | ||
| # because of scalar indexing | ||
| return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) | ||
| end | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) | ||
| atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) | ||
| strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) | ||
| return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) | ||
| end | ||
|
|
There was a problem hiding this comment.
Would it make sense to overload the truncate_domain! and truncate_codomain! and truncate_diagonal! functions instead?
This looks like it is quite prone to ambiguity, and I guess we will also have to copy this for the findtruncated version too if we want to have eigenvalue decompositions as well.
| function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} | |
| # returning a CuSectorVector wrecks things in truncate_{co}domain | |
| # because of scalar indexing | |
| return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) | |
| end | |
| for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace)) | |
| @eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat) | |
| # returning a CuSectorVector wrecks things in truncate_{co}domain | |
| # because of scalar indexing | |
| return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) | |
| end | |
| end | |
| function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) | |
| atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) | |
| strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) | |
| return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) | |
| end | |
| function TensorKit.Factorizations.truncate_domain!(tdst::CuTensorMap, tsrc::CuTensorMap, inds) | |
| for (c, b) in blocks(tdst) | |
| I = get(inds, c, nothing) | |
| @assert !isnothing(I) | |
| I = CUDA.CUDACore.Adapt.adapt(Vector, I) | |
| b′ = block(tsrc, c) | |
| b .= view(b′, :, I) | |
| end | |
| return tdst | |
| end | |
| function TensorKit.Factorizations.truncate_codomain!(tdst::CuTensorMap, tsrc::CuTensorMap, inds) | |
| for (c, b) in blocks(tdst) | |
| I = get(inds, c, nothing) | |
| @assert !isnothing(I) | |
| I = CUDA.CUDACore.Adapt.adapt(Vector, I) | |
| b′ = block(tsrc, c) | |
| b .= view(b′, I, :) | |
| end | |
| return tdst | |
| end | |
| function TensorKit.Factorizations.truncate_diagonal!(Ddst::DiagonalCuTensorMap, Dsrc::DiagonalCuTensorMap, inds) | |
| for (c, b) in blocks(Ddst) | |
| I = get(inds, c, nothing) | |
| @assert !isnothing(I) | |
| I = CUDA.CUDACore.Adapt.adapt(Vector, I) | |
| diagview(b) .= view(diagview(block(Dsrc, c)), I) | |
| end | |
| return Ddst | |
| end |
(Warning, did not try to run this!)
Also, should this be adapt or collect?
Also, I added the DiagonalCuTensorMap, not sure if we have that type alias yet. (And also if this is required for them?)
There was a problem hiding this comment.
I tried overriding those and the problem was I was getting compilation errors from GPUCompiler, it was enough of a rabbit hole that I thought it made more sense to punt this for now
Co-authored-by: Lukas Devos <ldevos98@gmail.com>
lkdvos
left a comment
There was a problem hiding this comment.
I think overall the current changes look good to me. The only remaining part is the one on the factorizations, but if this unblocks it, we can always revisit if this shows up in profilers.
|
Sorry by factorizations u mean |
Needed to get more MPSKit examples working