@@ -529,7 +529,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
529529 return pytato_program , name_in_program_to_tags , name_in_program_to_axes
530530
531531
532- def _args_to_device_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
532+ def _args_to_device_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ,
533+ fn_name = "<unknown>" ):
533534 input_kwargs_for_loopy = {}
534535
535536 for arg_id , arg in arg_id_to_arg .items ():
@@ -550,32 +551,20 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
550551 # got a frozen array => do nothing
551552 pass
552553 elif isinstance (arg , pt .Array ):
553- # got an array expression => evaluate it
554- from warnings import warn
555- warn (f"Argument array '{ arg_id } ' to a compiled function is "
556- "unevaluated. Evaluating just-in-time, at "
557- "considerable expense. This is deprecated and will stop "
558- "working in 2023. To avoid this warning, force evaluation "
559- "of all arguments via freeze/thaw." ,
560- DeprecationWarning , stacklevel = 4 )
561-
562- arg = actx .freeze (arg )
554+ # got an array expression => abort
555+ raise ValueError (
556+ f"Argument '{ arg_id } ' to the '{ fn_name } ' compiled function is a"
557+ " pytato array expression. Evaluating it just-in-time"
558+ " potentially causes a significant overhead on each call to the"
559+ " function and is therefore unsupported. "
560+ )
563561 else :
564562 raise NotImplementedError (type (arg ))
565563
566564 input_kwargs_for_loopy [input_id_to_name_in_program [arg_id ]] = arg
567565
568566 return input_kwargs_for_loopy
569567
570-
571- def _args_to_cl_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
572- from warnings import warn
573- warn ("_args_to_cl_buffer has been renamed to"
574- " _args_to_device_buffers. This will be"
575- " an error in 2023." , DeprecationWarning , stacklevel = 2 )
576- return _args_to_device_buffers (actx , input_id_to_name_in_program ,
577- arg_id_to_arg )
578-
579568# }}}
580569
581570
@@ -631,7 +620,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
631620 type of the callable.
632621 """
633622 actx : PytatoPyOpenCLArrayContext
634- pytato_program : pt .target .BoundProgram
623+ pytato_program : pt .target .loopy . BoundPyOpenCLExecutable
635624 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
636625 output_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
637626 name_in_program_to_tags : Mapping [str , frozenset [Tag ]]
@@ -642,8 +631,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
642631 from .utils import get_cl_axes_from_pt_axes
643632 from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
644633
634+ fn_name = self .pytato_program .program .entrypoint
635+
645636 input_kwargs_for_loopy = _args_to_device_buffers (
646- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
637+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
647638
648639 evt , out_dict = self .pytato_program (queue = self .actx .queue ,
649640 allocator = self .actx .allocator ,
@@ -674,7 +665,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
674665 Name of the output array in the program.
675666 """
676667 actx : PytatoPyOpenCLArrayContext
677- pytato_program : pt .target .BoundProgram
668+ pytato_program : pt .target .loopy . BoundPyOpenCLExecutable
678669 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
679670 output_tags : frozenset [Tag ]
680671 output_axes : tuple [pt .Axis , ...]
@@ -684,8 +675,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
684675 from .utils import get_cl_axes_from_pt_axes
685676 from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
686677
678+ fn_name = self .pytato_program .program .entrypoint
679+
687680 input_kwargs_for_loopy = _args_to_device_buffers (
688- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
681+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
689682
690683 evt , out_dict = self .pytato_program (queue = self .actx .queue ,
691684 allocator = self .actx .allocator ,
@@ -723,16 +716,18 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
723716 type of the callable.
724717 """
725718 actx : PytatoJAXArrayContext
726- pytato_program : pt .target .BoundProgram
719+ pytato_program : pt .target .python . BoundJAXPythonProgram
727720 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
728721 output_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
729722 name_in_program_to_tags : Mapping [str , frozenset [Tag ]]
730723 name_in_program_to_axes : Mapping [str , tuple [pt .Axis , ...]]
731724 output_template : ArrayContainer
732725
733726 def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
727+ fn_name = self .pytato_program .entrypoint
728+
734729 input_kwargs_for_loopy = _args_to_device_buffers (
735- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
730+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
736731
737732 out_dict = self .pytato_program (** input_kwargs_for_loopy )
738733
@@ -754,15 +749,17 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
754749 Name of the output array in the program.
755750 """
756751 actx : PytatoJAXArrayContext
757- pytato_program : pt .target .BoundProgram
752+ pytato_program : pt .target .python . BoundJAXPythonProgram
758753 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
759754 output_tags : frozenset [Tag ]
760755 output_axes : tuple [pt .Axis , ...]
761756 output_name : str
762757
763758 def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
759+ fn_name = self .pytato_program .entrypoint
760+
764761 input_kwargs_for_loopy = _args_to_device_buffers (
765- self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
762+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
766763
767764 _evt , out_dict = self .pytato_program (** input_kwargs_for_loopy )
768765
0 commit comments