diff --git a/loopy/__init__.py b/loopy/__init__.py index c0f8c78cd..a14bf09d5 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -76,7 +76,7 @@ affine_map_inames, find_unused_axis_tag, make_reduction_inames_unique, has_schedulable_iname_nesting, get_iname_duplication_options, - add_inames_to_insn, add_inames_for_unused_hw_axes) + add_inames_to_insn, add_inames_for_unused_hw_axes, map_domain) from loopy.transform.instruction import ( find_instructions, map_instructions, @@ -205,7 +205,7 @@ "affine_map_inames", "find_unused_axis_tag", "make_reduction_inames_unique", "has_schedulable_iname_nesting", "get_iname_duplication_options", - "add_inames_to_insn", "add_inames_for_unused_hw_axes", + "add_inames_to_insn", "add_inames_for_unused_hw_axes", "map_domain", "add_prefetch", "change_arg_to_image", "tag_array_axes", "tag_data_axes", diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index c3b4a42ee..33b52b913 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -72,6 +72,8 @@ .. autofunction:: add_inames_to_insn +.. autofunction:: map_domain + .. autofunction:: add_inames_for_unused_hw_axes """ @@ -1832,6 +1834,327 @@ def add_inames_to_insn(kernel, inames, insn_match): # }}} +# {{{ map_domain + +class _MapDomainMapper(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, within, new_inames, substitutions): + super(_MapDomainMapper, self).__init__(rule_mapping_context) + + self.within = within + + self.old_inames = frozenset(substitutions) + self.new_inames = new_inames + + self.substitutions = substitutions + + def map_reduction(self, expr, expn_state): + red_overlap = frozenset(expr.inames) & self.old_inames + arg_ctx_overlap = frozenset(expn_state.arg_context) & self.old_inames + if (red_overlap + and self.within( + expn_state.kernel, + expn_state.instruction)): + if len(red_overlap) != len(self.old_inames): + raise LoopyError("reduction '%s' involves a part " + "of the map domain inames. Reductions must " + "either involve all or none of the map domain " + "inames." % str(expr)) + + if arg_ctx_overlap: + if arg_ctx_overlap == red_overlap: + # All variables are shadowed by context, that's OK. + return super(_MapDomainMapper, self).map_reduction( + expr, expn_state) + else: + raise LoopyError("reduction '%s' has" + "some of the reduction variables affected " + "by the map_domain shadowed by context. " + "Either all or none must be shadowed." + % str(expr)) + + new_inames = list(expr.inames) + for old_iname in self.old_inames: + new_inames.remove(old_iname) + new_inames.extend(self.new_inames) + + from loopy.symbolic import Reduction + return Reduction(expr.operation, tuple(new_inames), + self.rec(expr.expr, expn_state), + expr.allow_simultaneous) + else: + return super(_MapDomainMapper, self).map_reduction(expr, expn_state) + + def map_variable(self, expr, expn_state): + if (expr.name in self.old_inames + and expr.name not in expn_state.arg_context + and self.within( + expn_state.kernel, + expn_state.instruction)): + return self.substitutions[expr.name] + else: + return super(_MapDomainMapper, self).map_variable(expr, expn_state) + + +def _find_aff_subst_from_map(iname, isl_map): + if not isinstance(isl_map, isl.BasicMap): + raise RuntimeError("isl_map must be a BasicMap") + + dt, dim_idx = isl_map.get_var_dict()[iname] + + assert dt == dim_type.in_ + + # Force isl to solve for only this iname on its side of the map, by + # projecting out all other "in" variables. + isl_map = isl_map.project_out(dt, dim_idx+1, isl_map.dim(dt)-(dim_idx+1)) + isl_map = isl_map.project_out(dt, 0, dim_idx) + dim_idx = 0 + + # Convert map to set to avoid "domain of affine expression should be a set". + # The old "in" variable will be the last of the out_dims. + new_dim_idx = isl_map.dim(dim_type.out) + isl_map = isl_map.move_dims( + dim_type.out, isl_map.dim(dim_type.out), + dt, dim_idx, 1) + isl_map = isl_map.range() # now a set + dt = dim_type.set + dim_idx = new_dim_idx + del new_dim_idx + + for cns in isl_map.get_constraints(): + if cns.is_equality() and cns.involves_dims(dt, dim_idx, 1): + coeff = cns.get_coefficient_val(dt, dim_idx) + cns_zeroed = cns.set_coefficient_val(dt, dim_idx, 0) + if cns_zeroed.involves_dims(dt, dim_idx, 1): + # not suitable, constraint still involves dim, perhaps in a div + continue + + if coeff.is_one(): + return -cns_zeroed.get_aff() + elif coeff.is_negone(): + return cns_zeroed.get_aff() + else: + # not suitable, coefficient does not have unit coefficient + continue + + raise LoopyError("no suitable equation for '%s' found" % iname) + + +# TODO to match convention elsewhere, swap 'dt' and 'dim_type' identifiers +# (use dt to abbreviate islpy.dim_type, and use dim_type for variables +# containing a specific dim_type) + +def _find_and_rename_dim(old_map, dim_types, old_name, new_name): + # (This function is only used once here, but do not inline it; it is used many + # times in child branch update-dependencies-during-transformations.) + new_map = old_map.copy() + for dt in dim_types: + new_map = new_map.set_dim_name( + dt, new_map.find_dim_by_name(dt, old_name), new_name) + return new_map + + +@for_each_kernel +def map_domain(kernel, isl_map, within=None): + # FIXME: Express _split_iname_backend in terms of this + # Missing/deleted for now: + # - slab processing + # - priorities processing + # FIXME: Process priorities + # FIXME: Express affine_map_inames in terms of this, deprecate + # FIXME: Document + + # FIXME: Support within + # FIXME: Right now, this requires all inames in a domain (or none) to + # be mapped. That makes this awkward to use. + + # {{{ within processing (disabled for now) + if within is not None: + raise NotImplementedError("within") + + from loopy.match import parse_match + within = parse_match(within) + + # {{{ return the same kernel if no kernel matches + + if not any(within(kernel, insn) for insn in kernel.instructions): + return kernel + + # }}} + + # }}} + + if not isl_map.is_bijective(): + raise LoopyError("isl_map must be bijective") + + new_inames = frozenset(isl_map.get_var_dict(dim_type.out)) + old_inames = frozenset(isl_map.get_var_dict(dim_type.in_)) + + # {{{ solve for representation of old inames in terms of new + + substitutions = {} + var_substitutions = {} + applied_iname_rewrites = kernel.applied_iname_rewrites[:] + + from loopy.symbolic import aff_to_expr + from pymbolic import var + for iname in old_inames: + substitutions[iname] = aff_to_expr( + _find_aff_subst_from_map(iname, isl_map)) + var_substitutions[var(iname)] = aff_to_expr( + _find_aff_subst_from_map(iname, isl_map)) + + applied_iname_rewrites.append(var_substitutions) + del var_substitutions + + # }}} + + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + ) + + def process_set(s): + var_dict = s.get_var_dict() + + overlap = old_inames & frozenset(var_dict) + + if not overlap: + # inames in s are not present in transform map, don't change s + return s + + if len(overlap) != len(old_inames): + raise LoopyError("loop domain '%s' involves a part " + "of the map domain inames. Domains must " + "either involve all or none of the map domain " + "inames." % s) + + from loopy.schedule.checker.utils import ( + add_eq_isl_constraint_from_names, + ) + + # {{{ align dims of isl_map and s + + # FIXME: Make this less gross + # FIXME: Make an exported/documented interface of this in islpy + from islpy import _align_dim_type + + map_with_s_domain = isl.Map.from_domain(s) + + # {{{ deal with dims missing from transform map (isl_map) + + # If dims in s are missing from transform map, they need to be added + # so that intersect_domain doesn't remove them. + # Order doesn't matter here because dims will be aligned in the next step. + dims_missing_from_transform_map = list( + set(s.get_var_names(dim_type.set)) - + set(isl_map.get_var_names(dim_type.in_))) + augmented_isl_map = add_and_name_isl_dims( + isl_map, dim_type.in_, dims_missing_from_transform_map) + + # We want these missing inames to map to themselves so that the transform + # has no effect on them. Unfortunatley isl will break if the + # names of the out dims aren't unique, so we will temporariliy rename them + # and then change the names back afterward. + + # FIXME: need better way to make sure proxy dim names are unique + dims_missing_from_transform_map_proxies = [ + d+"__prox" for d in dims_missing_from_transform_map] + assert not set(dims_missing_from_transform_map_proxies) & set( + augmented_isl_map.get_var_dict().keys()) + + augmented_isl_map = add_and_name_isl_dims( + augmented_isl_map, dim_type.out, dims_missing_from_transform_map_proxies) + + # Set proxy iname equal to real iname + for proxy_iname, real_iname in zip( + dims_missing_from_transform_map_proxies, + dims_missing_from_transform_map): + augmented_isl_map = add_eq_isl_constraint_from_names( + augmented_isl_map, proxy_iname, real_iname) + + # }}} + + dim_types = [dim_type.param, dim_type.in_, dim_type.out] + s_names = [ + map_with_s_domain.get_dim_name(dt, i) + for dt in dim_types + for i in range(map_with_s_domain.dim(dt)) + ] + map_names = [ + augmented_isl_map.get_dim_name(dt, i) + for dt in dim_types + for i in range(augmented_isl_map.dim(dt)) + ] + + # (order doesn't matter in s_names/map_names, + # _align_dim_type just converts these to sets + # to determine which names are in both the obj and template, + # not sure why this isn't just handled inside _align_dim_type) + aligned_map = _align_dim_type( + dim_type.param, + augmented_isl_map, map_with_s_domain, False, + map_names, s_names) + aligned_map = _align_dim_type( + dim_type.in_, + aligned_map, map_with_s_domain, False, + map_names, s_names) + + # }}} + + new_s = aligned_map.intersect_domain(s).range() + + # Now rename the proxy dims back to their original names + for proxy_iname, real_iname in zip( + dims_missing_from_transform_map_proxies, + dims_missing_from_transform_map): + new_s = _find_and_rename_dim( + new_s, [dim_type.set], proxy_iname, real_iname) + + return new_s + + # FIXME: Revive _project_out_only_if_all_instructions_in_within + + new_domains = [process_set(dom) for dom in kernel.domains] + + # {{{ update within_inames + + new_insns = [] + for insn in kernel.instructions: + overlap = old_inames & insn.within_inames + if overlap and within(kernel, insn): + if len(overlap) != len(old_inames): + raise LoopyError("instruction '%s' is within only a part " + "of the map domain inames. Instructions must " + "either be within all or none of the map domain " + "inames." % insn.id) + + insn = insn.copy( + within_inames=(insn.within_inames - old_inames) | new_inames) + else: + # leave insn unmodified + pass + + new_insns.append(insn) + + # }}} + + kernel = kernel.copy( + domains=new_domains, + instructions=new_insns, + applied_iname_rewrites=applied_iname_rewrites) + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + ins = _MapDomainMapper(rule_mapping_context, within, + new_inames, substitutions) + + kernel = ins.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(kernel) + + return kernel + +# }}} + + @for_each_kernel def add_inames_for_unused_hw_axes(kernel, within=None): """ diff --git a/test/test_transform.py b/test/test_transform.py index 51e7c2636..feda064e3 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -574,7 +574,6 @@ def test_split_iname_only_if_in_within(): def test_nested_substs_in_insns(ctx_factory): ctx = ctx_factory() - import loopy as lp ref_prg = lp.make_kernel( "{[i]: 0<=i<10}", @@ -594,6 +593,232 @@ def test_nested_substs_in_insns(ctx_factory): lp.auto_test_vs_ref(ref_prg, ctx, t_unit) +# {{{ test_map_domain_vs_split_iname + +def test_map_domain_vs_split_iname(): + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, t]: 0 <= x < nx and 0 <= t < nt}", + "[ni] -> {[i]: 0 <= i < ni}", + ], + """ + a[x,t] = b[x,t] {id=stmta} + c[x,t] = d[x,t] {id=stmtc} + e[i] = f[i] + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b,d,f": np.float32}) + ref_knl = knl + + # }}} + + # {{{ Apply domain change mapping + + knl_map_dom = ref_knl # loop priority goes away, deps stay + + # Create map_domain mapping: + import islpy as isl + transform_map = isl.BasicMap( + "[nt] -> {[t] -> [t_outer, t_inner]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt}") + + # Call map_domain to transform kernel + knl_map_dom = lp.map_domain(knl_map_dom, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + knl_map_dom = lp.prioritize_loops(knl_map_dom, "x, t_outer, t_inner") + + # Get a linearization + proc_knl_map_dom = lp.preprocess_kernel(knl_map_dom) + lin_knl_map_dom = lp.get_one_linearized_kernel( + proc_knl_map_dom["loopy_kernel"], proc_knl_map_dom.callables_table) + + # }}} + + # {{{ Split iname and see if we get the same result + + knl_split_iname = ref_knl + knl_split_iname = lp.split_iname(knl_split_iname, "t", 32) + knl_split_iname = lp.prioritize_loops(knl_split_iname, "x, t_outer, t_inner") + proc_knl_split_iname = lp.preprocess_kernel(knl_split_iname) + lin_knl_split_iname = lp.get_one_linearized_kernel( + proc_knl_split_iname["loopy_kernel"], proc_knl_split_iname.callables_table) + + from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + ) + for d_map_domain, d_split_iname in zip( + knl_map_dom["loopy_kernel"].domains, + knl_split_iname["loopy_kernel"].domains): + d_map_domain_aligned = ensure_dim_names_match_and_align( + d_map_domain, d_split_iname) + assert d_map_domain_aligned == d_split_iname + + for litem_map_domain, litem_split_iname in zip( + lin_knl_map_dom.linearization, lin_knl_split_iname.linearization): + assert litem_map_domain == litem_split_iname + + # Can't easily compare instructions because equivalent subscript + # expressions may have different orders + + # }}} + +# }}} + + +# {{{ test_map_domain_with_transform_map_missing_dims + +def test_map_domain_with_transform_map_missing_dims(): + # Make sure map_domain works correctly when the mapping doesn't include + # all the dims in the domain. + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, y, z, t]: 0 <= x,y,z < nx and 0 <= t < nt}", + ], + """ + a[y,x,t,z] = b[y,x,t,z] {id=stmta} + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b": np.float32}) + ref_knl = knl + + # }}} + + # {{{ Apply domain change mapping + + knl_map_dom = ref_knl # loop priority goes away, deps stay + + # Create map_domain mapping that only includes t and y + # (x and z should be unaffected) + import islpy as isl + transform_map = isl.BasicMap( + "[nx,nt] -> {[t, y] -> [t_outer, t_inner, y_new]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt and " + "y = y_new" + "}") + + # Call map_domain to transform kernel + knl_map_dom = lp.map_domain(knl_map_dom, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + try: + # Use constrain_loop_nesting if it's available + desired_prio = "x, t_outer, t_inner, z, y_new" + knl_map_dom = lp.constrain_loop_nesting(knl_map_dom, desired_prio) + except AttributeError: + # For some reason, prioritize_loops can't handle the ordering above + # when linearizing knl_split_iname below + desired_prio = "z, y_new, x, t_outer, t_inner" + knl_map_dom = lp.prioritize_loops(knl_map_dom, desired_prio) + + # Get a linearization + proc_knl_map_dom = lp.preprocess_kernel(knl_map_dom) + lin_knl_map_dom = lp.get_one_linearized_kernel( + proc_knl_map_dom["loopy_kernel"], proc_knl_map_dom.callables_table) + + # }}} + + # {{{ Split iname and see if we get the same result + + knl_split_iname = ref_knl + knl_split_iname = lp.split_iname(knl_split_iname, "t", 32) + knl_split_iname = lp.rename_iname(knl_split_iname, "y", "y_new") + try: + # Use constrain_loop_nesting if it's available + knl_split_iname = lp.constrain_loop_nesting(knl_split_iname, desired_prio) + except AttributeError: + knl_split_iname = lp.prioritize_loops(knl_split_iname, desired_prio) + proc_knl_split_iname = lp.preprocess_kernel(knl_split_iname) + lin_knl_split_iname = lp.get_one_linearized_kernel( + proc_knl_split_iname["loopy_kernel"], proc_knl_split_iname.callables_table) + + from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + ) + for d_map_domain, d_split_iname in zip( + knl_map_dom["loopy_kernel"].domains, + knl_split_iname["loopy_kernel"].domains): + d_map_domain_aligned = ensure_dim_names_match_and_align( + d_map_domain, d_split_iname) + assert d_map_domain_aligned == d_split_iname + + for litem_map_domain, litem_split_iname in zip( + lin_knl_map_dom.linearization, lin_knl_split_iname.linearization): + assert litem_map_domain == litem_split_iname + + # Can't easily compare instructions because equivalent subscript + # expressions may have different orders + + # }}} + +# }}} + + +def test_diamond_tiling(ctx_factory, interactive=False): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + ref_knl = lp.make_kernel( + "[nx,nt] -> {[ix, it]: 1<=ix {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") + knl = lp.map_domain(knl_for_transform, m) + knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx") + + if interactive: + nx = 43 + u = np.zeros((nx, 200)) + x = np.linspace(-1, 1, nx) + dx = x[1] - x[0] + u[:, 0] = u[:, 1] = np.exp(-100*x**2) + + u_dev = cl.array.to_device(queue, u) + knl(queue, u=u_dev, dx=dx, dt=dx) + + u = u_dev.get() + import matplotlib.pyplot as plt + plt.imshow(u.T) + plt.show() + else: + types = {"dt,dx,u": np.float64} + knl = lp.add_and_infer_dtypes(knl, types) + ref_knl = lp.add_and_infer_dtypes(ref_knl, types) + + lp.auto_test_vs_ref(ref_knl, ctx, knl, + parameters={ + "nx": 200, "nt": 300, + "dx": 1, "dt": 1 + }) + + def test_extract_subst_with_iname_deps_in_templ(ctx_factory): knl = lp.make_kernel( "{[i, j, k]: 0<=i<100 and 0<=j,k<5}",