Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9e1ebfe
More tweaks for GPU support
kshyatt Feb 18, 2026
ea787ae
fix typo
kshyatt Apr 27, 2026
c3d33e2
Fix TC once again
kshyatt May 11, 2026
eb6b985
Remove unneeded Adjoint methods
kshyatt May 12, 2026
c4653c8
Remove unneeded TensorMapWithStorage?
kshyatt May 12, 2026
b46e23e
Death to to_cpu
kshyatt May 12, 2026
5932066
Remove unneeded similarstoragetype method
kshyatt May 12, 2026
46e5037
Add in TensorMap constructor
kshyatt May 12, 2026
9fdc185
Restore former braiding tensor methods
kshyatt May 12, 2026
e948370
Fix type issue for sortperm
kshyatt May 13, 2026
28dca1e
Remove stale type params
kshyatt May 13, 2026
259309e
Apply suggestions from code review
kshyatt May 14, 2026
f23ce5c
Fix bad result of suggestion
kshyatt May 14, 2026
3504de3
Another fix?
kshyatt May 14, 2026
830d955
Force inds to move back to the CPU
kshyatt May 16, 2026
1fe40a7
Return to glorious scalartype
kshyatt May 16, 2026
55d25b5
Restore DiagonalTensorMap ctor
kshyatt May 16, 2026
a08bd23
Resolve trunc ambiguity
kshyatt May 16, 2026
921ffc6
Remove extra CUDA ctor
kshyatt May 16, 2026
ed07969
Restore chopped argument
kshyatt May 16, 2026
285ac16
Also remove no longer needed method
kshyatt May 16, 2026
ae018d1
Remove forced Int eltype
kshyatt May 17, 2026
248f8a3
Get rid of no-op ctor
kshyatt May 18, 2026
06d1ac5
Try to resolve ambiguity
kshyatt May 18, 2026
8efbb64
Cover all truncation strategies
kshyatt May 19, 2026
fa16c1a
Short-circuit logic in `findtruncated`
kshyatt May 19, 2026
0d844fb
Formatter
kshyatt May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
end

function Base.convert(
TT::Type{CuTensorMap{T, S, N₁, N₂}},
t::AbstractTensorMap{<:Any, S, N₁, N₂}
) where {T, S, N₁, N₂}
if typeof(t) === TT
return t
else
tnew = TT(undef, space(t))
return copy!(tnew, t)
end
end

function LinearAlgebra.isposdef(t::CuTensorMap)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
Expand All @@ -138,10 +126,9 @@ function Base.promote_rule(
return CuTensorMap{T, S, N₁, N₂}
end

TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
CuArray{T, N, CUDA.default_memory}


# CuTensorMap exponentation:
function TensorKit.exp!(t::CuTensorMap)
domain(t) == codomain(t) ||
Expand Down
24 changes: 24 additions & 0 deletions ext/TensorKitCUDAExt/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ function MatrixAlgebraKit.findtruncated(
fill!(v, dim(c))
end

isempty(parent(values)) && return similar(values, Bool)

perm = sortperm(parent(values); strategy.by, strategy.rev)
cumulative_dim = cumsum(Base.permute!(parent(dims), perm))

Expand All @@ -36,6 +38,8 @@ function MatrixAlgebraKit.findtruncated(
end
end

isempty(parent(values)) && return similar(values, Bool)

perm = sortperm(parent(values); by = abs, rev = false)
cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm))

Expand All @@ -44,6 +48,26 @@ function MatrixAlgebraKit.findtruncated(
return result
end

function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy}
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end

for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace))
@eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat)
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end
end

function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue)
atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values))
end

Comment on lines +51 to +70
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to overload the truncate_domain! and truncate_codomain! and truncate_diagonal! functions instead?
This looks like it is quite prone to ambiguity, and I guess we will also have to copy this for the findtruncated version too if we want to have eigenvalue decompositions as well.

Suggested change
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy}
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end
for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace))
@eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat)
# returning a CuSectorVector wrecks things in truncate_{co}domain
# because of scalar indexing
return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy))
end
end
function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue)
atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol)
strategy′ = trunctol(; atol, strategy.by, strategy.keep_below)
return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values))
end
function TensorKit.Factorizations.truncate_domain!(tdst::CuTensorMap, tsrc::CuTensorMap, inds)
for (c, b) in blocks(tdst)
I = get(inds, c, nothing)
@assert !isnothing(I)
I = CUDA.CUDACore.Adapt.adapt(Vector, I)
b′ = block(tsrc, c)
b .= view(b′, :, I)
end
return tdst
end
function TensorKit.Factorizations.truncate_codomain!(tdst::CuTensorMap, tsrc::CuTensorMap, inds)
for (c, b) in blocks(tdst)
I = get(inds, c, nothing)
@assert !isnothing(I)
I = CUDA.CUDACore.Adapt.adapt(Vector, I)
b′ = block(tsrc, c)
b .= view(b′, I, :)
end
return tdst
end
function TensorKit.Factorizations.truncate_diagonal!(Ddst::DiagonalCuTensorMap, Dsrc::DiagonalCuTensorMap, inds)
for (c, b) in blocks(Ddst)
I = get(inds, c, nothing)
@assert !isnothing(I)
I = CUDA.CUDACore.Adapt.adapt(Vector, I)
diagview(b) .= view(diagview(block(Dsrc, c)), I)
end
return Ddst
end

(Warning, did not try to run this!)
Also, should this be adapt or collect?
Also, I added the DiagonalCuTensorMap, not sure if we have that type alias yet. (And also if this is required for them?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried overriding those and the problem was I was getting compilation errors from GPUCompiler, it was enough of a rabbit hole that I thought it made more sense to punt this for now

# Needed until MatrixAlgebraKit patch hits...
function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int})
Comment thread
lkdvos marked this conversation as resolved.
result = fill!(similar(A), false)
Expand Down
9 changes: 4 additions & 5 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ storagetype(t) = storagetype(typeof(t))
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
if T isa Union
# attempt to be slightly more specific by promoting unions
Ma = storagetype(T.a)
Mb = storagetype(T.b)
return promote_storagetype(Ma, Mb)
return promote_storagetype(T.a, T.b)
else
# fallback definition by using scalartype
return similarstoragetype(scalartype(T))
Comment thread
lkdvos marked this conversation as resolved.
Expand Down Expand Up @@ -103,8 +101,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =

# implement on tensors
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
similarstoragetype(storagetype(TT), T)
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
return similarstoragetype(storagetype(TT), T)
end
Comment thread
kshyatt marked this conversation as resolved.

# implement on arrays
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
Expand Down
2 changes: 0 additions & 2 deletions src/tensors/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tu
return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...))
end

to_cpu(t::AdjointTensorMap) = adjoint(to_cpu(adjoint(t)))

# Show
#------
function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool)
Expand Down
7 changes: 1 addition & 6 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac
end
return TensorMap{T, S, N₁, N₂, A}(data, space)
end

# constructors from data
function TensorMap{T, S, N₁, N₂, A}(
data::A, space::TensorMapSpace{S, N₁, N₂}
Expand All @@ -34,6 +33,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac
return new{T, S, N₁, N₂, A}(data, space)
end
end
TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = TensorMap(A(t.data), space(t))

"""
Tensor{T, S, N, A<:DenseVector{T}} = TensorMap{T, S, N, 0, A}
Expand Down Expand Up @@ -407,11 +407,6 @@ for randf in (:rand, :randn, :randexp, :randisometry)
end
end

# Moving arbitrary TensorMaps to CPU
#-----------------------------
to_cpu(t::TensorMapWithStorage{T, Vector{T}}) where {T} = t # no op
to_cpu(t::TensorMap) = convert(TensorMapWithStorage{scalartype(t), similarstoragetype(scalartype(t))}, t)

# Efficient copy constructors
#-----------------------------
Base.copy(t::TensorMap) = typeof(t)(copy(t.data), t.space)
Expand Down
72 changes: 36 additions & 36 deletions test/amd/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ for V in spacelist
for T in (Int, Float32, ComplexF64)
t = @constinferred AMDGPU.rand(T, W)
d = convert(Dict, t)
@test TensorKit.to_cpu(t) == convert(TensorMap, d)
@test adapt(Array, t) == convert(TensorMap, d)
end
end
symmetricbraiding && @timedtestset "Basic linear algebra" begin
Expand Down Expand Up @@ -189,10 +189,10 @@ for V in spacelist
t = AMDGPU.rand(T, W)
t2 = @constinferred AMDGPU.rand!(similar(t))
α = rand(T)
@test norm(t, 2) ≈ norm(TensorKit.to_cpu(t), 2)
@test dot(t2, t) ≈ dot(TensorKit.to_cpu(t2), TensorKit.to_cpu(t))
@test TensorKit.to_cpu(α * t) ≈ α * TensorKit.to_cpu(t)
@test TensorKit.to_cpu(t + t) ≈ 2 * TensorKit.to_cpu(t)
@test norm(t, 2) ≈ norm(adapt(Array, t), 2)
@test dot(t2, t) ≈ dot(adapt(Array, t2), adapt(Array, t))
@test adapt(Array, α * t) ≈ α * adapt(Array, t)
@test adapt(Array, t + t) ≈ 2 * adapt(Array, t)
end
end
@timedtestset "Real and imaginary parts" begin
Expand All @@ -202,17 +202,17 @@ for V in spacelist

tr = @constinferred real(t)
@test scalartype(tr) <: Real
@test real(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tr)
@test real(adapt(Array, t)) == adapt(Array, tr)
@test storagetype(tr) == ROCVector{real(T), AMDGPU.Mem.HIPBuffer}

ti = @constinferred imag(t)
@test scalartype(ti) <: Real
@test imag(TensorKit.to_cpu(t)) == TensorKit.to_cpu(ti)
@test imag(adapt(Array, t)) == adapt(Array, ti)
@test storagetype(ti) == ROCVector{real(T), AMDGPU.Mem.HIPBuffer}

tc = @inferred complex(t)
@test scalartype(tc) <: Complex
@test complex(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tc)
@test complex(adapt(Array, t)) == adapt(Array, tc)
@test storagetype(tc) == ROCVector{complex(T), AMDGPU.Mem.HIPBuffer}

tc2 = @inferred complex(tr, ti)
Expand Down Expand Up @@ -275,13 +275,13 @@ for V in spacelist
p1 = ntuple(n -> p[n], k)
p2 = ntuple(n -> p[k + n], 5 - k)
dt2 = AMDGPU.@allowscalar permute(t, (p1, p2))
ht2 = permute(TensorKit.to_cpu(t), (p1, p2))
@test ht2 == TensorKit.to_cpu(dt2)
ht2 = permute(adapt(Array, t), (p1, p2))
@test ht2 == adapt(Array, dt2)
end

dt3 = AMDGPU.@allowscalar repartition(t, k)
ht3 = repartition(TensorKit.to_cpu(t), k)
@test ht3 == TensorKit.to_cpu(dt3)
ht3 = repartition(adapt(Array, t), k)
@test ht3 == adapt(Array, dt3)
end
end
symmetricbraiding && @timedtestset "Full trace: test self-consistency" begin
Expand Down Expand Up @@ -339,10 +339,10 @@ for V in spacelist
@tensor dHrA12[a, s1, s2, c] := drhoL[a, a'] * conj(dA1[a', t1, b]) *
dA2[b, t2, c'] * drhoR[c', c] *
dH[s1, s2, t1, t2]
@tensor hHrA12[a, s1, s2, c] := TensorKit.to_cpu(drhoL)[a, a'] * conj(TensorKit.to_cpu(dA1)[a', t1, b]) *
TensorKit.to_cpu(dA2)[b, t2, c'] * TensorKit.to_cpu(drhoR)[c', c] *
TensorKit.to_cpu(dH)[s1, s2, t1, t2]
@test TensorKit.to_cpu(dHrA12) ≈ hHrA12
@tensor hHrA12[a, s1, s2, c] := adapt(Array, drhoL)[a, a'] * conj(adapt(Array, dA1)[a', t1, b]) *
adapt(Array, dA2)[b, t2, c'] * adapt(Array, drhoR)[c', c] *
adapt(Array, dH)[s1, s2, t1, t2]
@test adapt(Array, dHrA12) ≈ hHrA12
end
end=# # doesn't yet work because of AdjointTensor
BraidingStyle(I) isa HasBraiding && @timedtestset "Index flipping: test flipping inverse" begin
Expand Down Expand Up @@ -422,31 +422,31 @@ for V in spacelist
t1 = AMDGPU.rand(T, W1, W1)
t2 = AMDGPU.rand(T, W2, W2)
t = AMDGPU.rand(T, W1, W2)
ht1 = TensorKit.to_cpu(t1)
ht2 = TensorKit.to_cpu(t2)
ht = TensorKit.to_cpu(t)
@test TensorKit.to_cpu(t1 * t) ≈ ht1 * ht
@test TensorKit.to_cpu(t1' * t) ≈ ht1' * ht
@test TensorKit.to_cpu(t2 * t') ≈ ht2 * ht'
@test TensorKit.to_cpu(t2' * t') ≈ ht2' * ht'
ht1 = adapt(Array, t1)
ht2 = adapt(Array, t2)
ht = adapt(Array, t)
@test adapt(Array, t1 * t) ≈ ht1 * ht
@test adapt(Array, t1' * t) ≈ ht1' * ht
@test adapt(Array, t2 * t') ≈ ht2 * ht'
@test adapt(Array, t2' * t') ≈ ht2' * ht'

#=AMDGPU.@allowscalar begin
@test TensorKit.to_cpu(inv(t1)) ≈ inv(ht1)
@test TensorKit.to_cpu(pinv(t)) ≈ pinv(ht)
@test adapt(Array, inv(t1)) ≈ inv(ht1)
@test adapt(Array, pinv(t)) ≈ pinv(ht)

if T == Float32 || T == ComplexF32
continue
end

@test TensorKit.to_cpu(t1 \ t) ≈ ht1 \ ht
@test TensorKit.to_cpu(t1' \ t) ≈ ht1' \ ht
@test TensorKit.to_cpu(t2 \ t') ≈ ht2 \ ht'
@test TensorKit.to_cpu(t2' \ t') ≈ ht2' \ ht'
@test adapt(Array, t1 \ t) ≈ ht1 \ ht
@test adapt(Array, t1' \ t) ≈ ht1' \ ht
@test adapt(Array, t2 \ t') ≈ ht2 \ ht'
@test adapt(Array, t2' \ t') ≈ ht2' \ ht'

@test TensorKit.to_cpu(t2 / t) ≈ ht2 / ht
@test TensorKit.to_cpu(t2' / t) ≈ ht2' / ht
@test TensorKit.to_cpu(t1 / t') ≈ ht1 / ht'
@test TensorKit.to_cpu(t1' / t') ≈ ht1' / ht'
@test adapt(Array, t2 / t) ≈ ht2 / ht
@test adapt(Array, t2' / t) ≈ ht2' / ht
@test adapt(Array, t1 / t') ≈ ht1 / ht'
@test adapt(Array, t1' / t') ≈ ht1' / ht'
end=#
end
end
Expand All @@ -456,11 +456,11 @@ for V in spacelist
#=t = project_hermitian!(AMDGPU.randn(T, W, W))
s = dim(W)
@test (@constinferred sqrt(t))^2 ≈ t
@test TensorKit.to_cpu(sqrt(t)) ≈ sqrt(TensorKit.to_cpu(t))
@test adapt(Array, sqrt(t)) ≈ sqrt(adapt(Array, t))
expt = @constinferred exp(t)
@test TensorKit.to_cpu(expt) ≈ exp(TensorKit.to_cpu(t))
@test adapt(Array, expt) ≈ exp(adapt(Array, t))
@test exp(@constinferred log(project_hermitian!(expt))) ≈ expt
@test TensorKit.to_cpu(log(project_hermitian!(expt))) ≈ log(TensorKit.to_cpu(expt))
@test adapt(Array, log(project_hermitian!(expt))) ≈ log(adapt(Array, expt))

@test (@constinferred cos(t))^2 + (@constinferred sin(t))^2 ≈
id(storagetype(t), W)
Expand Down
Loading
Loading