diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index 9bacb912b..f544989d8 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -85,15 +85,25 @@ class LoopyPyOpenCLTarget(LoopyTarget): The :mod:`pyopencl` device used to construct the :class:`loopy.PyOpenCLTarget`, or *None*. + + .. attribute:: loopy_target + + An optional :class:`loopy.PyOpenCLTarget` that gets used as the + underlying code generation target. """ - def __init__(self, device: Optional["pyopencl.Device"] = None): + def __init__(self, device: Optional["pyopencl.Device"] = None, + loopy_target: Optional[loopy.PyOpenCLTarget] = None) -> None: if device is not None: from warnings import warn warn("Passing 'device' is deprecated and will stop working in 2023.", DeprecationWarning, stacklevel=2) + self.loopy_target = loopy_target def get_loopy_target(self) -> "loopy.LoopyPyOpenCLTarget": + if self.loopy_target: + return self.loopy_target + import loopy as lp return lp.PyOpenCLTarget()