Skip to content

Commit 0ea000f

Browse files
add MPIEagerJaxArrayContext
1 parent 193100a commit 0ea000f

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

grudge/array_context.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,29 @@ def clone(self) -> Self:
427427

428428
# }}}
429429

430+
# {{{ distributed + eager jax
431+
432+
433+
from arraycontext.impl.jax import EagerJAXArrayContext
434+
435+
436+
class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):
437+
"""An array context for using distributed computation with :mod:`jax`
438+
eager evaluation.
439+
440+
.. autofunction:: __init__
441+
"""
442+
443+
def __init__(self, mpi_communicator) -> None:
444+
super().__init__()
445+
446+
self.mpi_communicator = mpi_communicator
447+
448+
def clone(self) -> Self:
449+
return type(self)(self.mpi_communicator)
450+
451+
# }}}
452+
430453

431454
# {{{ distributed + pytato array context subclasses
432455

0 commit comments

Comments
 (0)