Skip to content

Commit 408ce43

Browse files
add MPIPytatoJAXArrayContext
1 parent ef1b0b1 commit 408ce43

3 files changed

Lines changed: 46 additions & 3 deletions

File tree

grudge/array_context.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
.. autoclass:: MPIPyOpenCLArrayContext
66
.. autoclass:: MPINumpyArrayContext
77
.. class:: MPIPytatoArrayContext
8+
.. autoclass:: MPIEagerJAXArrayContext
9+
.. autoclass:: MPIPytatoJAXArrayContext
810
.. autofunction:: get_reasonable_array_context_class
911
"""
1012

@@ -76,13 +78,19 @@
7678
_HAVE_FUSION_ACTX = False
7779

7880

79-
from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext
81+
from arraycontext import (
82+
ArrayContext,
83+
EagerJAXArrayContext,
84+
NumpyArrayContext,
85+
PytatoJAXArrayContext,
86+
)
8087
from arraycontext.container import ArrayContainer
8188
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
8289
from arraycontext.pytest import (
8390
_PytestEagerJaxArrayContextFactory,
8491
_PytestNumpyArrayContextFactory,
8592
_PytestPyOpenCLArrayContextFactoryWithClass,
93+
_PytestPytatoJaxArrayContextFactory,
8694
_PytestPytatoPyOpenCLArrayContextFactory,
8795
register_pytest_array_context_factory,
8896
)
@@ -449,6 +457,26 @@ def clone(self) -> Self:
449457
# }}}
450458

451459

460+
# {{{ distributed + lazy jax
461+
462+
class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext):
463+
"""An array context for using distributed computation with :mod:`jax`
464+
lazy evaluation.
465+
466+
.. autofunction:: __init__
467+
"""
468+
469+
def __init__(self, mpi_communicator) -> None:
470+
super().__init__()
471+
472+
self.mpi_communicator = mpi_communicator
473+
474+
def clone(self) -> Self:
475+
return type(self)(self.mpi_communicator)
476+
477+
# }}}
478+
479+
452480
# {{{ distributed + pytato array context subclasses
453481

454482
class MPIBasePytatoPyOpenCLArrayContext(
@@ -551,6 +579,15 @@ def __call__(self):
551579
return self.actx_class()
552580

553581

582+
class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory):
583+
actx_class = PytatoJAXArrayContext
584+
585+
def __call__(self):
586+
import jax
587+
jax.config.update("jax_enable_x64", True)
588+
return self.actx_class()
589+
590+
554591
register_pytest_array_context_factory("grudge.pyopencl",
555592
PytestPyOpenCLArrayContextFactory)
556593
register_pytest_array_context_factory("grudge.pytato-pyopencl",
@@ -559,6 +596,8 @@ def __call__(self):
559596
PytestNumpyArrayContextFactory)
560597
register_pytest_array_context_factory("grudge.eager-jax",
561598
PytestEagerJAXArrayContextFactory)
599+
register_pytest_array_context_factory("grudge.lazy-jax",
600+
PytestPytatoJAXArrayContextFactory)
562601

563602
# }}}
564603

test/test_dt_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
PytestEagerJAXArrayContextFactory,
3131
PytestNumpyArrayContextFactory,
3232
PytestPyOpenCLArrayContextFactory,
33+
PytestPytatoJAXArrayContextFactory,
3334
PytestPytatoPyOpenCLArrayContextFactory,
3435
)
3536

@@ -38,7 +39,8 @@
3839
[PytestPyOpenCLArrayContextFactory,
3940
PytestPytatoPyOpenCLArrayContextFactory,
4041
PytestNumpyArrayContextFactory,
41-
PytestEagerJAXArrayContextFactory])
42+
PytestEagerJAXArrayContextFactory,
43+
PytestPytatoJAXArrayContextFactory])
4244

4345
import logging
4446

test/test_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PytestEagerJAXArrayContextFactory,
3737
PytestNumpyArrayContextFactory,
3838
PytestPyOpenCLArrayContextFactory,
39+
PytestPytatoJAXArrayContextFactory,
3940
PytestPytatoPyOpenCLArrayContextFactory,
4041
)
4142
from grudge.discretization import make_discretization_collection
@@ -46,7 +47,8 @@
4647
[PytestPyOpenCLArrayContextFactory,
4748
PytestPytatoPyOpenCLArrayContextFactory,
4849
PytestNumpyArrayContextFactory,
49-
PytestEagerJAXArrayContextFactory])
50+
PytestEagerJAXArrayContextFactory,
51+
PytestPytatoJAXArrayContextFactory])
5052

5153

5254
# {{{ inverse metric

0 commit comments

Comments
 (0)