Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
if TYPE_CHECKING:
import pytato
import pyopencl as cl
import loopy as lp

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import pyopencl as cl # noqa: F811
Expand Down Expand Up @@ -219,20 +218,6 @@ def get_target(self):

# {{{ PytatoPyOpenCLArrayContext

from pytato.target.loopy import LoopyPyOpenCLTarget


class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
def __init__(self, limit_arg_size_nbytes: int) -> None:
super().__init__()
self.limit_arg_size_nbytes = limit_arg_size_nbytes

@memoize_method
def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]:
from loopy import PyOpenCLTarget
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)


class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
Expand Down Expand Up @@ -408,7 +393,9 @@ def get_target(self):

logger.info(f"limiting argument buffer size for {dev} to {limit} bytes")

return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
from arraycontext.impl.pytato.utils import \
ArgSizeLimitingPytatoLoopyPyOpenCLTarget
return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
else:
return super().get_target()

Expand Down
25 changes: 24 additions & 1 deletion arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
"""


from typing import Any, Dict, Set, Tuple, Mapping
from typing import Any, Dict, Set, Tuple, Mapping, Optional, TYPE_CHECKING
from pytools import memoize_method

from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis
from pytato.array import Array, DataWrapper, DictOfNamedArrays
from pytato.transform import CopyMapper
from pytools import UniqueNameGenerator
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
from pytato.target.loopy import LoopyPyOpenCLTarget

if TYPE_CHECKING:
import loopy as lp


class _DatawrapperToBoundPlaceholderMapper(CopyMapper):
Expand Down Expand Up @@ -91,3 +97,20 @@ def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]:

def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]:
return tuple(ClAxis(axis.tags) for axis in axes)


# {{{ arg-size-limiting loopy target

class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
def __init__(self, limit_arg_size_nbytes: int) -> None:
super().__init__()
self.limit_arg_size_nbytes = limit_arg_size_nbytes

@memoize_method
def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]:
from loopy import PyOpenCLTarget
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)

# }}}

# vim: foldmethod=marker