Skip to content

Commit 73c43b1

Browse files
committed
add sparse matrix interface
1 parent f51c185 commit 73c43b1

2 files changed

Lines changed: 227 additions & 0 deletions

File tree

arraycontext/context.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116

117117
from abc import ABC, abstractmethod
118118
from collections.abc import Callable, Hashable, Mapping
119+
import dataclasses
119120
from typing import (
120121
TYPE_CHECKING,
121122
Any,
@@ -156,6 +157,22 @@
156157
P = ParamSpec("P")
157158

158159

160+
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
161+
class _CSRMatrix:
162+
# FIXME: Type for shape?
163+
shape: Any
164+
elem_values: Array
165+
elem_col_indices: Array
166+
row_starts: Array
167+
tags: frozenset[Tag] = dataclasses.field(kw_only=True)
168+
axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True)
169+
_matmul_func: Callable[[_CSRMatrix, Array], Array] = \
170+
dataclasses.field(kw_only=True)
171+
172+
def __matmul__(self, other: Array) -> Array:
173+
return self._matmul_func(self, other)
174+
175+
159176
# {{{ ArrayContext
160177

161178
class ArrayContext(ABC):
@@ -172,6 +189,8 @@ class ArrayContext(ABC):
172189
.. automethod:: to_numpy
173190
.. automethod:: call_loopy
174191
.. automethod:: einsum
192+
.. automethod:: make_csr_matrix
193+
.. automethod:: sparse_matmul
175194
.. attribute:: np
176195
177196
Provides access to a namespace that serves as a work-alike to
@@ -424,6 +443,172 @@ def einsum(self,
424443
)["out"]
425444
return self.tag(tagged, out_ary)
426445

446+
# FIXME: Not sure what type annotations to use for shape and result
447+
def make_csr_matrix(
448+
self,
449+
shape,
450+
elem_values: Array,
451+
elem_col_indices: Array,
452+
row_starts: Array,
453+
*,
454+
tags: frozenset[Tag] = frozenset(),
455+
axes: tuple[ToTagSetConvertible, ...] | None = None):
456+
"""Return a context-dependent object that represents a sparse matrix in
457+
compressed sparse row (CSR) format. Result is suitable for passing to
458+
:meth:`sparse_matmul`.
459+
460+
:arg shape: the (two-dimensional) shape of the matrix
461+
:arg elem_values: a one-dimensional array containing the values of all of the
462+
nonzero entries of the matrix, grouped by row.
463+
:arg elem_col_indices: a one-dimensional array containing the column index
464+
values corresponding to each entry in *elem_values*.
465+
:arg row_starts: a one-dimensional array of length `nrows+1`, where each entry
466+
gives the starting index in *elem_values* and *elem_col_indices* for the
467+
given row, with the last entry being equal to `nrows`.
468+
"""
469+
return _CSRMatrix(
470+
shape, elem_values, elem_col_indices, row_starts,
471+
tags=tags, axes=axes,
472+
_matmul_func=lambda x1, x2: self.sparse_matmul(x1, x2))
473+
474+
@memoize_method
475+
def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
476+
import numpy as np
477+
import loopy as lp
478+
479+
out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim))
480+
out_inames = tuple(["irow", *out_extra_inames])
481+
out_inames_set = frozenset(out_inames)
482+
483+
out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim))
484+
out_shape_comp_names = tuple(["nrows", *out_extra_shape_comp_names])
485+
486+
domains = []
487+
domains.append(
488+
"{ [" + ",".join(out_inames) + "] : "
489+
+ " and ".join(
490+
f"0 <= {iname} < {shape_comp_name}"
491+
for iname, shape_comp_name in zip(
492+
out_inames, out_shape_comp_names, strict=True))
493+
+ " }")
494+
domains.append(
495+
"{ [iel] : iel_lbound <= iel < iel_ubound }")
496+
497+
address_space = lp.AddressSpace.GLOBAL
498+
temporary_variables = {
499+
"iel_lbound": lp.TemporaryVariable(
500+
"iel_lbound",
501+
shape=(),
502+
address_space=lp.AddressSpace.GLOBAL,
503+
# FIXME: Need to do anything with tags?
504+
),
505+
"iel_ubound": lp.TemporaryVariable(
506+
"iel_ubound",
507+
shape=(),
508+
address_space=lp.AddressSpace.GLOBAL,
509+
# FIXME: Need to do anything with tags?
510+
)}
511+
512+
from loopy.kernel.instruction import make_assignment
513+
from pymbolic import var
514+
# FIXME: Need tags for any of these?
515+
instructions = [
516+
make_assignment(
517+
(var(f"iel_lbound"),),
518+
var("row_starts")[var("irow")],
519+
id="insn0",
520+
within_inames=out_inames_set),
521+
make_assignment(
522+
(var(f"iel_ubound"),),
523+
var("row_starts")[var("irow") + 1],
524+
id="insn1",
525+
within_inames=out_inames_set),
526+
make_assignment(
527+
(var("out")[tuple(var(iname) for iname in out_inames)],),
528+
lp.Reduction(
529+
"sum",
530+
tuple([var("iel")]),
531+
var("elem_values")[var("iel"),]
532+
* var("array")[(
533+
var("elem_col_indices")[var("iel"),],
534+
*(var(iname) for iname in out_extra_inames))]),
535+
id="insn2",
536+
within_inames=out_inames_set,
537+
depends_on=frozenset({"insn0", "insn1"}))]
538+
539+
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
540+
from .loopy import _DEFAULT_LOOPY_OPTIONS
541+
knl = lp.make_kernel(
542+
domains=domains,
543+
instructions=instructions,
544+
temporary_variables=temporary_variables,
545+
kernel_data=[
546+
lp.ValueArg("nrows", is_input=True),
547+
lp.ValueArg("ncols", is_input=True),
548+
lp.ValueArg("nels", is_input=True),
549+
*(
550+
lp.ValueArg(shape_comp_name, is_input=True)
551+
for shape_comp_name in out_extra_shape_comp_names),
552+
lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True),
553+
lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True),
554+
lp.GlobalArg("row_starts", shape=lp.auto, is_input=True),
555+
lp.GlobalArg(
556+
"array",
557+
shape=tuple([
558+
var("ncols"),
559+
*(
560+
var(shape_comp_name)
561+
for shape_comp_name in out_extra_shape_comp_names)]),
562+
# order="C",
563+
is_input=True),
564+
lp.GlobalArg(
565+
"out",
566+
shape=tuple([
567+
var("nrows"),
568+
*(
569+
var(shape_comp_name)
570+
for shape_comp_name in out_extra_shape_comp_names)]),
571+
# order="C",
572+
is_input=False),
573+
...],
574+
name="csr_matmul_kernel",
575+
lang_version=MOST_RECENT_LANGUAGE_VERSION,
576+
options=_DEFAULT_LOOPY_OPTIONS,
577+
default_order=lp.auto,
578+
default_offset=lp.auto,
579+
# FIXME: Need to do anything with tags?
580+
)
581+
582+
idx_dtype = knl.default_entrypoint.index_dtype
583+
584+
return lp.add_and_infer_dtypes(
585+
knl,
586+
{
587+
",".join([
588+
"ncols", "nrows", "nels",
589+
*out_extra_shape_comp_names]): idx_dtype,
590+
"elem_values,array,out": np.float64,
591+
"elem_col_indices,row_starts": idx_dtype})
592+
593+
# FIXME: Not sure what type annotation to use for x1
594+
def sparse_matmul(self, x1, x2: Array) -> Array:
595+
"""Multiply a sparse matrix by an array.
596+
597+
:arg x1: the sparse matrix.
598+
:arg x2: the array.
599+
"""
600+
if isinstance(x1, _CSRMatrix):
601+
prg = self._get_csr_matmul_prg(x2.ndim)
602+
out_ary = self.call_loopy(
603+
prg, elem_values=x1.elem_values,
604+
elem_col_indices=x1.elem_col_indices,
605+
row_starts=x1.row_starts, array=x2)["out"]
606+
# FIXME
607+
# return self.tag(tagged, out_ary)
608+
return out_ary
609+
else:
610+
raise TypeError(f"unrecognized matrix type '{type(x1).__name__}'.")
611+
427612
@abstractmethod
428613
def clone(self) -> Self:
429614
"""If possible, return a version of *self* that is semantically

arraycontext/impl/pytato/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
833833
dag = pt.transform.materialize_with_mpms(dag)
834834
return dag
835835

836+
@override
836837
def einsum(self, spec, *args, arg_names=None, tagged=()):
837838
import pytato as pt
838839

@@ -876,6 +877,29 @@ def preprocess_arg(name, arg):
876877
for name, arg in zip(arg_names, args, strict=True)
877878
]).tagged(_preprocess_array_tags(tagged))
878879

880+
# FIXME: Not sure what type annotations to use for shape and result
881+
@override
882+
def make_csr_matrix(
883+
self,
884+
shape,
885+
elem_values: Array,
886+
elem_col_indices: Array,
887+
row_starts: Array,
888+
*,
889+
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
890+
axes: tuple[ToTagSetConvertible, ...] | None = None):
891+
import pytato as pt
892+
return pt.make_csr_matrix(
893+
shape, elem_values, elem_col_indices, row_starts,
894+
# FIXME: Do I need to call _preprocess_array_tags on axes?
895+
tags=_preprocess_array_tags(tags), axes=axes)
896+
897+
# FIXME: Not sure what type annotation to use for x1
898+
@override
899+
def sparse_matmul(self, x1, x2: Array) -> Array:
900+
import pytato as pt
901+
return pt.sparse_matmul(x1, x2)
902+
879903
def clone(self):
880904
return type(self)(self.queue, self.allocator)
881905

@@ -1115,6 +1139,24 @@ def preprocess_arg(name: str | None, arg: Array):
11151139
for name, arg in zip(arg_names, args, strict=True)
11161140
]).tagged(_preprocess_array_tags(tagged)))
11171141

1142+
# FIXME: Not sure what type annotations to use for shape and result
1143+
@override
1144+
def make_csr_matrix(
1145+
self,
1146+
shape,
1147+
elem_values: Array,
1148+
elem_col_indices: Array,
1149+
row_starts: Array,
1150+
*,
1151+
tags: frozenset[Tag] = frozenset(),
1152+
axes: tuple[ToTagSetConvertible, ...] | None = None):
1153+
raise NotImplementedError
1154+
1155+
# FIXME: Not sure what type annotation to use for x1
1156+
@override
1157+
def sparse_matmul(self, x1, x2: Array) -> Array:
1158+
raise NotImplementedError
1159+
11181160
@override
11191161
def clone(self):
11201162
return type(self)()

0 commit comments

Comments
 (0)