diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 9eb0fae8a..7052a534a 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -232,6 +232,12 @@ class _FromGroupPickData(Generic[ArrayT]): from_element_indices: ArrayT is_surjective: bool + @keyed_memoize_method(key=lambda actx: type(actx)) + def indexed_dof_pick_lists(self, actx): + assert actx.permits_advanced_indexing + return actx.freeze( + actx.thaw(self.dof_pick_lists)[actx.thaw(self.dof_pick_list_indices)]) + # }}} @@ -356,7 +362,7 @@ def __init__(self, # {{{ _resample_matrix @keyed_memoize_method(key=lambda actx, to_group_index, ibatch_index: - (to_group_index, ibatch_index)) + (type(actx), to_group_index, ibatch_index)) def _resample_matrix(self, actx: ArrayContext, to_group_index: int, ibatch_index: int): import modepy as mp @@ -435,7 +441,8 @@ def _resample_point_pick_indices(self, to_group_index: int, ibatch_index: int, return result @keyed_memoize_method(lambda actx, to_group_index, ibatch_index, - tol_multiplier=None: (to_group_index, ibatch_index, tol_multiplier)) + tol_multiplier=None: (type(actx), to_group_index, + ibatch_index, tol_multiplier)) def _frozen_resample_point_pick_indices(self, actx: ArrayContext, to_group_index: int, ibatch_index: int, tol_multiplier: Optional[float] = None): @@ -736,8 +743,7 @@ def group_pick_knl(is_surjective: bool): grp_ary_contrib = ary[fgpd.from_group_index][ _reshape_and_preserve_tags( actx, from_element_indices, (-1, 1)), - actx.thaw(fgpd.dof_pick_lists)[ - actx.thaw(fgpd.dof_pick_list_indices)] + actx.thaw(fgpd.indexed_dof_pick_lists(actx)) ] if not fgpd.is_surjective: