Skip to content

Commit b6ed4ee

Browse files
committed
Cleanups
1 parent e4a5566 commit b6ed4ee

3 files changed

Lines changed: 29 additions & 22 deletions

File tree

src/tensors/abstractblocktensor/conversion.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,51 @@
11
# Conversion
22
# ----------
3-
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
4-
S = spacetype(t)
5-
N₁, N₂ = numout(t), numin(t)
6-
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
7-
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
8-
tdst = similar(t, cod dom)
9-
10-
issparse(t) && zerovector!(tdst)
113

4+
function _copy_subblocks!(tdst, tsrc)
5+
S = spacetype(tsrc)
6+
N₁, N₂ = numout(tsrc), numin(tsrc)
127
for ((f₁, f₂), arr) in subblocks(tdst)
138
blockax = ntuple(N₁ + N₂) do i
149
return if i <= N₁
15-
blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(t, i)))
10+
blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(tsrc, i)))
1611
else
17-
blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(t, i)'))
12+
blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(tsrc, i)'))
1813
end
1914
end
2015

21-
for (k, v) in nonzero_pairs(t)
16+
for (k, v) in nonzero_pairs(tsrc)
2217
indices = getindex.(blockax, Block.(Tuple(k)))
2318
arr_slice = arr[indices...]
2419
# need to check for empty since fusion tree pair might not be present
2520
isempty(arr_slice) || copy!(arr_slice, v[f₁, f₂])
2621
end
2722
end
23+
return tdst
24+
end
25+
26+
function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap)
27+
S = spacetype(t)
28+
N₁, N₂ = numout(t), numin(t)
29+
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
30+
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
31+
tdst = similar(t, cod dom)
32+
33+
issparse(t) && zerovector!(tdst)
34+
_copy_subblocks!(tdst, t)
35+
return tdst
36+
end
37+
38+
function Base.convert(::Type{TT}, t::AbstractBlockTensorMap) where {TT <: TensorMap}
39+
S = spacetype(t)
40+
N₁, N₂ = numout(t), numin(t)
41+
cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces))
42+
dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces))
43+
tdst = TK.TensorMapWithStorage{scalartype(TT), storagetype(TT)}(undef, cod dom)
44+
issparse(t) && zerovector!(tdst)
2845

46+
_copy_subblocks!(tdst, t)
2947
return tdst
3048
end
31-
# use subtype of TensorMap here to support CuTensorMap
32-
# ROCTensorMap, etc.
33-
Base.convert(::Type{<:TensorMap}, t::AbstractBlockTensorMap) = convert(TensorMap, t)
3449

3550
function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap}
3651
t isa TT && return t

src/tensors/blocktensor.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ struct BlockTensorMap{TT <: AbstractTensorMap, E, S, N₁, N₂, N} <:
2626
end
2727
end
2828

29-
# seems necessary to dispatch correctly onto the storage type of TT
30-
# AbstractBlockTensorMap doesn't have this TT field
31-
TensorKit.storagetype(::Type{<:BlockTensorMap{TT}}) where {TT <: AbstractTensorMap} = storagetype(TT)
32-
3329
function BlockTensorMap{TT, E, S, N₁, N₂, N}(
3430
::UndefInitializer, space::TensorMapSumSpace{S, N₁, N₂}
3531
) where {TT, E, S, N₁, N₂, N}

src/tensors/sparseblocktensor.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ function sparseblocktensormaptype(
6363
return SparseBlockTensorMap{TT}
6464
end
6565

66-
# seems necessary to dispatch correctly onto the storage type of TT
67-
# AbstractBlockTensorMap doesn't have this TT field
68-
TensorKit.storagetype(::Type{<:SparseBlockTensorMap{TT}}) where {TT <: AbstractTensorMap} = storagetype(TT)
69-
7066
# Constructors
7167
# ------------
7268
function SparseBlockTensorMap{TT}(

0 commit comments

Comments
 (0)