diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index a84b2282b..0d55a7c10 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -30,7 +30,8 @@ import loopy as lp from meshmode.transform_metadata import ( ConcurrentElementInameTag, ConcurrentDOFInameTag, - DiscretizationElementAxisTag, DiscretizationDOFAxisTag) + DiscretizationElementAxisTag, DiscretizationDOFAxisTag, + DiscretizationDOFPickListAxisTag) from pytools import memoize_in, keyed_memoize_method from arraycontext import ( ArrayContext, ArrayT, ArrayOrContainerT, NotAnArrayContainerError, @@ -166,12 +167,14 @@ def _global_from_element_indices( np_full_from_element_indices[~np_from_el_present] = 0 from_el_present = actx.freeze( - actx.tag(NameHint("from_el_present"), - actx.from_numpy( - np_from_el_present.astype(np.int8)))) + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_present"), + actx.from_numpy( + np_from_el_present.astype(np.int8))))) full_from_element_indices = actx.freeze( - actx.tag(NameHint("from_el_indices"), - actx.from_numpy(np_full_from_element_indices))) + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_indices"), + actx.from_numpy(np_full_from_element_indices)))) self._global_from_element_indices_cache = ( from_el_present, full_from_element_indices) @@ -553,17 +556,22 @@ def _per_target_group_pick_info( _FromGroupPickData( from_group_index=source_group_index, dof_pick_lists=actx.freeze( - actx.tag(NameHint("dof_pick_lists"), - actx.from_numpy(dof_pick_lists))), + actx.tag_axis(0, DiscretizationDOFPickListAxisTag(), + actx.tag(NameHint("dof_pick_lists"), + actx.from_numpy(dof_pick_lists)))), dof_pick_list_indices=actx.freeze( - actx.tag(NameHint("dof_pick_list_indices"), - actx.from_numpy(dof_pick_list_indices))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("dof_pick_list_indices"), + actx.from_numpy(dof_pick_list_indices)))), from_el_present=actx.freeze( - actx.tag(NameHint("from_el_present"), - actx.from_numpy(from_el_present.astype(np.int8)))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_present"), + actx.from_numpy( + from_el_present.astype(np.int8))))), from_element_indices=actx.freeze( - actx.tag(NameHint("from_el_indices"), - actx.from_numpy(from_el_indices))), + actx.tag_axis(0, DiscretizationElementAxisTag(), + actx.tag(NameHint("from_el_indices"), + actx.from_numpy(from_el_indices)))), is_surjective=from_el_present.all() )) @@ -732,25 +740,27 @@ def group_pick_knl(is_surjective: bool): group_pick_info = None if group_pick_info is not None: - group_array_contributions = [] - if actx.permits_advanced_indexing and not _force_use_loopy: for fgpd in group_pick_info: from_element_indices = actx.thaw(fgpd.from_element_indices) if ary[fgpd.from_group_index].size: grp_ary_contrib = ary[fgpd.from_group_index][ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), - actx.thaw(fgpd.dof_pick_lists)[ - actx.thaw(fgpd.dof_pick_list_indices)] - ] + actx, from_element_indices, (-1, 1))), + actx.thaw(fgpd.dof_pick_lists)[ + actx.thaw(fgpd.dof_pick_list_indices)] + ] if not fgpd.is_surjective: from_el_present = actx.thaw(fgpd.from_el_present) grp_ary_contrib = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1))), grp_ary_contrib, 0) @@ -800,8 +810,10 @@ def group_pick_knl(is_surjective: bool): mat = self._resample_matrix(actx, i_tgrp, i_batch) if actx.permits_advanced_indexing and not _force_use_loopy: batch_result = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_el_present, (-1, 1))), actx.einsum("ij,ej->ei", mat, grp_ary[from_element_indices]), 0) @@ -822,11 +834,15 @@ def group_pick_knl(is_surjective: bool): if actx.permits_advanced_indexing and not _force_use_loopy: batch_result = actx.np.where( - _reshape_and_preserve_tags( - actx, from_el_present, (-1, 1)), - from_vec[ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, _reshape_and_preserve_tags( - actx, from_element_indices, (-1, 1)), + actx, from_el_present, (-1, 1))), + from_vec[ + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + _reshape_and_preserve_tags( + actx, from_element_indices, (-1, 1))), pick_list], 0) else: @@ -853,10 +869,13 @@ def group_pick_knl(is_surjective: bool): else: # If no batched data at all, return zeros for this # particular group array - group_array = actx.zeros( + group_array = tag_axes(actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + actx.zeros( shape=(self.to_discr.groups[i_tgrp].nelements, self.to_discr.groups[i_tgrp].nunit_dofs), - dtype=ary.entry_dtype) + dtype=ary.entry_dtype)) group_arrays.append(group_array) diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index f622310b1..54db12b14 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -8,6 +8,7 @@ .. autoclass:: DiscretizationDOFAxisTag .. autoclass:: DiscretizationAmbientDimAxisTag .. autoclass:: DiscretizationTopologicalDimAxisTag +.. autoclass:: DiscretizationDOFPickListAxisTag """ __copyright__ = """ @@ -121,3 +122,12 @@ class DiscretizationTopologicalDimAxisTag(DiscretizationDimAxisTag): Array dimensions tagged with this tag type describe an axis indexing over the discretization's physical coordinate dimensions. """ + + +@tag_dataclass +class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag): + """ + Array dimensions tagged with this tag type describe an axis indexing over + DOF pick lists in + :class:`meshmode.discretization.connection.DirectDiscretizationConnection`. + """