66.. currentmodule:: arraycontext
77
88.. autofunction:: with_container_arithmetic
9+ .. autoclass:: Bcast
10+ .. autoclass:: BcastNLevels
11+ .. autoclass:: BcastUntilActxArray
12+
13+ .. function:: Bcast1
14+
15+ Like :class:`BcastNLevels` with *nlevels* set to 1.
16+
17+ .. function:: Bcast2
18+
19+ Like :class:`BcastNLevels` with *nlevels* set to 2.
20+
21+ .. function:: Bcast3
22+
23+ Like :class:`BcastNLevels` with *nlevels* set to 3.
924"""
1025
1126
3449"""
3550
3651import enum
37- from typing import Any , Callable , Optional , Tuple , TypeVar , Union
52+ from abc import ABC , abstractmethod
53+ from dataclasses import FrozenInstanceError
54+ from functools import partial
55+ from numbers import Number
56+ from typing import Any , Callable , ClassVar , Optional , Tuple , TypeVar , Union
3857from warnings import warn
3958
4059import numpy as np
4160
61+ from arraycontext .context import ArrayContext , ArrayOrContainer
62+
4263
4364# {{{ with_container_arithmetic
4465
@@ -147,8 +168,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
147168 warn (
148169 "Broadcasting container against non-object numpy array. "
149170 "This was never documented to work and will now stop working in "
150- "2025. Convert the array to an object array to preserve the "
151- "current semantics." , DeprecationWarning , stacklevel = 3 )
171+ "2025. Convert the array to an object array or use "
172+ "variants of arraycontext.Bcast to obtain the desired "
173+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
152174 return True
153175 else :
154176 return False
@@ -158,6 +180,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
158180 pass
159181
160182
183+ class Bcast :
184+ """
185+ A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
186+ to broadcast *arg* across a container (with the container as the 'outer' structure).
187+ Since array containers are often nested in complex ways, different subclasses
188+ implement different rules on how broadcasting interacts with the hierarchy,
189+ with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
190+ the most common.
191+ """
192+ arg : ArrayOrContainer
193+
194+ # Accessing this attribute is cheaper than isinstance, so use that
195+ # to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
196+ _with_next_operand : ClassVar [bool ]
197+
198+ def __init__ (self , arg : ArrayOrContainer ) -> None :
199+ object .__setattr__ (self , "arg" , arg )
200+
201+ def __setattr__ (self , name : str , value : Any ) -> None :
202+ raise FrozenInstanceError ()
203+
204+ def __delattr__ (self , name : str ) -> None :
205+ raise FrozenInstanceError ()
206+
207+
208+ class _BcastWithNextOperand (Bcast , ABC ):
209+ """
210+ A :class:`Bcast` object that gets to see who the next operand will be, in
211+ order to decide whether wrapping the child in :class:`Bcast` is still necessary.
212+ This is much more flexible, but also considerably more expensive, than
213+ :class:`_BcastWithoutNextOperand`.
214+ """
215+
216+ _with_next_operand = True
217+
218+ # purposefully undocumented
219+ @abstractmethod
220+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
221+ ...
222+
223+
224+ class _BcastWithoutNextOperand (Bcast , ABC ):
225+ """
226+ A :class:`Bcast` object that does not get to see who the next operand will be.
227+ """
228+ _with_next_operand = False
229+
230+ # purposefully undocumented
231+ @abstractmethod
232+ def _rewrap (self ) -> ArrayOrContainer :
233+ ...
234+
235+
236+ class BcastNLevels (_BcastWithoutNextOperand ):
237+ """
238+ A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
239+ array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
240+ convenient aliases for the common cases.
241+
242+ Usage example::
243+
244+ container + Bcast2(actx_array)
245+
246+ .. note::
247+
248+ :mod:`numpy` object arrays do not count against the number of levels.
249+
250+ .. automethod:: __init__
251+ """
252+ nlevels : int
253+
254+ def __init__ (self , nlevels : int , arg : ArrayOrContainer ) -> None :
255+ if nlevels < 1 :
256+ raise ValueError ("nlevels is expected to be one or greater." )
257+
258+ super ().__init__ (arg )
259+ object .__setattr__ (self , "nlevels" , nlevels )
260+
261+ def _rewrap (self ) -> ArrayOrContainer :
262+ if self .nlevels == 1 :
263+ return self .arg
264+ else :
265+ return BcastNLevels (self .nlevels - 1 , self .arg )
266+
267+
268+ Bcast1Level = partial (BcastNLevels , 1 )
269+ Bcast2Levels = partial (BcastNLevels , 2 )
270+ Bcast3Levels = partial (BcastNLevels , 3 )
271+
272+
273+ class BcastUntilActxArray (_BcastWithNextOperand ):
274+ """
275+ A broadcast rule that broadcasts *arg* across array containers until
276+ the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
277+ of *actx*, or a :class:`~numbers.Number`.
278+
279+ Suggested usage pattern::
280+
281+ bcast = functools.partial(BcastUntilActxArray, actx)
282+
283+ container + bcast(actx_array)
284+
285+ .. automethod:: __init__
286+ """
287+ actx : ArrayContext
288+
289+ def __init__ (self ,
290+ actx : ArrayContext ,
291+ arg : ArrayOrContainer ) -> None :
292+ super ().__init__ (arg )
293+ object .__setattr__ (self , "actx" , actx )
294+
295+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
296+ if isinstance (other_operand , (* self .actx .array_types , Number )):
297+ return self .arg
298+ else :
299+ return self
300+
301+
161302def with_container_arithmetic (
162303 * ,
163304 bcast_number : bool = True ,
@@ -206,6 +347,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
206347
207348 Each operator class also includes the "reverse" operators if applicable.
208349
350+ .. note::
351+
352+ For the generated binary arithmetic operators, if certain types
353+ should be broadcast over the container (with the container as the
354+ 'outer' structure) but are not handled in this way by their types,
355+ you may wrap them in one of the :class:`Bcast` variants to achieve
356+ the desired semantics.
357+
209358 .. note::
210359
211360 To generate the code implementing the operators, this function relies on
@@ -238,6 +387,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
238387 #
239388 # - Broadcast rules are hard to change once established, particularly
240389 # because one cannot grep for their use.
390+ #
391+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
392+ #
393+ # - If one rule does not fit the user's need, they can straightforwardly use
394+ # another.
395+ #
396+ # - It's straightforward to find where certain broadcast rules are used.
397+ #
398+ # - The broadcast rule can contain more state. For example, it's now easy
399+ # for the rule to know what array context should be used to determine
400+ # actx array types.
401+ #
402+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
403+ #
404+ # - User code is a bit more wordy.
405+ #
406+ # - Rewrapping has the potential to be costly, especially in
407+ # _with_next_operand mode.
241408
242409 # {{{ handle inputs
243410
@@ -349,9 +516,8 @@ def wrap(cls: Any) -> Any:
349516 f"Broadcasting array context array types across { cls } "
350517 "has been explicitly "
351518 "enabled. As of 2025, this will stop working. "
352- "There is no replacement as of right now. "
353- "See the discussion in "
354- "https://github.com/inducer/arraycontext/pull/190. "
519+ "Express these operations using arraycontext.Bcast variants "
520+ "instead. "
355521 "To opt out now (and avoid this warning), "
356522 "pass _bcast_actx_array_type=False. " ,
357523 DeprecationWarning , stacklevel = 2 )
@@ -360,9 +526,8 @@ def wrap(cls: Any) -> Any:
360526 f"Broadcasting array context array types across { cls } "
361527 "has been implicitly "
362528 "enabled. As of 2025, this will no longer work. "
363- "There is no replacement as of right now. "
364- "See the discussion in "
365- "https://github.com/inducer/arraycontext/pull/190. "
529+ "Express these operations using arraycontext.Bcast variants "
530+ "instead. "
366531 "To opt out now (and avoid this warning), "
367532 "pass _bcast_actx_array_type=False." ,
368533 DeprecationWarning , stacklevel = 2 )
@@ -380,7 +545,7 @@ def wrap(cls: Any) -> Any:
380545 gen (f"""
381546 from numbers import Number
382547 import numpy as np
383- from arraycontext import ArrayContainer
548+ from arraycontext import ArrayContainer, Bcast
384549 from warnings import warn
385550
386551 def _raise_if_actx_none(actx):
@@ -400,7 +565,8 @@ def is_numpy_array(arg):
400565 "behavior will change in 2025. If you would like the "
401566 "broadcasting behavior to stay the same, make sure "
402567 "to convert the passed numpy array to an "
403- "object array.",
568+ "object array, or use arraycontext.Bcast to achieve "
569+ "the desired broadcasting semantics.",
404570 DeprecationWarning, stacklevel=3)
405571 return True
406572 else:
@@ -492,6 +658,33 @@ def {fname}(arg1):
492658 cls ._serialize_init_arrays_code ("arg2" ).items ()
493659 })
494660
661+ def get_operand (arg : Union [tuple [str , str ], str ]) -> str :
662+ if isinstance (arg , tuple ):
663+ entry , _container = arg
664+ return entry
665+ else :
666+ return arg
667+
668+ bcast_init_args_arg1_is_outer_with_rewrap = \
669+ cls ._deserialize_init_arrays_code ("arg1" , {
670+ key_arg1 :
671+ _format_binary_op_str (
672+ op_str , expr_arg1 ,
673+ f"arg2._rewrap({ get_operand (expr_arg1 )} )" )
674+ for key_arg1 , expr_arg1 in
675+ cls ._serialize_init_arrays_code ("arg1" ).items ()
676+ })
677+ bcast_init_args_arg2_is_outer_with_rewrap = \
678+ cls ._deserialize_init_arrays_code ("arg2" , {
679+ key_arg2 :
680+ _format_binary_op_str (
681+ op_str ,
682+ f"arg1._rewrap({ get_operand (expr_arg2 )} )" ,
683+ expr_arg2 )
684+ for key_arg2 , expr_arg2 in
685+ cls ._serialize_init_arrays_code ("arg2" ).items ()
686+ })
687+
495688 # {{{ "forward" binary operators
496689
497690 gen (f"def { fname } (arg1, arg2):" )
@@ -544,14 +737,19 @@ def {fname}(arg1):
544737 warn("Broadcasting { cls } over array "
545738 f"context array type {{type(arg2)}} is deprecated "
546739 "and will no longer work in 2025. "
547- "There is no replacement as of right now. "
548- "See the discussion in "
549- "https://github.com/inducer/arraycontext/"
550- "pull/190. ",
740+ "Use arraycontext.Bcast to achieve the desired "
741+ "broadcasting semantics.",
551742 DeprecationWarning, stacklevel=2)
552743
553744 return cls({ bcast_init_args_arg1_is_outer } )
554745
746+ if isinstance(arg2, Bcast):
747+ if arg2._with_next_operand:
748+ return cls({ bcast_init_args_arg1_is_outer_with_rewrap } )
749+ else:
750+ arg2 = arg2._rewrap()
751+ return cls({ bcast_init_args_arg1_is_outer } )
752+
555753 return NotImplemented
556754 """ )
557755 gen (f"cls.__{ dunder_name } __ = { fname } " )
@@ -595,14 +793,19 @@ def {fname}(arg2, arg1):
595793 f"context array type {{type(arg1)}} "
596794 "is deprecated "
597795 "and will no longer work in 2025."
598- "There is no replacement as of right now. "
599- "See the discussion in "
600- "https://github.com/inducer/arraycontext/"
601- "pull/190. ",
796+ "Use arraycontext.Bcast to achieve the "
797+ "desired broadcasting semantics.",
602798 DeprecationWarning, stacklevel=2)
603799
604800 return cls({ bcast_init_args_arg2_is_outer } )
605801
802+ if isinstance(arg1, Bcast):
803+ if arg1._with_next_operand:
804+ return cls({ bcast_init_args_arg2_is_outer_with_rewrap } )
805+ else:
806+ arg1 = arg1._rewrap()
807+ return cls({ bcast_init_args_arg2_is_outer } )
808+
606809 return NotImplemented
607810
608811 cls.__r{ dunder_name } __ = { fname } """ )
0 commit comments