Skip to content

Commit 714d73c

Browse files
add MPIEagerJaxArrayContext
1 parent 193100a commit 714d73c

1 file changed

Lines changed: 21 additions & 0 deletions

File tree

grudge/array_context.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,27 @@ def clone(self) -> Self:
427427

428428
# }}}
429429

430+
# {{{ distributed + eager jax
431+
432+
from arraycontext.impl.jax import EagerJAXArrayContext
433+
434+
class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):
435+
"""An array context for using distributed computation with :mod:`jax`
436+
eager evaluation.
437+
438+
.. autofunction:: __init__
439+
"""
440+
441+
def __init__(self, mpi_communicator) -> None:
442+
super().__init__()
443+
444+
self.mpi_communicator = mpi_communicator
445+
446+
def clone(self) -> Self:
447+
return type(self)(self.mpi_communicator)
448+
449+
# }}}
450+
430451

431452
# {{{ distributed + pytato array context subclasses
432453

0 commit comments

Comments
 (0)