Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,21 @@
from arraycontext.container.traversal import (rec_map_array_container,
with_array_context)
from arraycontext.metadata import NameHint
from pytools import memoize_method

if TYPE_CHECKING:
import pytato
import pyopencl as cl
import loopy as lp

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import pyopencl as cl # noqa: F811


import logging
logger = logging.getLogger(__name__)


# {{{ tag conversion

def _preprocess_array_tags(tags: ToTagSetConvertible) -> FrozenSet[Tag]:
Expand Down Expand Up @@ -203,13 +209,30 @@ def supports_nonscalar_broadcasting(self):
def permits_advanced_indexing(self):
return True

def get_target(self):
return None

# }}}

# }}}


# {{{ PytatoPyOpenCLArrayContext

from pytato.target.loopy import LoopyPyOpenCLTarget


class _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget):
def __init__(self, limit_arg_size_nbytes: int) -> None:
super().__init__()
self.limit_arg_size_nbytes = limit_arg_size_nbytes

@memoize_method
def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]:
from loopy import PyOpenCLTarget
return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes)


class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
Expand All @@ -232,7 +255,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
def __init__(
self, queue: "cl.CommandQueue", allocator=None, *,
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None
use_memory_pool: Optional[bool] = None,
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,

# do not use: only for testing
_force_svm_arg_limit: Optional[int] = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
Expand All @@ -242,16 +269,57 @@ def __init__(
representation. This interface should be considered
unstable.
"""
if allocator is not None and use_memory_pool is not None:
raise TypeError("may not specify both allocator and use_memory_pool")

self.using_svm = None

if allocator is None:
from pyopencl.characterize import has_coarse_grain_buffer_svm
has_svm = has_coarse_grain_buffer_svm(queue.device)
if has_svm:
self.using_svm = True

from pyopencl.tools import SVMAllocator
allocator = SVMAllocator(queue.context, queue=queue)

if use_memory_pool:
from pyopencl.tools import SVMPool
allocator = SVMPool(allocator)
else:
self.using_svm = False

from pyopencl.tools import ImmediateAllocator
allocator = ImmediateAllocator(queue.context)

if use_memory_pool:
from pyopencl.tools import MemoryPool
allocator = MemoryPool(allocator)
else:
# Check whether the passed allocator allocates SVM
try:
from pyopencl import SVMPointer
mem = allocator(4)
if isinstance(mem, SVMPointer):
self.using_svm = True
else:
self.using_svm = False
except ImportError:
self.using_svm = False

import pytato as pt
import pyopencl.array as cla
super().__init__(compile_trace_callback=compile_trace_callback)
self.queue = queue

self.allocator = allocator
self.array_types = (pt.Array, cla.Array)

# unused, but necessary to keep the context alive
self.context = self.queue.context

self._force_svm_arg_limit = _force_svm_arg_limit

@property
def _frozen_array_types(self) -> Tuple[Type, ...]:
import pyopencl.array as cla
Expand Down Expand Up @@ -321,6 +389,29 @@ def _to_numpy(ary):
self._rec_map_container(_to_numpy, self.freeze(array)),
actx=None)

@memoize_method
def get_target(self):
import pyopencl as cl
import pyopencl.characterize as cl_char

dev = self.queue.device

if (
self._force_svm_arg_limit is not None
or (
self.using_svm and dev.type & cl.device_type.GPU
and cl_char.has_coarse_grain_buffer_svm(dev))):

limit = dev.max_parameter_size
if self._force_svm_arg_limit is not None:
limit = self._force_svm_arg_limit

logger.info(f"limiting argument buffer size for {dev} to {limit} bytes")

return _ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit)
else:
return super().get_target()

def freeze(self, array):
if np.isscalar(array):
return array
Expand Down Expand Up @@ -415,7 +506,8 @@ def _record_leaf_ary_in_dict(
pt_prg = pt.generate_loopy(transformed_dag,
options=_DEFAULT_LOOPY_OPTIONS,
cl_device=self.queue.device,
function_name=function_name)
function_name=function_name,
target=self.get_target())
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg
else:
Expand Down
4 changes: 3 additions & 1 deletion arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
options=lp.Options(
return_dict=True,
no_numpy=True),
function_name=_prg_id_to_kernel_name(prg_id))
function_name=_prg_id_to_kernel_name(prg_id),
target=self.actx.get_target(),
)
assert isinstance(pytato_program, BoundPyOpenCLProgram)

self.actx._compile_trace_callback(
Expand Down
21 changes: 21 additions & 0 deletions test/test_pytato_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,27 @@ def test_tags_preserved_after_freeze(actx_factory):
assert foo.axes[1].tags_of_type(BazTag)


def test_arg_size_limit(actx_factory):
ran_callback = False

def my_ctc(what, stage, ir):
if stage == "final":
assert ir.target.limit_arg_size_nbytes == 42
nonlocal ran_callback
ran_callback = True

def twice(x):
return 2 * x

actx = _PytatoPyOpenCLArrayContextForTests(
actx_factory().queue, compile_trace_callback=my_ctc, _force_svm_arg_limit=42)

f = actx.compile(twice)
f(99)

assert ran_callback


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down