Skip to content

Commit 80813d7

Browse files
alexfiklinducer
authored andcommitted
filter out unavailable pytest actx factories
* adds an is_available to check is the array context works * do not create a factory per CL device, since that does not quite work for non-CL array contexts
1 parent 7f3fc73 commit 80813d7

File tree

2 files changed

+79
-24
lines changed

2 files changed

+79
-24
lines changed

arraycontext/pytest.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,18 @@
3434

3535
from typing import Any, Callable, Dict, Sequence, Type, Union
3636

37-
import pyopencl as cl
3837
from arraycontext.context import ArrayContext
3938

4039

4140
# {{{ array context factories
4241

4342
class PytestArrayContextFactory:
44-
pass
43+
@classmethod
44+
def is_available(cls) -> bool:
45+
return True
46+
47+
def __call__(self) -> ArrayContext:
48+
raise NotImplementedError
4549

4650

4751
class PytestPyOpenCLArrayContextFactory(PytestArrayContextFactory):
@@ -56,6 +60,14 @@ def __init__(self, device):
5660
"""
5761
self.device = device
5862

63+
@classmethod
64+
def is_available(cls) -> bool:
65+
try:
66+
import pyopencl # noqa: F401
67+
return True
68+
except ImportError:
69+
return False
70+
5971
def get_command_queue(self):
6072
# Get rid of leftovers from past tests.
6173
# CL implementations are surprisingly limited in how many
@@ -66,14 +78,12 @@ def get_command_queue(self):
6678
from gc import collect
6779
collect()
6880

81+
import pyopencl as cl
6982
# On Intel CPU CL, existence of a command queue does not ensure that
7083
# the context survives.
7184
ctx = cl.Context([self.device])
7285
return ctx, cl.CommandQueue(ctx)
7386

74-
def __call__(self) -> ArrayContext:
75-
raise NotImplementedError
76-
7787

7888
class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
7989
force_device_scalars = True
@@ -107,8 +117,15 @@ class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
107117
force_device_scalars = False
108118

109119

110-
class _PytestPytatoPyOpenCLArrayContextFactory(
111-
PytestPyOpenCLArrayContextFactory):
120+
class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
121+
@classmethod
122+
def is_available(cls) -> bool:
123+
try:
124+
import pyopencl # noqa: F401
125+
import pytato # noqa: F401
126+
return True
127+
except ImportError:
128+
return False
112129

113130
@property
114131
def actx_class(self):
@@ -137,6 +154,14 @@ class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory):
137154
def __init__(self, *args, **kwargs):
138155
pass
139156

157+
@classmethod
158+
def is_available(cls) -> bool:
159+
try:
160+
import jax # noqa: F401
161+
return True
162+
except ImportError:
163+
return False
164+
140165
def __call__(self):
141166
from arraycontext import EagerJAXArrayContext
142167
from jax.config import config
@@ -151,6 +176,15 @@ class _PytestPytatoJaxArrayContextFactory(PytestArrayContextFactory):
151176
def __init__(self, *args, **kwargs):
152177
pass
153178

179+
@classmethod
180+
def is_available(cls) -> bool:
181+
try:
182+
import jax # noqa: F401
183+
import pytato # noqa: F401
184+
return True
185+
except ImportError:
186+
return False
187+
154188
def __call__(self):
155189
from arraycontext import PytatoJAXArrayContext
156190
from jax.config import config
@@ -254,9 +288,19 @@ def pytest_generate_tests_for_array_contexts(
254288
else:
255289
raise ValueError(f"unknown array contexts: {unknown_factories}")
256290

257-
unique_factories = set([
258-
_ARRAY_CONTEXT_FACTORY_REGISTRY.get(factory, factory) # type: ignore[misc]
259-
for factory in unique_factories])
291+
available_factories = {
292+
factory for key in unique_factories
293+
for factory in [_ARRAY_CONTEXT_FACTORY_REGISTRY.get(key, key)]
294+
if (
295+
not isinstance(factory, str)
296+
and issubclass(factory, PytestArrayContextFactory)
297+
and factory.is_available())
298+
}
299+
300+
from pytools import partition
301+
pyopencl_factories, other_factories = partition(
302+
lambda factory: issubclass(factory, PytestPyOpenCLArrayContextFactory),
303+
available_factories)
260304

261305
# }}}
262306

@@ -271,6 +315,7 @@ def inner(metafunc):
271315
return
272316

273317
arg_values, ids = cl_tools.get_pyopencl_fixture_arg_values()
318+
empty_arg_dict = {k: None for k in arg_values[0]}
274319

275320
# }}}
276321

@@ -283,23 +328,29 @@ def inner(metafunc):
283328
"'ctx_factory' / 'ctx_getter' as arguments.")
284329

285330
arg_values_with_actx = []
286-
for arg_dict in arg_values:
331+
332+
if pyopencl_factories:
333+
for arg_dict in arg_values:
334+
arg_values_with_actx.extend([
335+
{factory_arg_name: factory(arg_dict["device"]), **arg_dict}
336+
for factory in pyopencl_factories
337+
])
338+
339+
if other_factories:
287340
arg_values_with_actx.extend([
288-
{factory_arg_name: factory(arg_dict["device"]), **arg_dict}
289-
for factory in unique_factories
341+
{factory_arg_name: factory(), **empty_arg_dict}
342+
for factory in other_factories
290343
])
291344
else:
292345
arg_values_with_actx = arg_values
293346

294-
arg_value_tuples = [
295-
tuple(arg_dict[name] for name in arg_names)
296-
for arg_dict in arg_values_with_actx
297-
]
298-
299347
# }}}
300348

301-
# Sort the actx's so that parallel pytest works
302-
arg_value_tuples = sorted(arg_value_tuples, key=lambda x: x.__str__())
349+
# NOTE: sorts the args so that parallel pytest works
350+
arg_value_tuples = sorted([
351+
tuple([arg_dict[name] for name in arg_names])
352+
for arg_dict in arg_values_with_actx
353+
], key=lambda x: str(x))
303354

304355
metafunc.parametrize(arg_names, arg_value_tuples, ids=ids)
305356

test/test_pytato_arraycontext.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory
2828
from pytools.tag import Tag
2929

30-
30+
import pytest
3131
import logging
3232
logger = logging.getLogger(__name__)
3333

@@ -79,10 +79,15 @@ class BazTag(Tag):
7979

8080

8181
def test_tags_preserved_after_freeze(actx_factory):
82+
actx = actx_factory()
83+
84+
from arraycontext.impl.pytato import _BasePytatoArrayContext
85+
if not isinstance(actx, _BasePytatoArrayContext):
86+
pytest.skip("only pytato-based array context are supported")
87+
8288
from numpy.random import default_rng
8389
rng = default_rng()
8490

85-
actx = actx_factory()
8691
foo = actx.thaw(actx.freeze(
8792
actx.from_numpy(rng.random((10, 4)))
8893
.tagged(FooTag())
@@ -100,7 +105,6 @@ def test_tags_preserved_after_freeze(actx_factory):
100105
if len(sys.argv) > 1:
101106
exec(sys.argv[1])
102107
else:
103-
from pytest import main
104-
main([__file__])
108+
pytest.main([__file__])
105109

106110
# vim: fdm=marker

0 commit comments

Comments
 (0)