|
1 | 1 | # Conversion |
2 | 2 | # ---------- |
3 | | -function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap) |
4 | | - S = spacetype(t) |
5 | | - N₁, N₂ = numout(t), numin(t) |
6 | | - cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces)) |
7 | | - dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces)) |
8 | | - tdst = similar(t, cod ← dom) |
9 | | - |
10 | | - issparse(t) && zerovector!(tdst) |
11 | 3 |
|
| 4 | +function _copy_subblocks!(tdst, tsrc) |
| 5 | + S = spacetype(tsrc) |
| 6 | + N₁, N₂ = numout(tsrc), numin(tsrc) |
12 | 7 | for ((f₁, f₂), arr) in subblocks(tdst) |
13 | 8 | blockax = ntuple(N₁ + N₂) do i |
14 | 9 | return if i <= N₁ |
15 | | - blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(t, i))) |
| 10 | + blockedrange(map(Base.Fix2(dim, f₁.uncoupled[i]), space(tsrc, i))) |
16 | 11 | else |
17 | | - blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(t, i)')) |
| 12 | + blockedrange(map(Base.Fix2(dim, f₂.uncoupled[i - N₁]), space(tsrc, i)')) |
18 | 13 | end |
19 | 14 | end |
20 | 15 |
|
21 | | - for (k, v) in nonzero_pairs(t) |
| 16 | + for (k, v) in nonzero_pairs(tsrc) |
22 | 17 | indices = getindex.(blockax, Block.(Tuple(k))) |
23 | 18 | arr_slice = arr[indices...] |
24 | 19 | # need to check for empty since fusion tree pair might not be present |
25 | 20 | isempty(arr_slice) || copy!(arr_slice, v[f₁, f₂]) |
26 | 21 | end |
27 | 22 | end |
| 23 | + return tdst |
| 24 | +end |
| 25 | + |
| 26 | +function Base.convert(::Type{TensorMap}, t::AbstractBlockTensorMap) |
| 27 | + S = spacetype(t) |
| 28 | + N₁, N₂ = numout(t), numin(t) |
| 29 | + cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces)) |
| 30 | + dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces)) |
| 31 | + tdst = similar(t, cod ← dom) |
| 32 | + |
| 33 | + issparse(t) && zerovector!(tdst) |
| 34 | + _copy_subblocks!(tdst, t) |
| 35 | + return tdst |
| 36 | +end |
| 37 | + |
| 38 | +function Base.convert(::Type{TT}, t::AbstractBlockTensorMap) where {TT <: TensorMap} |
| 39 | + S = spacetype(t) |
| 40 | + N₁, N₂ = numout(t), numin(t) |
| 41 | + cod = ProductSpace{S, N₁}(oplus.(codomain(t).spaces)) |
| 42 | + dom = ProductSpace{S, N₂}(oplus.(domain(t).spaces)) |
| 43 | + tdst = TK.TensorMapWithStorage{scalartype(TT), storagetype(TT)}(undef, cod ← dom) |
| 44 | + issparse(t) && zerovector!(tdst) |
28 | 45 |
|
| 46 | + _copy_subblocks!(tdst, t) |
29 | 47 | return tdst |
30 | 48 | end |
31 | | -# use subtype of TensorMap here to support CuTensorMap |
32 | | -# ROCTensorMap, etc. |
33 | | -Base.convert(::Type{<:TensorMap}, t::AbstractBlockTensorMap) = convert(TensorMap, t) |
34 | 49 |
|
35 | 50 | function Base.convert(::Type{TT}, t::AbstractTensorMap) where {TT <: AbstractBlockTensorMap} |
36 | 51 | t isa TT && return t |
|
0 commit comments