We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 193100a commit 714d73cCopy full SHA for 714d73c
1 file changed
grudge/array_context.py
@@ -427,6 +427,27 @@ def clone(self) -> Self:
427
428
# }}}
429
430
+# {{{ distributed + eager jax
431
+
432
+from arraycontext.impl.jax import EagerJAXArrayContext
433
434
+class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):
435
+ """An array context for using distributed computation with :mod:`jax`
436
+ eager evaluation.
437
438
+ .. autofunction:: __init__
439
+ """
440
441
+ def __init__(self, mpi_communicator) -> None:
442
+ super().__init__()
443
444
+ self.mpi_communicator = mpi_communicator
445
446
+ def clone(self) -> Self:
447
+ return type(self)(self.mpi_communicator)
448
449
+# }}}
450
451
452
# {{{ distributed + pytato array context subclasses
453
0 commit comments