-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsparse_matrix.jl
More file actions
108 lines (89 loc) · 3.11 KB
/
sparse_matrix.jl
File metadata and controls
108 lines (89 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) 2019: Joaquim Dias Garcia, and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
abstract type AbstractIndexing end
struct ZeroBasedIndexing <: AbstractIndexing end
struct OneBasedIndexing <: AbstractIndexing end
first_index(::ZeroBasedIndexing) = 0
first_index(::OneBasedIndexing) = 1
shift(x, ::ZeroBasedIndexing, ::ZeroBasedIndexing) = x
shift(x::Integer, ::ZeroBasedIndexing, ::OneBasedIndexing) = x + 1
shift(x::Array{<:Integer}, ::ZeroBasedIndexing, ::OneBasedIndexing) = x .+ 1
shift(x::Integer, ::OneBasedIndexing, ::ZeroBasedIndexing) = x - 1
shift(x, ::OneBasedIndexing, ::OneBasedIndexing) = x
mutable struct SparseMatrixCSRtoCSC{Tv,Ti<:Integer,I<:AbstractIndexing}
indexing::I
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti}
rowval::Vector{Ti}
nzval::Vector{Tv}
function SparseMatrixCSRtoCSC{Tv,Ti,I}(n) where {Tv,Ti<:Integer,I}
A = new{Tv,Ti,I}()
A.n = n
A.colptr = zeros(Ti, n + 1)
return A
end
end
function allocate_nonzeros(A::SparseMatrixCSRtoCSC{Tv,Ti}) where {Tv,Ti}
for i in 3:length(A.colptr)
A.colptr[i] += A.colptr[i-1]
end
A.rowval = Vector{Ti}(undef, A.colptr[end])
A.nzval = Vector{Tv}(undef, A.colptr[end])
return
end
function final_touch(A::SparseMatrixCSRtoCSC)
for i in length(A.colptr):-1:2
A.colptr[i] = shift(A.colptr[i-1], ZeroBasedIndexing(), A.indexing)
end
A.colptr[1] = first_index(A.indexing)
return
end
function _allocate_terms(colptr, indexmap, terms)
for term in terms
colptr[indexmap[term.scalar_term.variable].value+1] += 1
end
return
end
function allocate_terms(A::SparseMatrixCSRtoCSC, indexmap, func)
return _allocate_terms(A.colptr, indexmap, func.terms)
end
function _load_terms(colptr, rowval, nzval, indexmap, terms, offset)
for term in terms
ptr = colptr[indexmap[term.scalar_term.variable].value] += 1
rowval[ptr] = offset + term.output_index
nzval[ptr] = -term.scalar_term.coefficient
end
return
end
function load_terms(A::SparseMatrixCSRtoCSC, indexmap, func, offset)
return _load_terms(
A.colptr,
A.rowval,
A.nzval,
indexmap,
func.terms,
shift(offset, OneBasedIndexing(), A.indexing),
)
end
"""
Base.convert(::Type{SparseMatrixCSC{Tv, Ti}}, A::SparseMatrixCSRtoCSC{Tv, Ti, I}) where {Tv, Ti, I}
Converts `A` to a `SparseMatrixCSC`. Note that the field `A.nzval` is **not
copied** so if `A` is modified after the call of this function, it can still
affect the value returned. Moreover, if `I` is `OneBasedIndexing`, `colptr`
and `rowval` are not copied either, i.e., the conversion is allocation-free.
"""
function Base.convert(
::Type{SparseMatrixCSC{Tv,Ti}},
A::SparseMatrixCSRtoCSC{Tv,Ti},
) where {Tv,Ti}
return SparseMatrixCSC{Tv,Ti}(
A.m,
A.n,
shift(A.colptr, A.indexing, OneBasedIndexing()),
shift(A.rowval, A.indexing, OneBasedIndexing()),
A.nzval,
)
end