55.. autoclass:: MPIPyOpenCLArrayContext
66.. autoclass:: MPINumpyArrayContext
77.. class:: MPIPytatoArrayContext
8+ .. autoclass:: MPIEagerJAXArrayContext
9+ .. autoclass:: MPIPytatoJAXArrayContext
810.. autofunction:: get_reasonable_array_context_class
911"""
1012
7678 _HAVE_FUSION_ACTX = False
7779
7880
79- from arraycontext import ArrayContext , EagerJAXArrayContext , NumpyArrayContext
81+ from arraycontext import (
82+ ArrayContext ,
83+ EagerJAXArrayContext ,
84+ NumpyArrayContext ,
85+ PytatoJAXArrayContext ,
86+ )
8087from arraycontext .container import ArrayContainer
8188from arraycontext .impl .pytato .compile import LazilyPyOpenCLCompilingFunctionCaller
8289from arraycontext .pytest import (
8390 _PytestEagerJaxArrayContextFactory ,
8491 _PytestNumpyArrayContextFactory ,
8592 _PytestPyOpenCLArrayContextFactoryWithClass ,
93+ _PytestPytatoJaxArrayContextFactory ,
8694 _PytestPytatoPyOpenCLArrayContextFactory ,
8795 register_pytest_array_context_factory ,
8896)
@@ -449,6 +457,26 @@ def clone(self) -> Self:
449457# }}}
450458
451459
460+ # {{{ distributed + lazy jax
461+
462+ class MPIPytatoJAXArrayContext (PytatoJAXArrayContext , MPIBasedArrayContext ):
463+ """An array context for using distributed computation with :mod:`jax`
464+ lazy evaluation.
465+
466+ .. autofunction:: __init__
467+ """
468+
469+ def __init__ (self , mpi_communicator ) -> None :
470+ super ().__init__ ()
471+
472+ self .mpi_communicator = mpi_communicator
473+
474+ def clone (self ) -> Self :
475+ return type (self )(self .mpi_communicator )
476+
477+ # }}}
478+
479+
452480# {{{ distributed + pytato array context subclasses
453481
454482class MPIBasePytatoPyOpenCLArrayContext (
@@ -551,6 +579,15 @@ def __call__(self):
551579 return self .actx_class ()
552580
553581
582+ class PytestPytatoJAXArrayContextFactory (_PytestPytatoJaxArrayContextFactory ):
583+ actx_class = PytatoJAXArrayContext
584+
585+ def __call__ (self ):
586+ import jax
587+ jax .config .update ("jax_enable_x64" , True )
588+ return self .actx_class ()
589+
590+
554591register_pytest_array_context_factory ("grudge.pyopencl" ,
555592 PytestPyOpenCLArrayContextFactory )
556593register_pytest_array_context_factory ("grudge.pytato-pyopencl" ,
@@ -559,6 +596,8 @@ def __call__(self):
559596 PytestNumpyArrayContextFactory )
560597register_pytest_array_context_factory ("grudge.eager-jax" ,
561598 PytestEagerJAXArrayContextFactory )
599+ register_pytest_array_context_factory ("grudge.lazy-jax" ,
600+ PytestPytatoJAXArrayContextFactory )
562601
563602# }}}
564603
0 commit comments