Skip to content

Commit d8a43b7

Browse files
add MPIPytatoJAXArrayContext
1 parent ef1b0b1 commit d8a43b7

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

grudge/array_context.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
.. autoclass:: MPIPyOpenCLArrayContext
66
.. autoclass:: MPINumpyArrayContext
77
.. class:: MPIPytatoArrayContext
8+
.. autoclass:: MPIEagerJAXArrayContext
9+
.. autoclass:: MPIPytatoJAXArrayContext
810
.. autofunction:: get_reasonable_array_context_class
911
"""
1012

@@ -76,14 +78,15 @@
7678
_HAVE_FUSION_ACTX = False
7779

7880

79-
from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext
81+
from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext, PytatoJAXArrayContext
8082
from arraycontext.container import ArrayContainer
8183
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
8284
from arraycontext.pytest import (
8385
_PytestEagerJaxArrayContextFactory,
8486
_PytestNumpyArrayContextFactory,
8587
_PytestPyOpenCLArrayContextFactoryWithClass,
8688
_PytestPytatoPyOpenCLArrayContextFactory,
89+
_PytestPytatoJaxArrayContextFactory,
8790
register_pytest_array_context_factory,
8891
)
8992

@@ -449,6 +452,26 @@ def clone(self) -> Self:
449452
# }}}
450453

451454

455+
# {{{ distributed + lazy jax
456+
457+
class MPIPytatoJAXArrayContext(PytatoJAXArrayContext, MPIBasedArrayContext):
458+
"""An array context for using distributed computation with :mod:`jax`
459+
lazy evaluation.
460+
461+
.. autofunction:: __init__
462+
"""
463+
464+
def __init__(self, mpi_communicator) -> None:
465+
super().__init__()
466+
467+
self.mpi_communicator = mpi_communicator
468+
469+
def clone(self) -> Self:
470+
return type(self)(self.mpi_communicator)
471+
472+
# }}}
473+
474+
452475
# {{{ distributed + pytato array context subclasses
453476

454477
class MPIBasePytatoPyOpenCLArrayContext(
@@ -551,6 +574,15 @@ def __call__(self):
551574
return self.actx_class()
552575

553576

577+
class PytestPytatoJAXArrayContextFactory(_PytestPytatoJaxArrayContextFactory):
578+
actx_class = PytatoJAXArrayContext
579+
580+
def __call__(self):
581+
import jax
582+
jax.config.update("jax_enable_x64", True)
583+
return self.actx_class()
584+
585+
554586
register_pytest_array_context_factory("grudge.pyopencl",
555587
PytestPyOpenCLArrayContextFactory)
556588
register_pytest_array_context_factory("grudge.pytato-pyopencl",
@@ -559,6 +591,8 @@ def __call__(self):
559591
PytestNumpyArrayContextFactory)
560592
register_pytest_array_context_factory("grudge.eager-jax",
561593
PytestEagerJAXArrayContextFactory)
594+
register_pytest_array_context_factory("grudge.lazy-jax",
595+
PytestPytatoJAXArrayContextFactory)
562596

563597
# }}}
564598

test/test_dt_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@
3131
PytestNumpyArrayContextFactory,
3232
PytestPyOpenCLArrayContextFactory,
3333
PytestPytatoPyOpenCLArrayContextFactory,
34+
PytestPytatoJAXArrayContextFactory,
3435
)
3536

3637

3738
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
3839
[PytestPyOpenCLArrayContextFactory,
3940
PytestPytatoPyOpenCLArrayContextFactory,
4041
PytestNumpyArrayContextFactory,
41-
PytestEagerJAXArrayContextFactory])
42+
PytestEagerJAXArrayContextFactory,
43+
PytestPytatoJAXArrayContextFactory])
4244

4345
import logging
4446

test/test_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
PytestNumpyArrayContextFactory,
3838
PytestPyOpenCLArrayContextFactory,
3939
PytestPytatoPyOpenCLArrayContextFactory,
40+
PytestPytatoJAXArrayContextFactory,
4041
)
4142
from grudge.discretization import make_discretization_collection
4243

@@ -46,7 +47,8 @@
4647
[PytestPyOpenCLArrayContextFactory,
4748
PytestPytatoPyOpenCLArrayContextFactory,
4849
PytestNumpyArrayContextFactory,
49-
PytestEagerJAXArrayContextFactory])
50+
PytestEagerJAXArrayContextFactory,
51+
PytestPytatoJAXArrayContextFactory])
5052

5153

5254
# {{{ inverse metric

0 commit comments

Comments
 (0)