66.. currentmodule:: arraycontext
77
88.. autofunction:: with_container_arithmetic
9+ .. autoclass:: Bcast
10+ .. autoclass:: BcastNLevels
11+ .. autoclass:: BcastUntilActxArray
12+
13+ .. function:: Bcast1
14+
15+ Like :class:`BcastNLevels` with *nlevels* set to 1.
16+
17+ .. function:: Bcast2
18+
19+ Like :class:`BcastNLevels` with *nlevels* set to 2.
20+
21+ .. function:: Bcast3
22+
23+ Like :class:`BcastNLevels` with *nlevels* set to 3.
924"""
1025
1126
3449"""
3550
3651import enum
52+ from abc import ABC , abstractmethod
3753from collections .abc import Callable
38- from typing import Any , TypeVar
54+ from dataclasses import FrozenInstanceError
55+ from functools import partial
56+ from numbers import Number
57+ from typing import Any , ClassVar , TypeVar , Union
3958from warnings import warn
4059
4160import numpy as np
4261
62+ from arraycontext .context import ArrayContext , ArrayOrContainer
63+
4364
4465# {{{ with_container_arithmetic
4566
@@ -142,8 +163,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142163 warn (
143164 "Broadcasting container against non-object numpy array. "
144165 "This was never documented to work and will now stop working in "
145- "2025. Convert the array to an object array to preserve the "
146- "current semantics." , DeprecationWarning , stacklevel = 3 )
166+ "2025. Convert the array to an object array or use "
167+ "variants of arraycontext.Bcast to obtain the desired "
168+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
147169 return True
148170 else :
149171 return False
@@ -153,6 +175,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
153175 pass
154176
155177
178+ class Bcast :
179+ """
180+ A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
181+ to broadcast *arg* across a container (with the container as the 'outer' structure).
182+ Since array containers are often nested in complex ways, different subclasses
183+ implement different rules on how broadcasting interacts with the hierarchy,
184+ with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
185+ the most common.
186+ """
187+ arg : ArrayOrContainer
188+
189+ # Accessing this attribute is cheaper than isinstance, so use that
190+ # to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
191+ _with_next_operand : ClassVar [bool ]
192+
193+ def __init__ (self , arg : ArrayOrContainer ) -> None :
194+ object .__setattr__ (self , "arg" , arg )
195+
196+ def __setattr__ (self , name : str , value : Any ) -> None :
197+ raise FrozenInstanceError ()
198+
199+ def __delattr__ (self , name : str ) -> None :
200+ raise FrozenInstanceError ()
201+
202+
203+ class _BcastWithNextOperand (Bcast , ABC ):
204+ """
205+ A :class:`Bcast` object that gets to see who the next operand will be, in
206+ order to decide whether wrapping the child in :class:`Bcast` is still necessary.
207+ This is much more flexible, but also considerably more expensive, than
208+ :class:`_BcastWithoutNextOperand`.
209+ """
210+
211+ _with_next_operand = True
212+
213+ # purposefully undocumented
214+ @abstractmethod
215+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
216+ ...
217+
218+
219+ class _BcastWithoutNextOperand (Bcast , ABC ):
220+ """
221+ A :class:`Bcast` object that does not get to see who the next operand will be.
222+ """
223+ _with_next_operand = False
224+
225+ # purposefully undocumented
226+ @abstractmethod
227+ def _rewrap (self ) -> ArrayOrContainer :
228+ ...
229+
230+
231+ class BcastNLevels (_BcastWithoutNextOperand ):
232+ """
233+ A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
234+ array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
235+ convenient aliases for the common cases.
236+
237+ Usage example::
238+
239+ container + Bcast2(actx_array)
240+
241+ .. note::
242+
243+ :mod:`numpy` object arrays do not count against the number of levels.
244+
245+ .. automethod:: __init__
246+ """
247+ nlevels : int
248+
249+ def __init__ (self , nlevels : int , arg : ArrayOrContainer ) -> None :
250+ if nlevels < 1 :
251+ raise ValueError ("nlevels is expected to be one or greater." )
252+
253+ super ().__init__ (arg )
254+ object .__setattr__ (self , "nlevels" , nlevels )
255+
256+ def _rewrap (self ) -> ArrayOrContainer :
257+ if self .nlevels == 1 :
258+ return self .arg
259+ else :
260+ return BcastNLevels (self .nlevels - 1 , self .arg )
261+
262+
263+ Bcast1Level = partial (BcastNLevels , 1 )
264+ Bcast2Levels = partial (BcastNLevels , 2 )
265+ Bcast3Levels = partial (BcastNLevels , 3 )
266+
267+
268+ class BcastUntilActxArray (_BcastWithNextOperand ):
269+ """
270+ A broadcast rule that broadcasts *arg* across array containers until
271+ the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
272+ of *actx*, or a :class:`~numbers.Number`.
273+
274+ Suggested usage pattern::
275+
276+ bcast = functools.partial(BcastUntilActxArray, actx)
277+
278+ container + bcast(actx_array)
279+
280+ .. automethod:: __init__
281+ """
282+ actx : ArrayContext
283+
284+ def __init__ (self ,
285+ actx : ArrayContext ,
286+ arg : ArrayOrContainer ) -> None :
287+ super ().__init__ (arg )
288+ object .__setattr__ (self , "actx" , actx )
289+
290+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
291+ if isinstance (other_operand , (* self .actx .array_types , Number )):
292+ return self .arg
293+ else :
294+ return self
295+
296+
156297def with_container_arithmetic (
157298 * ,
158299 number_bcasts_across : bool | None = None ,
@@ -207,6 +348,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207348
208349 Each operator class also includes the "reverse" operators if applicable.
209350
351+ .. note::
352+
353+ For the generated binary arithmetic operators, if certain types
354+ should be broadcast over the container (with the container as the
355+ 'outer' structure) but are not handled in this way by their types,
356+ you may wrap them in one of the :class:`Bcast` variants to achieve
357+ the desired semantics.
358+
210359 .. note::
211360
212361 To generate the code implementing the operators, this function relies on
@@ -239,6 +388,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
239388 #
240389 # - Broadcast rules are hard to change once established, particularly
241390 # because one cannot grep for their use.
391+ #
392+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
393+ #
394+ # - If one rule does not fit the user's need, they can straightforwardly use
395+ # another.
396+ #
397+ # - It's straightforward to find where certain broadcast rules are used.
398+ #
399+ # - The broadcast rule can contain more state. For example, it's now easy
400+ # for the rule to know what array context should be used to determine
401+ # actx array types.
402+ #
403+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
404+ #
405+ # - User code is a bit more wordy.
406+ #
407+ # - Rewrapping has the potential to be costly, especially in
408+ # _with_next_operand mode.
242409
243410 # {{{ handle inputs
244411
@@ -404,9 +571,8 @@ def wrap(cls: Any) -> Any:
404571 f"Broadcasting array context array types across { cls } "
405572 "has been explicitly "
406573 "enabled. As of 2025, this will stop working. "
407- "There is no replacement as of right now. "
408- "See the discussion in "
409- "https://github.com/inducer/arraycontext/pull/190. "
574+ "Express these operations using arraycontext.Bcast variants "
575+ "instead. "
410576 "To opt out now (and avoid this warning), "
411577 "pass _bcast_actx_array_type=False. " ,
412578 DeprecationWarning , stacklevel = 2 )
@@ -415,9 +581,8 @@ def wrap(cls: Any) -> Any:
415581 f"Broadcasting array context array types across { cls } "
416582 "has been implicitly "
417583 "enabled. As of 2025, this will no longer work. "
418- "There is no replacement as of right now. "
419- "See the discussion in "
420- "https://github.com/inducer/arraycontext/pull/190. "
584+ "Express these operations using arraycontext.Bcast variants "
585+ "instead. "
421586 "To opt out now (and avoid this warning), "
422587 "pass _bcast_actx_array_type=False." ,
423588 DeprecationWarning , stacklevel = 2 )
@@ -435,7 +600,7 @@ def wrap(cls: Any) -> Any:
435600 gen (f"""
436601 from numbers import Number
437602 import numpy as np
438- from arraycontext import ArrayContainer
603+ from arraycontext import ArrayContainer, Bcast
439604 from warnings import warn
440605
441606 def _raise_if_actx_none(actx):
@@ -455,7 +620,8 @@ def is_numpy_array(arg):
455620 "behavior will change in 2025. If you would like the "
456621 "broadcasting behavior to stay the same, make sure "
457622 "to convert the passed numpy array to an "
458- "object array.",
623+ "object array, or use arraycontext.Bcast to achieve "
624+ "the desired broadcasting semantics.",
459625 DeprecationWarning, stacklevel=3)
460626 return True
461627 else:
@@ -553,6 +719,33 @@ def {fname}(arg1):
553719 cls ._serialize_init_arrays_code ("arg2" ).items ()
554720 })
555721
722+ def get_operand (arg : Union [tuple [str , str ], str ]) -> str :
723+ if isinstance (arg , tuple ):
724+ entry , _container = arg
725+ return entry
726+ else :
727+ return arg
728+
729+ bcast_init_args_arg1_is_outer_with_rewrap = \
730+ cls ._deserialize_init_arrays_code ("arg1" , {
731+ key_arg1 :
732+ _format_binary_op_str (
733+ op_str , expr_arg1 ,
734+ f"arg2._rewrap({ get_operand (expr_arg1 )} )" )
735+ for key_arg1 , expr_arg1 in
736+ cls ._serialize_init_arrays_code ("arg1" ).items ()
737+ })
738+ bcast_init_args_arg2_is_outer_with_rewrap = \
739+ cls ._deserialize_init_arrays_code ("arg2" , {
740+ key_arg2 :
741+ _format_binary_op_str (
742+ op_str ,
743+ f"arg1._rewrap({ get_operand (expr_arg2 )} )" ,
744+ expr_arg2 )
745+ for key_arg2 , expr_arg2 in
746+ cls ._serialize_init_arrays_code ("arg2" ).items ()
747+ })
748+
556749 # {{{ "forward" binary operators
557750
558751 gen (f"def { fname } (arg1, arg2):" )
@@ -605,14 +798,19 @@ def {fname}(arg1):
605798 warn("Broadcasting { cls } over array "
606799 f"context array type {{type(arg2)}} is deprecated "
607800 "and will no longer work in 2025. "
608- "There is no replacement as of right now. "
609- "See the discussion in "
610- "https://github.com/inducer/arraycontext/"
611- "pull/190. ",
801+ "Use arraycontext.Bcast to achieve the desired "
802+ "broadcasting semantics.",
612803 DeprecationWarning, stacklevel=2)
613804
614805 return cls({ bcast_init_args_arg1_is_outer } )
615806
807+ if isinstance(arg2, Bcast):
808+ if arg2._with_next_operand:
809+ return cls({ bcast_init_args_arg1_is_outer_with_rewrap } )
810+ else:
811+ arg2 = arg2._rewrap()
812+ return cls({ bcast_init_args_arg1_is_outer } )
813+
616814 return NotImplemented
617815 """ )
618816 gen (f"cls.__{ dunder_name } __ = { fname } " )
@@ -656,14 +854,19 @@ def {fname}(arg2, arg1):
656854 f"context array type {{type(arg1)}} "
657855 "is deprecated "
658856 "and will no longer work in 2025."
659- "There is no replacement as of right now. "
660- "See the discussion in "
661- "https://github.com/inducer/arraycontext/"
662- "pull/190. ",
857+ "Use arraycontext.Bcast to achieve the "
858+ "desired broadcasting semantics.",
663859 DeprecationWarning, stacklevel=2)
664860
665861 return cls({ bcast_init_args_arg2_is_outer } )
666862
863+ if isinstance(arg1, Bcast):
864+ if arg1._with_next_operand:
865+ return cls({ bcast_init_args_arg2_is_outer_with_rewrap } )
866+ else:
867+ arg1 = arg1._rewrap()
868+ return cls({ bcast_init_args_arg2_is_outer } )
869+
667870 return NotImplemented
668871
669872 cls.__r{ dunder_name } __ = { fname } """ )
0 commit comments