55.. autoclass:: MPIPyOpenCLArrayContext
66.. autoclass:: MPINumpyArrayContext
77.. class:: MPIPytatoArrayContext
8+ .. autoclass:: MPIEagerJAXArrayContext
9+ .. autoclass:: MPIPytatoJAXArrayContext
810.. autofunction:: get_reasonable_array_context_class
911"""
1012
7678 _HAVE_FUSION_ACTX = False
7779
7880
79- from arraycontext import ArrayContext , EagerJAXArrayContext , NumpyArrayContext
81+ from arraycontext import ArrayContext , EagerJAXArrayContext , NumpyArrayContext , PytatoJAXArrayContext
8082from arraycontext .container import ArrayContainer
8183from arraycontext .impl .pytato .compile import LazilyPyOpenCLCompilingFunctionCaller
8284from arraycontext .pytest import (
8385 _PytestEagerJaxArrayContextFactory ,
8486 _PytestNumpyArrayContextFactory ,
8587 _PytestPyOpenCLArrayContextFactoryWithClass ,
8688 _PytestPytatoPyOpenCLArrayContextFactory ,
89+ _PytestPytatoJaxArrayContextFactory ,
8790 register_pytest_array_context_factory ,
8891)
8992
@@ -449,6 +452,26 @@ def clone(self) -> Self:
449452# }}}
450453
451454
455+ # {{{ distributed + lazy jax
456+
457+ class MPIPytatoJAXArrayContext (PytatoJAXArrayContext , MPIBasedArrayContext ):
458+ """An array context for using distributed computation with :mod:`jax`
459+ lazy evaluation.
460+
461+ .. autofunction:: __init__
462+ """
463+
464+ def __init__ (self , mpi_communicator ) -> None :
465+ super ().__init__ ()
466+
467+ self .mpi_communicator = mpi_communicator
468+
469+ def clone (self ) -> Self :
470+ return type (self )(self .mpi_communicator )
471+
472+ # }}}
473+
474+
452475# {{{ distributed + pytato array context subclasses
453476
454477class MPIBasePytatoPyOpenCLArrayContext (
@@ -551,6 +574,15 @@ def __call__(self):
551574 return self .actx_class ()
552575
553576
577+ class PytestPytatoJAXArrayContextFactory (_PytestPytatoJaxArrayContextFactory ):
578+ actx_class = PytatoJAXArrayContext
579+
580+ def __call__ (self ):
581+ import jax
582+ jax .config .update ("jax_enable_x64" , True )
583+ return self .actx_class ()
584+
585+
554586register_pytest_array_context_factory ("grudge.pyopencl" ,
555587 PytestPyOpenCLArrayContextFactory )
556588register_pytest_array_context_factory ("grudge.pytato-pyopencl" ,
@@ -559,6 +591,8 @@ def __call__(self):
559591 PytestNumpyArrayContextFactory )
560592register_pytest_array_context_factory ("grudge.eager-jax" ,
561593 PytestEagerJAXArrayContextFactory )
594+ register_pytest_array_context_factory ("grudge.lazy-jax" ,
595+ PytestPytatoJAXArrayContextFactory )
562596
563597# }}}
564598
0 commit comments