We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 193100a commit 0ea000fCopy full SHA for 0ea000f
1 file changed
grudge/array_context.py
@@ -427,6 +427,29 @@ def clone(self) -> Self:
427
428
# }}}
429
430
+# {{{ distributed + eager jax
431
+
432
433
+from arraycontext.impl.jax import EagerJAXArrayContext
434
435
436
+class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):
437
+ """An array context for using distributed computation with :mod:`jax`
438
+ eager evaluation.
439
440
+ .. autofunction:: __init__
441
+ """
442
443
+ def __init__(self, mpi_communicator) -> None:
444
+ super().__init__()
445
446
+ self.mpi_communicator = mpi_communicator
447
448
+ def clone(self) -> Self:
449
+ return type(self)(self.mpi_communicator)
450
451
+# }}}
452
453
454
# {{{ distributed + pytato array context subclasses
455
0 commit comments