@@ -124,29 +124,31 @@ def map_node_min(self, expr):
124124 def _map_elementwise_reduction (self , reduction_name , expr ):
125125 import loopy as lp
126126 from arraycontext import make_loopy_program
127- from meshmode .transform_metadata import (
128- ConcurrentElementInameTag , ConcurrentDOFInameTag )
127+ from meshmode .transform_metadata import ConcurrentElementInameTag
128+ actx = self . array_context
129129
130- @memoize_in (self .places , "elementwise_node_" + reduction_name )
130+ @memoize_in (actx , (
131+ EvaluationMapperBase ._map_elementwise_reduction ,
132+ f"elementwise_node_{ reduction_name } " ))
131133 def node_knl ():
132134 t_unit = make_loopy_program (
133135 """{[iel, idof, jdof]:
134136 0<=iel<nelements and
135137 0<=idof, jdof<ndofs}""" ,
136138 """
137- result[iel, idof] = %s(jdof, operand[iel, jdof])
139+ <> el_result = %s(jdof, operand[iel, jdof])
140+ result[iel, idof] = el_result
138141 """ % reduction_name ,
139- name = "nodewise_reduce " )
142+ name = f"elementwise_node_ { reduction_name } " )
140143
141144 return lp .tag_inames (t_unit , {
142145 "iel" : ConcurrentElementInameTag (),
143- "idof" : ConcurrentDOFInameTag (),
144146 })
145147
146- @memoize_in (self .places , "elementwise_" + reduction_name )
148+ @memoize_in (actx , (
149+ EvaluationMapperBase ._map_elementwise_reduction ,
150+ f"elementwise_element_{ reduction_name } " ))
147151 def element_knl ():
148- # FIXME: This computes the reduction value redundantly for each
149- # output DOF.
150152 t_unit = make_loopy_program (
151153 """{[iel, jdof]:
152154 0<=iel<nelements and
@@ -155,37 +157,27 @@ def element_knl():
155157 """
156158 result[iel, 0] = %s(jdof, operand[iel, jdof])
157159 """ % reduction_name ,
158- name = "elementwise_reduce " )
160+ name = f"elementwise_element_ { reduction_name } " )
159161
160162 return lp .tag_inames (t_unit , {
161163 "iel" : ConcurrentElementInameTag (),
162164 })
163165
164- discr = self .places .get_discretization (
165- expr .dofdesc .geometry , expr .dofdesc .discr_stage )
166+ dofdesc = expr .dofdesc
166167 operand = self .rec (expr .operand )
167- assert operand .shape == (len (discr .groups ),)
168-
169- def _reduce (knl , result ):
170- for g_operand , g_result in zip (operand , result ):
171- self .array_context .call_loopy (
172- knl , operand = g_operand , result = g_result )
173-
174- return result
175-
176- dtype = operand .entry_dtype
177- granularity = expr .dofdesc .granularity
178- if granularity is sym .GRANULARITY_NODE :
179- return _reduce (node_knl (),
180- discr .empty (self .array_context , dtype = dtype ))
181- elif granularity is sym .GRANULARITY_ELEMENT :
182- result = DOFArray (self .array_context , tuple ([
183- self .array_context .empty ((grp .nelements , 1 ), dtype = dtype )
184- for grp in discr .groups
168+
169+ if dofdesc .granularity is sym .GRANULARITY_NODE :
170+ return type (operand )(actx , tuple ([
171+ actx .call_loopy (node_knl (), operand = operand_i )["result" ]
172+ for operand_i in operand
173+ ]))
174+ elif dofdesc .granularity is sym .GRANULARITY_ELEMENT :
175+ return type (operand )(actx , tuple ([
176+ actx .call_loopy (element_knl (), operand = operand_i )["result" ]
177+ for operand_i in operand
185178 ]))
186- return _reduce (element_knl (), result )
187179 else :
188- raise ValueError (f"unsupported granularity: { granularity } " )
180+ raise ValueError (f"unsupported granularity: { dofdesc . granularity } " )
189181
190182 def map_elementwise_sum (self , expr ):
191183 return self ._map_elementwise_reduction ("sum" , expr )
0 commit comments