Skip to content

fix: improvements to bypass scalar indexing and improve GPU support#375

Merged
lkdvos merged 27 commits into
mainfrom
ksh/cuda_tweaks
May 19, 2026
Merged

fix: improvements to bypass scalar indexing and improve GPU support#375
lkdvos merged 27 commits into
mainfrom
ksh/cuda_tweaks

Conversation

@kshyatt
Copy link
Copy Markdown
Member

@kshyatt kshyatt commented Feb 18, 2026

Needed to get more MPSKit examples working

Comment thread ext/TensorKitCUDAExt/auxiliary.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread src/tensors/braidingtensor.jl Outdated
Comment thread src/tensors/treetransformers.jl Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 26, 2026

Codecov Report

❌ Patch coverage is 66.66667% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/TensorKitCUDAExt/cutensormap.jl 0.00% 3 Missing ⚠️
src/tensors/tensor.jl 0.00% 2 Missing ⚠️
src/tensors/abstracttensor.jl 85.71% 1 Missing ⚠️
Files with missing lines Coverage Δ
ext/TensorKitCUDAExt/truncation.jl 96.77% <100.00%> (ø)
src/tensors/adjoint.jl 89.65% <ø> (-0.35%) ⬇️
src/tensors/braidingtensor.jl 88.51% <100.00%> (+0.07%) ⬆️
src/tensors/diagonal.jl 90.23% <100.00%> (ø)
src/tensors/indexmanipulations.jl 72.50% <100.00%> (-1.13%) ⬇️
src/tensors/tensoroperations.jl 96.27% <100.00%> (-1.40%) ⬇️
src/tensors/abstracttensor.jl 56.34% <85.71%> (+1.24%) ⬆️
src/tensors/tensor.jl 83.23% <0.00%> (-0.58%) ⬇️
ext/TensorKitCUDAExt/cutensormap.jl 70.83% <0.00%> (-3.84%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kshyatt kshyatt marked this pull request as draft February 27, 2026 11:14
@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented Feb 27, 2026

Let's make this a draft too to cut down on CI thrash

@kshyatt kshyatt force-pushed the ksh/cuda_tweaks branch 2 times, most recently from f5857b3 to 32e182d Compare March 12, 2026 12:36
@kshyatt kshyatt force-pushed the ksh/cuda_tweaks branch 2 times, most recently from f5faaf6 to 2359d28 Compare March 23, 2026 14:24
@lkdvos lkdvos mentioned this pull request Mar 26, 2026
Copy link
Copy Markdown
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comment throughout, there are some things that I am not entirely convinced by but the rest looks great, thanks for working through all of this!

For the similarstoragetype(tensor, storagetype) calls that you added, this seems like something we should probably discuss over a separate PR, and it would be great if we could consolidate this one to get the remainder of the fixes in.
Would you be up for splitting these two things, and then getting this merged?

The same kind of holds for some of the comments I made too, if we can just postpone the things that are not obvious, but already get the other parts in, that would probably be helpful.

(Note that I am very much aware that none of this is your fault and this PR has lived for too long so the design shifts a bit, for which I do apologize!)

Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread src/tensors/abstracttensor.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
Comment thread src/tensors/abstracttensor.jl
Comment thread src/tensors/indexmanipulations.jl Outdated
Comment thread src/tensors/indexmanipulations.jl Outdated
Comment thread src/tensors/indexmanipulations.jl Outdated
Comment thread src/tensors/tensoroperations.jl Outdated
@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented Mar 31, 2026

It's completely fine!! This has stayed open as I work through adding more tests for MPSKit, so I think we can pare off the simpler stuff we agree on, and then discuss things that are more contentious.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 31, 2026

Your PR no longer requires formatting changes. Thank you for your contribution!

Comment thread src/tensors/braidingtensor.jl
Comment thread src/tensors/indexmanipulations.jl Outdated
Comment thread src/tensors/indexmanipulations.jl Outdated
Comment thread src/tensors/indexmanipulations.jl Outdated
Comment thread src/tensors/abstracttensor.jl Outdated
Comment thread src/tensors/adjoint.jl Outdated
Comment thread src/tensors/adjoint.jl Outdated
Comment thread src/tensors/indexmanipulations.jl Outdated
Copy link
Copy Markdown
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like some of the rebasing and the github UI has made it hard to spot the comments I left before, although I think many of them are still unresolved and could be discussed :)

Comment thread ext/TensorKitCUDAExt/cutensormap.jl Outdated
@kshyatt kshyatt force-pushed the ksh/cuda_tweaks branch from 9e9b0ed to a37a133 Compare May 15, 2026 10:33
@kshyatt kshyatt force-pushed the ksh/cuda_tweaks branch from a37a133 to 3504de3 Compare May 15, 2026 11:26
Comment thread ext/TensorKitCUDAExt/truncation.jl Outdated
@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented May 18, 2026

Think I addressed everything and cleaned up the diff a bit as well

@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented May 19, 2026

AMD test fail is unrelated. Does anyone have objections to this getting merged today?

Comment thread ext/TensorKitCUDAExt/truncation.jl Outdated
Comment thread ext/TensorKitCUDAExt/truncation.jl Outdated
Comment thread src/tensors/abstracttensor.jl
Comment on lines +47 to +66
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy}
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end

for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace))
@eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat)
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end
end

function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue)
atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values))
end

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to overload the truncate_domain! and truncate_codomain! and truncate_diagonal! functions instead?
This looks like it is quite prone to ambiguity, and I guess we will also have to copy this for the findtruncated version too if we want to have eigenvalue decompositions as well.

Suggested change
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy}
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end
for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace))
@eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat)
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end
end
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue)
atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values))
end
function TensorKit.Factorizations.truncate_domain!(tdst::CuTensorMap, tsrc::CuTensorMap, inds)
for (c, b) in blocks(tdst)
I = get(inds, c, nothing)
@assert !isnothing(I)
I = CUDA.CUDACore.Adapt.adapt(Vector, I)
b′ = block(tsrc, c)
b .= view(b′, :, I)
end
return tdst
end
function TensorKit.Factorizations.truncate_codomain!(tdst::CuTensorMap, tsrc::CuTensorMap, inds)
for (c, b) in blocks(tdst)
I = get(inds, c, nothing)
@assert !isnothing(I)
I = CUDA.CUDACore.Adapt.adapt(Vector, I)
b′ = block(tsrc, c)
b .= view(b′, I, :)
end
return tdst
end
function TensorKit.Factorizations.truncate_diagonal!(Ddst::DiagonalCuTensorMap, Dsrc::DiagonalCuTensorMap, inds)
for (c, b) in blocks(Ddst)
I = get(inds, c, nothing)
@assert !isnothing(I)
I = CUDA.CUDACore.Adapt.adapt(Vector, I)
diagview(b) .= view(diagview(block(Dsrc, c)), I)
end
return Ddst
end

(Warning, did not try to run this!)
Also, should this be adapt or collect?
Also, I added the DiagonalCuTensorMap, not sure if we have that type alias yet. (And also if this is required for them?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried overriding those and the problem was I was getting compilation errors from GPUCompiler, it was enough of a rabbit hole that I thought it made more sense to punt this for now

Comment thread ext/TensorKitCUDAExt/truncation.jl
Co-authored-by: Lukas Devos <ldevos98@gmail.com>
lkdvos
lkdvos previously approved these changes May 19, 2026
Copy link
Copy Markdown
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall the current changes look good to me. The only remaining part is the one on the factorizations, but if this unblocks it, we can always revisit if this shows up in profilers.

@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented May 19, 2026

Sorry by factorizations u mean truncate_{co}domain?

@kshyatt kshyatt enabled auto-merge (squash) May 19, 2026 12:32
@lkdvos lkdvos changed the title More tweaks fix: improvements to bypass scalar indexing and improve GPU support May 19, 2026
@lkdvos lkdvos disabled auto-merge May 19, 2026 13:29
@lkdvos lkdvos merged commit e3bdab4 into main May 19, 2026
37 of 40 checks passed
@lkdvos lkdvos deleted the ksh/cuda_tweaks branch May 19, 2026 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants