116116
117117from abc import ABC , abstractmethod
118118from collections .abc import Callable , Hashable , Mapping
119+ import dataclasses
119120from typing import (
120121 TYPE_CHECKING ,
121122 Any ,
156157P = ParamSpec ("P" )
157158
158159
160+ @dataclasses .dataclass (frozen = True , eq = False , repr = False )
161+ class _CSRMatrix :
162+ # FIXME: Type for shape?
163+ shape : Any
164+ elem_values : Array
165+ elem_col_indices : Array
166+ row_starts : Array
167+ tags : frozenset [Tag ] = dataclasses .field (kw_only = True )
168+ axes : tuple [ToTagSetConvertible , ...] = dataclasses .field (kw_only = True )
169+ _matmul_func : Callable [[_CSRMatrix , Array ], Array ] = \
170+ dataclasses .field (kw_only = True )
171+
172+ def __matmul__ (self , other : Array ) -> Array :
173+ return self ._matmul_func (self , other )
174+
175+
159176# {{{ ArrayContext
160177
161178class ArrayContext (ABC ):
@@ -172,6 +189,8 @@ class ArrayContext(ABC):
172189 .. automethod:: to_numpy
173190 .. automethod:: call_loopy
174191 .. automethod:: einsum
192+ .. automethod:: make_csr_matrix
193+ .. automethod:: sparse_matmul
175194 .. attribute:: np
176195
177196 Provides access to a namespace that serves as a work-alike to
@@ -424,6 +443,172 @@ def einsum(self,
424443 )["out" ]
425444 return self .tag (tagged , out_ary )
426445
446+ # FIXME: Not sure what type annotations to use for shape and result
447+ def make_csr_matrix (
448+ self ,
449+ shape ,
450+ elem_values : Array ,
451+ elem_col_indices : Array ,
452+ row_starts : Array ,
453+ * ,
454+ tags : frozenset [Tag ] = frozenset (),
455+ axes : tuple [ToTagSetConvertible , ...] | None = None ):
456+ """Return a context-dependent object that represents a sparse matrix in
457+ compressed sparse row (CSR) format. Result is suitable for passing to
458+ :meth:`sparse_matmul`.
459+
460+ :arg shape: the (two-dimensional) shape of the matrix
461+ :arg elem_values: a one-dimensional array containing the values of all of the
462+ nonzero entries of the matrix, grouped by row.
463+ :arg elem_col_indices: a one-dimensional array containing the column index
464+ values corresponding to each entry in *elem_values*.
465+ :arg row_starts: a one-dimensional array of length `nrows+1`, where each entry
466+ gives the starting index in *elem_values* and *elem_col_indices* for the
467+ given row, with the last entry being equal to `nrows`.
468+ """
469+ return _CSRMatrix (
470+ shape , elem_values , elem_col_indices , row_starts ,
471+ tags = tags , axes = axes ,
472+ _matmul_func = lambda x1 , x2 : self .sparse_matmul (x1 , x2 ))
473+
474+ @memoize_method
475+ def _get_csr_matmul_prg (self , out_ndim : int ) -> loopy .TranslationUnit :
476+ import numpy as np
477+ import loopy as lp
478+
479+ out_extra_inames = tuple (f"i{ n } " for n in range (1 , out_ndim ))
480+ out_inames = tuple (["irow" , * out_extra_inames ])
481+ out_inames_set = frozenset (out_inames )
482+
483+ out_extra_shape_comp_names = tuple (f"n{ n } " for n in range (1 , out_ndim ))
484+ out_shape_comp_names = tuple (["nrows" , * out_extra_shape_comp_names ])
485+
486+ domains = []
487+ domains .append (
488+ "{ [" + "," .join (out_inames ) + "] : "
489+ + " and " .join (
490+ f"0 <= { iname } < { shape_comp_name } "
491+ for iname , shape_comp_name in zip (
492+ out_inames , out_shape_comp_names , strict = True ))
493+ + " }" )
494+ domains .append (
495+ "{ [iel] : iel_lbound <= iel < iel_ubound }" )
496+
497+ address_space = lp .AddressSpace .GLOBAL
498+ temporary_variables = {
499+ "iel_lbound" : lp .TemporaryVariable (
500+ "iel_lbound" ,
501+ shape = (),
502+ address_space = lp .AddressSpace .GLOBAL ,
503+ # FIXME: Need to do anything with tags?
504+ ),
505+ "iel_ubound" : lp .TemporaryVariable (
506+ "iel_ubound" ,
507+ shape = (),
508+ address_space = lp .AddressSpace .GLOBAL ,
509+ # FIXME: Need to do anything with tags?
510+ )}
511+
512+ from loopy .kernel .instruction import make_assignment
513+ from pymbolic import var
514+ # FIXME: Need tags for any of these?
515+ instructions = [
516+ make_assignment (
517+ (var (f"iel_lbound" ),),
518+ var ("row_starts" )[var ("irow" )],
519+ id = "insn0" ,
520+ within_inames = out_inames_set ),
521+ make_assignment (
522+ (var (f"iel_ubound" ),),
523+ var ("row_starts" )[var ("irow" ) + 1 ],
524+ id = "insn1" ,
525+ within_inames = out_inames_set ),
526+ make_assignment (
527+ (var ("out" )[tuple (var (iname ) for iname in out_inames )],),
528+ lp .Reduction (
529+ "sum" ,
530+ tuple ([var ("iel" )]),
531+ var ("elem_values" )[var ("iel" ),]
532+ * var ("array" )[(
533+ var ("elem_col_indices" )[var ("iel" ),],
534+ * (var (iname ) for iname in out_extra_inames ))]),
535+ id = "insn2" ,
536+ within_inames = out_inames_set ,
537+ depends_on = frozenset ({"insn0" , "insn1" }))]
538+
539+ from loopy .version import MOST_RECENT_LANGUAGE_VERSION
540+ from .loopy import _DEFAULT_LOOPY_OPTIONS
541+ knl = lp .make_kernel (
542+ domains = domains ,
543+ instructions = instructions ,
544+ temporary_variables = temporary_variables ,
545+ kernel_data = [
546+ lp .ValueArg ("nrows" , is_input = True ),
547+ lp .ValueArg ("ncols" , is_input = True ),
548+ lp .ValueArg ("nels" , is_input = True ),
549+ * (
550+ lp .ValueArg (shape_comp_name , is_input = True )
551+ for shape_comp_name in out_extra_shape_comp_names ),
552+ lp .GlobalArg ("elem_values" , shape = (var ("nels" ),), is_input = True ),
553+ lp .GlobalArg ("elem_col_indices" , shape = (var ("nels" ),), is_input = True ),
554+ lp .GlobalArg ("row_starts" , shape = lp .auto , is_input = True ),
555+ lp .GlobalArg (
556+ "array" ,
557+ shape = tuple ([
558+ var ("ncols" ),
559+ * (
560+ var (shape_comp_name )
561+ for shape_comp_name in out_extra_shape_comp_names )]),
562+ # order="C",
563+ is_input = True ),
564+ lp .GlobalArg (
565+ "out" ,
566+ shape = tuple ([
567+ var ("nrows" ),
568+ * (
569+ var (shape_comp_name )
570+ for shape_comp_name in out_extra_shape_comp_names )]),
571+ # order="C",
572+ is_input = False ),
573+ ...],
574+ name = "csr_matmul_kernel" ,
575+ lang_version = MOST_RECENT_LANGUAGE_VERSION ,
576+ options = _DEFAULT_LOOPY_OPTIONS ,
577+ default_order = lp .auto ,
578+ default_offset = lp .auto ,
579+ # FIXME: Need to do anything with tags?
580+ )
581+
582+ idx_dtype = knl .default_entrypoint .index_dtype
583+
584+ return lp .add_and_infer_dtypes (
585+ knl ,
586+ {
587+ "," .join ([
588+ "ncols" , "nrows" , "nels" ,
589+ * out_extra_shape_comp_names ]): idx_dtype ,
590+ "elem_values,array,out" : np .float64 ,
591+ "elem_col_indices,row_starts" : idx_dtype })
592+
593+ # FIXME: Not sure what type annotation to use for x1
594+ def sparse_matmul (self , x1 , x2 : Array ) -> Array :
595+ """Multiply a sparse matrix by an array.
596+
597+ :arg x1: the sparse matrix.
598+ :arg x2: the array.
599+ """
600+ if isinstance (x1 , _CSRMatrix ):
601+ prg = self ._get_csr_matmul_prg (x2 .ndim )
602+ out_ary = self .call_loopy (
603+ prg , elem_values = x1 .elem_values ,
604+ elem_col_indices = x1 .elem_col_indices ,
605+ row_starts = x1 .row_starts , array = x2 )["out" ]
606+ # FIXME
607+ # return self.tag(tagged, out_ary)
608+ return out_ary
609+ else :
610+ raise TypeError (f"unrecognized matrix type '{ type (x1 ).__name__ } '." )
611+
427612 @abstractmethod
428613 def clone (self ) -> Self :
429614 """If possible, return a version of *self* that is semantically
0 commit comments