3434
3535from typing import Any , Callable , Dict , Sequence , Type , Union
3636
37- import pyopencl as cl
3837from arraycontext .context import ArrayContext
3938
4039
4140# {{{ array context factories
4241
4342class PytestArrayContextFactory :
44- pass
43+ @classmethod
44+ def is_available (cls ) -> bool :
45+ return True
46+
47+ def __call__ (self ) -> ArrayContext :
48+ raise NotImplementedError
4549
4650
4751class PytestPyOpenCLArrayContextFactory (PytestArrayContextFactory ):
@@ -56,6 +60,14 @@ def __init__(self, device):
5660 """
5761 self .device = device
5862
63+ @classmethod
64+ def is_available (cls ) -> bool :
65+ try :
66+ import pyopencl # noqa: F401
67+ return True
68+ except ImportError :
69+ return False
70+
5971 def get_command_queue (self ):
6072 # Get rid of leftovers from past tests.
6173 # CL implementations are surprisingly limited in how many
@@ -66,14 +78,12 @@ def get_command_queue(self):
6678 from gc import collect
6779 collect ()
6880
81+ import pyopencl as cl
6982 # On Intel CPU CL, existence of a command queue does not ensure that
7083 # the context survives.
7184 ctx = cl .Context ([self .device ])
7285 return ctx , cl .CommandQueue (ctx )
7386
74- def __call__ (self ) -> ArrayContext :
75- raise NotImplementedError
76-
7787
7888class _PytestPyOpenCLArrayContextFactoryWithClass (PytestPyOpenCLArrayContextFactory ):
7989 force_device_scalars = True
@@ -107,8 +117,15 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
107117 force_device_scalars = False
108118
109119
110- class _PytestPytatoPyOpenCLArrayContextFactory (
111- PytestPyOpenCLArrayContextFactory ):
120+ class _PytestPytatoPyOpenCLArrayContextFactory (PytestPyOpenCLArrayContextFactory ):
121+ @classmethod
122+ def is_available (cls ) -> bool :
123+ try :
124+ import pyopencl # noqa: F401
125+ import pytato # noqa: F401
126+ return True
127+ except ImportError :
128+ return False
112129
113130 @property
114131 def actx_class (self ):
@@ -137,6 +154,14 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
137154 def __init__ (self , * args , ** kwargs ):
138155 pass
139156
157+ @classmethod
158+ def is_available (cls ) -> bool :
159+ try :
160+ import jax # noqa: F401
161+ return True
162+ except ImportError :
163+ return False
164+
140165 def __call__ (self ):
141166 from arraycontext import EagerJAXArrayContext
142167 from jax .config import config
@@ -151,6 +176,15 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory):
151176 def __init__ (self , * args , ** kwargs ):
152177 pass
153178
179+ @classmethod
180+ def is_available (cls ) -> bool :
181+ try :
182+ import jax # noqa: F401
183+ import pytato # noqa: F401
184+ return True
185+ except ImportError :
186+ return False
187+
154188 def __call__ (self ):
155189 from arraycontext import PytatoJAXArrayContext
156190 from jax .config import config
@@ -254,9 +288,19 @@ def pytest_generate_tests_for_array_contexts(
254288 else :
255289 raise ValueError (f"unknown array contexts: { unknown_factories } " )
256290
257- unique_factories = set ([
258- _ARRAY_CONTEXT_FACTORY_REGISTRY .get (factory , factory ) # type: ignore[misc]
259- for factory in unique_factories ])
291+ available_factories = {
292+ factory for key in unique_factories
293+ for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY .get (key , key )]
294+ if (
295+ not isinstance (factory , str )
296+ and issubclass (factory , PytestArrayContextFactory )
297+ and factory .is_available ())
298+ }
299+
300+ from pytools import partition
301+ pyopencl_factories , other_factories = partition (
302+ lambda factory : issubclass (factory , PytestPyOpenCLArrayContextFactory ),
303+ available_factories )
260304
261305 # }}}
262306
@@ -271,6 +315,7 @@ def inner(metafunc):
271315 return
272316
273317 arg_values , ids = cl_tools .get_pyopencl_fixture_arg_values ()
318+ empty_arg_dict = {k : None for k in arg_values [0 ]}
274319
275320 # }}}
276321
@@ -283,23 +328,29 @@ def inner(metafunc):
283328 "'ctx_factory' / 'ctx_getter' as arguments." )
284329
285330 arg_values_with_actx = []
286- for arg_dict in arg_values :
331+
332+ if pyopencl_factories :
333+ for arg_dict in arg_values :
334+ arg_values_with_actx .extend ([
335+ {factory_arg_name : factory (arg_dict ["device" ]), ** arg_dict }
336+ for factory in pyopencl_factories
337+ ])
338+
339+ if other_factories :
287340 arg_values_with_actx .extend ([
288- {factory_arg_name : factory (arg_dict [ "device" ] ), ** arg_dict }
289- for factory in unique_factories
341+ {factory_arg_name : factory (), ** empty_arg_dict }
342+ for factory in other_factories
290343 ])
291344 else :
292345 arg_values_with_actx = arg_values
293346
294- arg_value_tuples = [
295- tuple (arg_dict [name ] for name in arg_names )
296- for arg_dict in arg_values_with_actx
297- ]
298-
299347 # }}}
300348
301- # Sort the actx's so that parallel pytest works
302- arg_value_tuples = sorted (arg_value_tuples , key = lambda x : x .__str__ ())
349+ # NOTE: sorts the args so that parallel pytest works
350+ arg_value_tuples = sorted ([
351+ tuple ([arg_dict [name ] for name in arg_names ])
352+ for arg_dict in arg_values_with_actx
353+ ], key = lambda x : str (x ))
303354
304355 metafunc .parametrize (arg_names , arg_value_tuples , ids = ids )
305356
0 commit comments