-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathBlockSparseArraysTensorAlgebraExt.jl
More file actions
60 lines (54 loc) · 2.18 KB
/
BlockSparseArraysTensorAlgebraExt.jl
File metadata and controls
60 lines (54 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
module BlockSparseArraysTensorAlgebraExt
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes
struct BlockReshapeFusion <: FusionStyle end
function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray})
return BlockReshapeFusion()
end
using BlockArrays: Block, blocklength, blocks
using BlockSparseArrays: blocksparse
using SparseArraysBase: eachstoredindex
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
function TensorAlgebra.matricize(
::BlockReshapeFusion, a::AbstractArray, length1::Val, length2::Val
)
ax = fuseaxes(axes(a), length1, length2)
reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax))
key(I) = Block(Tuple(I))
value(I) = matricize(reshaped_blocks_a[I], length1, length2)
Is = eachstoredindex(reshaped_blocks_a)
bs = if isempty(Is)
# Catch empty case and make sure the type is constrained properly.
# This seems to only be necessary in Julia versions below v1.11,
# try removing it when we drop support for those versions.
keytype = Base.promote_op(key, eltype(Is))
valtype = Base.promote_op(value, eltype(Is))
valtype′ = !isconcretetype(valtype) ? AbstractMatrix{eltype(a)} : valtype
Dict{keytype, valtype′}()
else
Dict(key(I) => value(I) for I in Is)
end
return blocksparse(bs, ax)
end
using BlockArrays: blocklengths
function TensorAlgebra.unmatricize(
::BlockReshapeFusion,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
ax = (codomain_axes..., domain_axes...)
reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax))
function f(I)
block_axes_I = BlockedTuple(
map(ntuple(identity, length(ax))) do i
return Base.axes1(ax[i][Block(I[i])])
end,
(length(codomain_axes), length(domain_axes)),
)
return unmatricize(reshaped_blocks_m[I], block_axes_I)
end
bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_m))
return blocksparse(bs, ax)
end
end