Skip to content

Commit 95efa1e

Browse files
committed
add sparse matrix interface
1 parent f51c185 commit 95efa1e

2 files changed

Lines changed: 236 additions & 1 deletion

File tree

arraycontext/context.py

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
"""
115115

116116

117+
import dataclasses
117118
from abc import ABC, abstractmethod
118119
from collections.abc import Callable, Hashable, Mapping
119120
from typing import (
@@ -138,7 +139,7 @@
138139
from numpy.typing import DTypeLike
139140

140141
import loopy
141-
from pytools.tag import ToTagSetConvertible
142+
from pytools.tag import Tag, ToTagSetConvertible
142143

143144
from .fake_numpy import BaseFakeNumpyNamespace
144145
from .typing import (
@@ -156,6 +157,22 @@
156157
P = ParamSpec("P")
157158

158159

160+
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
161+
class _CSRMatrix:
162+
# FIXME: Type for shape?
163+
shape: Any
164+
elem_values: Array
165+
elem_col_indices: Array
166+
row_starts: Array
167+
tags: frozenset[Tag] = dataclasses.field(kw_only=True)
168+
axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True)
169+
_matmul_func: Callable[[_CSRMatrix, Array], Array] = \
170+
dataclasses.field(kw_only=True)
171+
172+
def __matmul__(self, other: Array) -> Array:
173+
return self._matmul_func(self, other)
174+
175+
159176
# {{{ ArrayContext
160177

161178
class ArrayContext(ABC):
@@ -172,6 +189,8 @@ class ArrayContext(ABC):
172189
.. automethod:: to_numpy
173190
.. automethod:: call_loopy
174191
.. automethod:: einsum
192+
.. automethod:: make_csr_matrix
193+
.. automethod:: sparse_matmul
175194
.. attribute:: np
176195
177196
Provides access to a namespace that serves as a work-alike to
@@ -424,6 +443,177 @@ def einsum(self,
424443
)["out"]
425444
return self.tag(tagged, out_ary)
426445

446+
# FIXME: Not sure what type annotations to use for shape and result
447+
def make_csr_matrix(
448+
self,
449+
shape,
450+
elem_values: Array,
451+
elem_col_indices: Array,
452+
row_starts: Array,
453+
*,
454+
tags: ToTagSetConvertible = frozenset(),
455+
axes: tuple[ToTagSetConvertible, ...] | None = None) -> Any:
456+
"""Return a context-dependent object that represents a sparse matrix in
457+
compressed sparse row (CSR) format. Result is suitable for passing to
458+
:meth:`sparse_matmul`.
459+
460+
:arg shape: the (two-dimensional) shape of the matrix
461+
:arg elem_values: a one-dimensional array containing the values of all of the
462+
nonzero entries of the matrix, grouped by row.
463+
:arg elem_col_indices: a one-dimensional array containing the column index
464+
values corresponding to each entry in *elem_values*.
465+
:arg row_starts: a one-dimensional array of length `nrows+1`, where each entry
466+
gives the starting index in *elem_values* and *elem_col_indices* for the
467+
given row, with the last entry being equal to `nrows`.
468+
"""
469+
if axes is None:
470+
axes = (frozenset(), frozenset())
471+
472+
return _CSRMatrix(
473+
shape, elem_values, elem_col_indices, row_starts,
474+
tags=tags, axes=axes,
475+
_matmul_func=lambda x1, x2: self.sparse_matmul(x1, x2))
476+
477+
@memoize_method
478+
def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
479+
import numpy as np
480+
481+
import loopy as lp
482+
483+
out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim))
484+
out_inames = ("irow", *out_extra_inames)
485+
out_inames_set = frozenset(out_inames)
486+
487+
out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim))
488+
out_shape_comp_names = ("nrows", *out_extra_shape_comp_names)
489+
490+
domains: list[str] = []
491+
domains.append(
492+
"{ [" + ",".join(out_inames) + "] : "
493+
+ " and ".join(
494+
f"0 <= {iname} < {shape_comp_name}"
495+
for iname, shape_comp_name in zip(
496+
out_inames, out_shape_comp_names, strict=True))
497+
+ " }")
498+
domains.append(
499+
"{ [iel] : iel_lbound <= iel < iel_ubound }")
500+
501+
temporary_variables: Mapping[str, lp.TemporaryVariable] = {
502+
"iel_lbound": lp.TemporaryVariable(
503+
"iel_lbound",
504+
shape=(),
505+
address_space=lp.AddressSpace.GLOBAL,
506+
# FIXME: Need to do anything with tags?
507+
),
508+
"iel_ubound": lp.TemporaryVariable(
509+
"iel_ubound",
510+
shape=(),
511+
address_space=lp.AddressSpace.GLOBAL,
512+
# FIXME: Need to do anything with tags?
513+
)}
514+
515+
from loopy.kernel.instruction import make_assignment
516+
from pymbolic import var
517+
# FIXME: Need tags for any of these?
518+
instructions: list[lp.Assignment | lp.CallInstruction] = [
519+
make_assignment(
520+
(var("iel_lbound"),),
521+
var("row_starts")[var("irow")],
522+
id="insn0",
523+
within_inames=out_inames_set),
524+
make_assignment(
525+
(var("iel_ubound"),),
526+
var("row_starts")[var("irow") + 1],
527+
id="insn1",
528+
within_inames=out_inames_set),
529+
make_assignment(
530+
(var("out")[tuple(var(iname) for iname in out_inames)],),
531+
lp.Reduction(
532+
"sum",
533+
(var("iel"),),
534+
var("elem_values")[var("iel"),]
535+
* var("array")[(
536+
var("elem_col_indices")[var("iel"),],
537+
*(var(iname) for iname in out_extra_inames))]),
538+
id="insn2",
539+
within_inames=out_inames_set,
540+
depends_on=frozenset({"insn0", "insn1"}))]
541+
542+
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
543+
544+
from .loopy import _DEFAULT_LOOPY_OPTIONS
545+
546+
knl = lp.make_kernel(
547+
domains=domains,
548+
instructions=instructions,
549+
temporary_variables=temporary_variables,
550+
kernel_data=[
551+
lp.ValueArg("nrows", is_input=True),
552+
lp.ValueArg("ncols", is_input=True),
553+
lp.ValueArg("nels", is_input=True),
554+
*(
555+
lp.ValueArg(shape_comp_name, is_input=True)
556+
for shape_comp_name in out_extra_shape_comp_names),
557+
lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True),
558+
lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True),
559+
lp.GlobalArg("row_starts", shape=lp.auto, is_input=True),
560+
lp.GlobalArg(
561+
"array",
562+
shape=(
563+
var("ncols"),
564+
*(
565+
var(shape_comp_name)
566+
for shape_comp_name in out_extra_shape_comp_names),),
567+
# order="C",
568+
is_input=True),
569+
lp.GlobalArg(
570+
"out",
571+
shape=(
572+
var("nrows"),
573+
*(
574+
var(shape_comp_name)
575+
for shape_comp_name in out_extra_shape_comp_names),),
576+
# order="C",
577+
is_input=False),
578+
...],
579+
name="csr_matmul_kernel",
580+
lang_version=MOST_RECENT_LANGUAGE_VERSION,
581+
options=_DEFAULT_LOOPY_OPTIONS,
582+
default_order=lp.auto,
583+
default_offset=lp.auto,
584+
# FIXME: Need to do anything with tags?
585+
)
586+
587+
idx_dtype = knl.default_entrypoint.index_dtype
588+
589+
return lp.add_and_infer_dtypes(
590+
knl,
591+
{
592+
",".join([
593+
"ncols", "nrows", "nels",
594+
*out_extra_shape_comp_names]): idx_dtype,
595+
"elem_values,array,out": np.float64,
596+
"elem_col_indices,row_starts": idx_dtype})
597+
598+
# FIXME: Not sure what type annotation to use for x1
599+
def sparse_matmul(self, x1, x2: Array) -> Array:
600+
"""Multiply a sparse matrix by an array.
601+
602+
:arg x1: the sparse matrix.
603+
:arg x2: the array.
604+
"""
605+
if isinstance(x1, _CSRMatrix):
606+
prg = self._get_csr_matmul_prg(len(x2.shape))
607+
out_ary = self.call_loopy(
608+
prg, elem_values=x1.elem_values,
609+
elem_col_indices=x1.elem_col_indices,
610+
row_starts=x1.row_starts, array=x2)["out"]
611+
# FIXME
612+
# return self.tag(tagged, out_ary)
613+
return out_ary
614+
else:
615+
raise TypeError(f"unrecognized matrix type '{type(x1).__name__}'.")
616+
427617
@abstractmethod
428618
def clone(self) -> Self:
429619
"""If possible, return a version of *self* that is semantically

arraycontext/impl/pytato/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
833833
dag = pt.transform.materialize_with_mpms(dag)
834834
return dag
835835

836+
@override
836837
def einsum(self, spec, *args, arg_names=None, tagged=()):
837838
import pytato as pt
838839

@@ -876,6 +877,32 @@ def preprocess_arg(name, arg):
876877
for name, arg in zip(arg_names, args, strict=True)
877878
]).tagged(_preprocess_array_tags(tagged))
878879

880+
# FIXME: Not sure what type annotations to use for shape and result
881+
@override
882+
def make_csr_matrix(
883+
self,
884+
shape,
885+
elem_values: Array,
886+
elem_col_indices: Array,
887+
row_starts: Array,
888+
*,
889+
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
890+
axes: tuple[ToTagSetConvertible, ...] | None = None) -> Any:
891+
import pytato as pt
892+
assert isinstance(elem_values, pt.Array)
893+
assert isinstance(elem_col_indices, pt.Array)
894+
assert isinstance(row_starts, pt.Array)
895+
return pt.make_csr_matrix(
896+
shape, elem_values, elem_col_indices, row_starts,
897+
# FIXME: Do I need to call _preprocess_array_tags on axes?
898+
tags=_preprocess_array_tags(tags), axes=axes)
899+
900+
# FIXME: Not sure what type annotation to use for x1
901+
@override
902+
def sparse_matmul(self, x1, x2: Array) -> Array:
903+
import pytato as pt
904+
return pt.sparse_matmul(x1, x2)
905+
879906
def clone(self):
880907
return type(self)(self.queue, self.allocator)
881908

@@ -1115,6 +1142,24 @@ def preprocess_arg(name: str | None, arg: Array):
11151142
for name, arg in zip(arg_names, args, strict=True)
11161143
]).tagged(_preprocess_array_tags(tagged)))
11171144

1145+
# FIXME: Not sure what type annotations to use for shape and result
1146+
@override
1147+
def make_csr_matrix(
1148+
self,
1149+
shape,
1150+
elem_values: Array,
1151+
elem_col_indices: Array,
1152+
row_starts: Array,
1153+
*,
1154+
tags: ToTagSetConvertible = frozenset(),
1155+
axes: tuple[ToTagSetConvertible, ...] | None = None) -> Any:
1156+
raise NotImplementedError
1157+
1158+
# FIXME: Not sure what type annotation to use for x1
1159+
@override
1160+
def sparse_matmul(self, x1, x2: Array) -> Array:
1161+
raise NotImplementedError
1162+
11181163
@override
11191164
def clone(self):
11201165
return type(self)()

0 commit comments

Comments
 (0)