diff --git a/doc/ref_kernel.rst b/doc/ref_kernel.rst index 922315685..09eceb1d1 100644 --- a/doc/ref_kernel.rst +++ b/doc/ref_kernel.rst @@ -220,6 +220,8 @@ Tag Meaning Identifiers ----------- +.. _reserved-identifiers: + Reserved Identifiers ^^^^^^^^^^^^^^^^^^^^ diff --git a/loopy/__init__.py b/loopy/__init__.py index a73f83bb9..54c06c680 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -69,18 +69,20 @@ from loopy.version import VERSION, MOST_RECENT_LANGUAGE_VERSION from loopy.transform.iname import ( - set_loop_priority, prioritize_loops, untag_inames, + set_loop_priority, prioritize_loops, constrain_loop_nesting, + untag_inames, split_iname, chunk_iname, join_inames, tag_inames, duplicate_inames, rename_iname, remove_unused_inames, split_reduction_inward, split_reduction_outward, affine_map_inames, find_unused_axis_tag, make_reduction_inames_unique, has_schedulable_iname_nesting, get_iname_duplication_options, - add_inames_to_insn, add_inames_for_unused_hw_axes) + add_inames_to_insn, add_inames_for_unused_hw_axes, map_domain) from loopy.transform.instruction import ( find_instructions, map_instructions, - set_instruction_priority, add_dependency, + set_instruction_priority, + add_dependency, add_dependency_v2, remove_instructions, replace_instruction_ids, tag_instructions, @@ -129,6 +131,8 @@ from loopy.schedule import ( generate_loop_schedules, get_one_scheduled_kernel, get_one_linearized_kernel, linearize) +from loopy.schedule.checker import ( + find_unsatisfied_dependencies) from loopy.statistics import (ToCountMap, ToCountPolynomialMap, CountGranularity, stringify_stats_mapping, Op, MemAccess, get_op_map, get_mem_access_map, get_synchronization_map, gather_access_footprints, @@ -194,7 +198,8 @@ # {{{ transforms - "set_loop_priority", "prioritize_loops", "untag_inames", + "set_loop_priority", "prioritize_loops", "constrain_loop_nesting", + "untag_inames", "split_iname", "chunk_iname", "join_inames", "tag_inames", "duplicate_inames", "rename_iname", "remove_unused_inames", @@ -202,7 +207,7 @@ "affine_map_inames", "find_unused_axis_tag", "make_reduction_inames_unique", "has_schedulable_iname_nesting", "get_iname_duplication_options", - "add_inames_to_insn", "add_inames_for_unused_hw_axes", + "add_inames_to_insn", "add_inames_for_unused_hw_axes", "map_domain", "add_prefetch", "change_arg_to_image", "tag_array_axes", "tag_data_axes", @@ -212,7 +217,8 @@ "rename_argument", "set_temporary_scope", "find_instructions", "map_instructions", - "set_instruction_priority", "add_dependency", + "set_instruction_priority", + "add_dependency", "add_dependency_v2", "remove_instructions", "replace_instruction_ids", "tag_instructions", @@ -268,7 +274,7 @@ "generate_loop_schedules", "get_one_scheduled_kernel", "get_one_linearized_kernel", "linearize", - + "find_unsatisfied_dependencies", "GeneratedProgram", "CodeGenerationResult", "PreambleInfo", "generate_code", "generate_code_v2", "generate_body", diff --git a/loopy/kernel/__init__.py b/loopy/kernel/__init__.py index 2f39614b8..43d97453d 100644 --- a/loopy/kernel/__init__.py +++ b/loopy/kernel/__init__.py @@ -249,6 +249,7 @@ class LoopKernel(ImmutableRecordWithoutPickling, Taggable): .. automethod:: tagged .. automethod:: without_tags """ + # TODO document loop_nest_constraints attribute # {{{ constructor @@ -268,6 +269,7 @@ def __init__(self, domains, instructions, args=None, iname_slab_increments=None, loop_priority=frozenset(), + loop_nest_constraints=None, silenced_warnings=None, applied_iname_rewrites=None, @@ -380,6 +382,7 @@ def __init__(self, domains, instructions, args=None, assumptions=assumptions, iname_slab_increments=iname_slab_increments, loop_priority=loop_priority, + loop_nest_constraints=loop_nest_constraints, silenced_warnings=silenced_warnings, temporary_variables=temporary_variables, local_sizes=local_sizes, @@ -1543,6 +1546,7 @@ def __setstate__(self, state): "substitutions", "iname_slab_increments", "loop_priority", + "loop_nest_constraints", "silenced_warnings", "options", "state", diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 9fb9757a6..58e06c2b9 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -75,7 +75,7 @@ class UseStreamingStoreTag(Tag): # {{{ instructions: base class class InstructionBase(ImmutableRecord, Taggable): - """A base class for all types of instruction that can occur in + r"""A base class for all types of instruction that can occur in a kernel. .. attribute:: id @@ -87,7 +87,7 @@ class InstructionBase(ImmutableRecord, Taggable): .. attribute:: depends_on - a :class:`frozenset` of :attr:`id` values of :class:`InstructionBase` + A :class:`frozenset` of :attr:`id` values of :class:`InstructionBase` instances that *must* be executed before this one. Note that :func:`loopy.preprocess_kernel` (usually invoked automatically) augments this by adding dependencies on any writes to temporaries read @@ -106,6 +106,15 @@ class InstructionBase(ImmutableRecord, Taggable): :func:`loopy.make_kernel`. Note, that this is not meant as a user-facing interface. + .. attribute:: dependencies + + A :class:`dict` mapping :attr:`id` values of :class:`InstructionBase` + instances (each referring to a statement with statement instances that + must be executed before instances of this statement) to lists (one list + per key) of class:`islpy.Map`\ s mapping each instance of the dependee + statement to all instances of this statement that must occur later. Note + that this dict will eventually replace the `depends_on` attribute. + .. attribute:: depends_on_is_final A :class:`bool` determining whether :attr:`depends_on` constitutes @@ -212,6 +221,8 @@ class InstructionBase(ImmutableRecord, Taggable): pymbolic_set_fields = {"predicates"} def __init__(self, id, depends_on, depends_on_is_final, + dependencies, + non_linearizing_deps, groups, conflicts_with_groups, no_sync_with, within_inames_is_final, within_inames, @@ -241,6 +252,12 @@ def __init__(self, id, depends_on, depends_on_is_final, if depends_on is None: depends_on = frozenset() + if dependencies is None: + dependencies = {} + # TODO dependee ids for deps that don't affect cartoon dag + if non_linearizing_deps is None: + non_linearizing_deps = set() + if groups is None: groups = frozenset() @@ -297,6 +314,8 @@ def __init__(self, id, depends_on, depends_on_is_final, id=id, depends_on=depends_on, depends_on_is_final=depends_on_is_final, + dependencies=dependencies, + non_linearizing_deps=non_linearizing_deps, # TODO no_sync_with=no_sync_with, groups=groups, conflicts_with_groups=conflicts_with_groups, within_inames_is_final=within_inames_is_final, @@ -392,6 +411,11 @@ def get_str_options(self): if self.depends_on: result.append("dep="+":".join(self.depends_on)) + if self.dependencies: + result.append("dependencies="+":".join(self.dependencies.keys())) + if self.non_linearizing_deps: + result.append( + "non_linearizing_deps="+":".join(self.non_linearizing_deps)) if self.no_sync_with: result.append("nosync="+":".join( "%s@%s" % entry for entry in self.no_sync_with)) @@ -461,6 +485,8 @@ def __setstate__(self, val): if self.id is not None: # pylint:disable=access-member-before-definition self.id = intern(self.id) self.depends_on = intern_frozenset_of_ids(self.depends_on) + # TODO something with dependencies? + # TODO something with non_linearizing_deps? self.groups = intern_frozenset_of_ids(self.groups) self.conflicts_with_groups = ( intern_frozenset_of_ids(self.conflicts_with_groups)) @@ -887,6 +913,8 @@ def __init__(self, id=None, depends_on=None, depends_on_is_final=None, + dependencies=None, + non_linearizing_deps=None, # TODO groups=None, conflicts_with_groups=None, no_sync_with=None, @@ -903,6 +931,8 @@ def __init__(self, id=id, depends_on=depends_on, depends_on_is_final=depends_on_is_final, + dependencies=dependencies, + non_linearizing_deps=non_linearizing_deps, # TODO groups=groups, conflicts_with_groups=conflicts_with_groups, no_sync_with=no_sync_with, @@ -1038,6 +1068,8 @@ def __init__(self, id=None, depends_on=None, depends_on_is_final=None, + dependencies=None, + non_linearizing_deps=None, # TODO groups=None, conflicts_with_groups=None, no_sync_with=None, @@ -1051,6 +1083,8 @@ def __init__(self, id=id, depends_on=depends_on, depends_on_is_final=depends_on_is_final, + dependencies=dependencies, + non_linearizing_deps=non_linearizing_deps, # TODO groups=groups, conflicts_with_groups=conflicts_with_groups, no_sync_with=no_sync_with, @@ -1329,13 +1363,21 @@ class CInstruction(InstructionBase): def __init__(self, iname_exprs, code, - read_variables=frozenset(), assignees=tuple(), - id=None, depends_on=None, depends_on_is_final=None, - groups=None, conflicts_with_groups=None, + read_variables=frozenset(), + assignees=tuple(), + id=None, + depends_on=None, + depends_on_is_final=None, + dependencies=None, + non_linearizing_deps=None, # TODO + groups=None, + conflicts_with_groups=None, no_sync_with=None, - within_inames_is_final=None, within_inames=None, + within_inames_is_final=None, + within_inames=None, priority=0, - predicates=frozenset(), tags=None): + predicates=frozenset(), + tags=None): """ :arg iname_exprs: Like :attr:`iname_exprs`, but instead of tuples, simple strings pepresenting inames are also allowed. A single @@ -1350,11 +1392,14 @@ def __init__(self, id=id, depends_on=depends_on, depends_on_is_final=depends_on_is_final, + dependencies=dependencies, + non_linearizing_deps=non_linearizing_deps, # TODO groups=groups, conflicts_with_groups=conflicts_with_groups, no_sync_with=no_sync_with, within_inames_is_final=within_inames_is_final, within_inames=within_inames, - priority=priority, predicates=predicates, tags=tags) + priority=priority, predicates=predicates, + tags=tags) # {{{ normalize iname_exprs @@ -1495,16 +1540,27 @@ class NoOpInstruction(_DataObliviousInstruction): ... nop """ - def __init__(self, id=None, depends_on=None, depends_on_is_final=None, - groups=None, conflicts_with_groups=None, + def __init__( + self, + id=None, + depends_on=None, + depends_on_is_final=None, + dependencies=None, # TODO + non_linearizing_deps=None, + groups=None, + conflicts_with_groups=None, no_sync_with=None, - within_inames_is_final=None, within_inames=None, + within_inames_is_final=None, + within_inames=None, priority=None, - predicates=None, tags=None): + predicates=None, + tags=None): super().__init__( id=id, depends_on=depends_on, depends_on_is_final=depends_on_is_final, + dependencies=dependencies, + non_linearizing_deps=non_linearizing_deps, # TODO groups=groups, conflicts_with_groups=conflicts_with_groups, no_sync_with=no_sync_with, @@ -1554,12 +1610,22 @@ class BarrierInstruction(_DataObliviousInstruction): fields = _DataObliviousInstruction.fields | {"synchronization_kind", "mem_kind"} - def __init__(self, id, depends_on=None, depends_on_is_final=None, - groups=None, conflicts_with_groups=None, + def __init__( + self, + id, + depends_on=None, + depends_on_is_final=None, + dependencies=None, # TODO + non_linearizing_deps=None, + groups=None, + conflicts_with_groups=None, no_sync_with=None, - within_inames_is_final=None, within_inames=None, + within_inames_is_final=None, + within_inames=None, priority=None, - predicates=None, tags=None, synchronization_kind="global", + predicates=None, + tags=None, + synchronization_kind="global", mem_kind="local"): if predicates: @@ -1569,6 +1635,8 @@ def __init__(self, id, depends_on=None, depends_on_is_final=None, id=id, depends_on=depends_on, depends_on_is_final=depends_on_is_final, + dependencies=dependencies, + non_linearizing_deps=non_linearizing_deps, # TODO groups=groups, conflicts_with_groups=conflicts_with_groups, no_sync_with=no_sync_with, diff --git a/loopy/options.py b/loopy/options.py index 9f12814b0..da242e648 100644 --- a/loopy/options.py +++ b/loopy/options.py @@ -242,6 +242,7 @@ def __init__( disable_global_barriers=kwargs.get("disable_global_barriers", False), check_dep_resolution=kwargs.get("check_dep_resolution", True), + use_dependencies_v2=kwargs.get("use_dependencies_v2", False), enforce_variable_access_ordered=kwargs.get( "enforce_variable_access_ordered", True), diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index 5822f44ed..7aecacfd6 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -253,57 +253,100 @@ def find_loop_nest_around_map(kernel): return result -def find_loop_insn_dep_map(kernel, loop_nest_with_map, loop_nest_around_map): +def find_loop_insn_dep_map( + kernel, loop_nest_with_map, loop_nest_around_map, + simplified_depends_on_graph): """Returns a dictionary mapping inames to other instruction ids that need to be scheduled before the iname should be eligible for scheduling. + + :arg loop_nest_with_map: Dictionary mapping iname1 to a set containing + iname2 iff either iname1 nests around iname2 or iname2 nests around + iname1 + + :arg loop_nest_around_map: Dictionary mapping iname1 to a set containing + iname2 iff iname2 nests around iname1 + + :arg simplified_depends_on_graph: Dictionary mapping depender statement IDs + to sets of dependee statement IDs, as produced by + `loopy.schedule.checker.dependency.filter_deps_by_intersection_with_SAME`, + which will be used to acquire depndee statement ids if + `kernel.options.use_dependencies_v2` is 'True' (otherwise old + dependencies in insn.depends_on will be used). + """ result = {} from loopy.kernel.data import ConcurrentTag, IlpBaseTag + # For each insn, examine its inames (`iname`) and its dependees' inames + # (`dep_iname`) to determine which instructions must be scheduled before + # entering the iname loop. + # Create result dict, which maps iname to instructions that must be + # scheduled prior to entering iname. + + # For each insn, loop over its non-concurrent inames (`iname`) for insn in kernel.instructions: for iname in kernel.insn_inames(insn): + # (Ignore concurrent inames) if kernel.iname_tags_of_type(iname, ConcurrentTag): continue + # Let iname_dep be the set of ids associated with result[iname] + # (if iname is not already in result, add iname as a key) iname_dep = result.setdefault(iname, set()) - for dep_insn_id in insn.depends_on: + # Loop over instructions on which insn depends (dep_insn) + # and determine whether dep_insn must be schedued before + # iname, in which case add its id to iname_dep (result[iname]) + if kernel.options.use_dependencies_v2: + dependee_ids = simplified_depends_on_graph.get(insn.id, set()) + else: + dependee_ids = insn.depends_on + + for dep_insn_id in dependee_ids: if dep_insn_id in iname_dep: # already depending, nothing to check continue - dep_insn = kernel.id_to_insn[dep_insn_id] - dep_insn_inames = dep_insn.within_inames + dep_insn = kernel.id_to_insn[dep_insn_id] # Dependee + dep_insn_inames = dep_insn.within_inames # Dependee inames + # Check whether insn's iname is also in dependee inames if iname in dep_insn_inames: - # Nothing to be learned, dependency is in loop over iname + # Nothing to be learned, dependee is inside loop over iname # already. continue # To make sure dep_insn belongs outside of iname, we must prove - # that all inames that dep_insn will be executed in nest + # that all inames in which dep_insn will be executed nest # outside of the loop over *iname*. (i.e. nested around, or # before). + # Loop over each of the dependee's inames (dep_insn_iname) may_add_to_loop_dep_map = True for dep_insn_iname in dep_insn_inames: + + # If loop_nest_around_map says dep_insn_iname nests around + # iname, dep_insn_iname is guaranteed to nest outside of + # iname, we're safe, so continue if dep_insn_iname in loop_nest_around_map[iname]: - # dep_insn_iname is guaranteed to nest outside of iname - # -> safe. continue + # If dep_insn_iname is concurrent, continue + # (parallel tags don't really nest, so disregard them here) if kernel.iname_tags_of_type(dep_insn_iname, (ConcurrentTag, IlpBaseTag)): - # Parallel tags don't really nest, so we'll disregard - # them here. continue + # If loop_nest_with_map says dep_insn_iname does not nest + # inside or around iname, it must be nested separately; + # we're safe, so continue if dep_insn_iname not in loop_nest_with_map.get(iname, []): - # dep_insn_iname does not nest with iname, so its nest - # must occur outside. continue + # If none of the three cases above succeeds for any + # dep_insn_iname in dep_insn_inames, we cannot add dep_insn + # to iname's set of insns in result dict. may_add_to_loop_dep_map = False break @@ -318,6 +361,10 @@ def find_loop_insn_dep_map(kernel, loop_nest_with_map, loop_nest_around_map): dep_insn=dep_insn_id, insn=insn.id)) + # If at least one of the three cases above succeeds for every + # dep_insn_iname, we can add dep_insn to iname's set of insns + # in result dict. + # (means dep_insn must be scheduled before entering iname loop) iname_dep.add(dep_insn_id) return result @@ -333,16 +380,24 @@ def group_insn_counts(kernel): return result -def gen_dependencies_except(kernel, insn_id, except_insn_ids): - insn = kernel.id_to_insn[insn_id] - for dep_id in insn.depends_on: +def gen_dependencies_except( + kernel, insn_id, except_insn_ids, simplified_depends_on_graph): + + # Get dependee IDs + if kernel.options.use_dependencies_v2: + dependee_ids = simplified_depends_on_graph.get(insn_id, set()) + else: + dependee_ids = kernel.id_to_insn[insn_id].depends_on + + for dep_id in dependee_ids: if dep_id in except_insn_ids: continue yield dep_id - yield from gen_dependencies_except(kernel, dep_id, except_insn_ids) + yield from gen_dependencies_except( + kernel, dep_id, except_insn_ids, simplified_depends_on_graph) def get_priority_tiers(wanted, priorities): @@ -631,8 +686,10 @@ class SchedulerState(ImmutableRecord): order with instruction priorities as tie breaker. """ + # TODO document simplified_depends_on_graph + @property - def last_entered_loop(self): + def deepest_active_iname(self): if self.active_inames: return self.active_inames[-1] else: @@ -641,12 +698,20 @@ def last_entered_loop(self): # }}} -def get_insns_in_topologically_sorted_order(kernel): +def get_insns_in_topologically_sorted_order( + kernel, simplified_depends_on_graph): from pytools.graph import compute_topological_order rev_dep_map = {insn.id: set() for insn in kernel.instructions} for insn in kernel.instructions: - for dep in insn.depends_on: + + if kernel.options.use_dependencies_v2: + dependee_ids = simplified_depends_on_graph.get( + insn.id, set()) + else: + dependee_ids = insn.depends_on + + for dep in dependee_ids: rev_dep_map[dep].add(insn.id) # For breaking ties, we compare the features of an intruction @@ -680,7 +745,8 @@ def key(insn_id): # {{{ schedule_as_many_run_insns_as_possible -def schedule_as_many_run_insns_as_possible(sched_state, template_insn): +def schedule_as_many_run_insns_as_possible( + sched_state, template_insn, use_dependencies_v2): """ Returns an instance of :class:`loopy.schedule.SchedulerState`, by appending all reachable instructions that are similar to *template_insn*. We define @@ -748,7 +814,14 @@ def is_similar_to_template(insn): if is_similar_to_template(insn): # check reachability - if not (insn.depends_on & ignored_unscheduled_insn_ids): + + if use_dependencies_v2: + dependee_ids = sched_state.simplified_depends_on_graph.get( + insn.id, set()) + else: + dependee_ids = insn.depends_on + + if not (dependee_ids & ignored_unscheduled_insn_ids): if insn.id in sched_state.prescheduled_insn_ids: if next_preschedule_insn_id() == insn.id: preschedule.pop(0) @@ -937,24 +1010,34 @@ def insn_sort_key(insn_id): for insn_id in insn_ids_to_try: insn = kernel.id_to_insn[insn_id] - is_ready = insn.depends_on <= sched_state.scheduled_insn_ids + # make sure dependees have been scheduled + if kernel.options.use_dependencies_v2: + dependee_ids = sched_state.simplified_depends_on_graph.get( + insn.id, set()) + else: + dependee_ids = insn.depends_on + + is_ready = dependee_ids <= sched_state.scheduled_insn_ids if not is_ready: continue - want = insn.within_inames - sched_state.parallel_inames - have = active_inames_set - sched_state.parallel_inames + nonconc_insn_inames_wanted = insn.within_inames - sched_state.parallel_inames + nonconc_active_inames = active_inames_set - sched_state.parallel_inames - if want != have: + if nonconc_insn_inames_wanted != nonconc_active_inames: + # We don't have the inames we need, may need to open more loops is_ready = False if debug_mode: - if want-have: + if nonconc_insn_inames_wanted-nonconc_active_inames: print("instruction '%s' is missing inames '%s'" - % (format_insn(kernel, insn.id), ",".join(want-have))) - if have-want: + % (format_insn(kernel, insn.id), ",".join( + nonconc_insn_inames_wanted-nonconc_active_inames))) + if nonconc_active_inames-nonconc_insn_inames_wanted: print("instruction '%s' won't work under inames '%s'" - % (format_insn(kernel, insn.id), ",".join(have-want))) + % (format_insn(kernel, insn.id), ",".join( + nonconc_active_inames-nonconc_insn_inames_wanted))) # {{{ check if scheduling this insn is compatible with preschedule @@ -1006,9 +1089,10 @@ def insn_sort_key(insn_id): # }}} - # {{{ determine reachability + # {{{ determine reachability (no active inames conflict w/insn, but + # may need more inames) - if (not is_ready and have <= want): + if (not is_ready and nonconc_active_inames <= nonconc_insn_inames_wanted): reachable_insn_ids.add(insn_id) # }}} @@ -1016,7 +1100,13 @@ def insn_sort_key(insn_id): if is_ready and debug_mode: print("ready to schedule '%s'" % format_insn(kernel, insn.id)) + # (if we wanted, we could check to see whether adding insn would + # violate dependencies_v2 here, as done in old in-progress branch: + # https://gitlab.tiker.net/jdsteve2/loopy/-/merge_requests/15/diffs) + if is_ready and not debug_mode: + # schedule this instruction and recurse + iid_set = frozenset([insn.id]) # {{{ update active group counts for added instruction @@ -1068,8 +1158,8 @@ def insn_sort_key(insn_id): insns_in_topologically_sorted_order=new_toposorted_insns, ) - new_sched_state = schedule_as_many_run_insns_as_possible(new_sched_state, - insn) + new_sched_state = schedule_as_many_run_insns_as_possible( + new_sched_state, insn, kernel.options.use_dependencies_v2) # Don't be eager about entering/leaving loops--if progress has been # made, revert to top of scheduler and see if more progress can be @@ -1086,42 +1176,49 @@ def insn_sort_key(insn_id): # }}} + # No insns are ready to be scheduled now, but some may be reachable + # reachable_insn_ids = no active inames conflict w/insn, but may need more inames + # {{{ see if we're ready to leave the innermost loop - last_entered_loop = sched_state.last_entered_loop + deepest_active_iname = sched_state.deepest_active_iname - if last_entered_loop is not None: + if deepest_active_iname is not None: can_leave = True if ( - last_entered_loop in sched_state.prescheduled_inames + deepest_active_iname in sched_state.prescheduled_inames and not ( isinstance(next_preschedule_item, LeaveLoop) - and next_preschedule_item.iname == last_entered_loop)): + and next_preschedule_item.iname == deepest_active_iname)): # A prescheduled loop can only be left if the preschedule agrees. if debug_mode: print("cannot leave '%s' because of preschedule constraints" - % last_entered_loop) + % deepest_active_iname) can_leave = False - elif last_entered_loop not in sched_state.breakable_inames: + elif deepest_active_iname not in sched_state.breakable_inames: # If the iname is not breakable, then check that we've # scheduled all the instructions that require it. for insn_id in sched_state.unscheduled_insn_ids: insn = kernel.id_to_insn[insn_id] - if last_entered_loop in insn.within_inames: + if deepest_active_iname in insn.within_inames: + # cannot leave deepest_active_iname; insn still depends on it if debug_mode: print("cannot leave '%s' because '%s' still depends on it" - % (last_entered_loop, format_insn(kernel, insn.id))) + % (deepest_active_iname, format_insn(kernel, insn.id))) # check if there's a dependency of insn that needs to be - # outside of last_entered_loop. - for subdep_id in gen_dependencies_except(kernel, insn_id, - sched_state.scheduled_insn_ids): - want = (kernel.insn_inames(subdep_id) + # outside of deepest_active_iname. + for subdep_id in gen_dependencies_except( + kernel, insn_id, + sched_state.scheduled_insn_ids, + sched_state.simplified_depends_on_graph): + nonconc_subdep_insn_inames_wanted = ( + kernel.insn_inames(subdep_id) - sched_state.parallel_inames) - if ( - last_entered_loop not in want): + if (deepest_active_iname + not in nonconc_subdep_insn_inames_wanted): print( "%(warn)swarning:%(reset_all)s '%(iname)s', " "which the schedule is " @@ -1135,7 +1232,7 @@ def insn_sort_key(insn_id): % { "warn": Fore.RED + Style.BRIGHT, "reset_all": Style.RESET_ALL, - "iname": last_entered_loop, + "iname": deepest_active_iname, "subdep": format_insn_id(kernel, subdep_id), "dep": format_insn_id(kernel, insn_id), "subdep_i": format_insn(kernel, subdep_id), @@ -1162,23 +1259,57 @@ def insn_sort_key(insn_id): if ignore_count: ignore_count -= 1 else: - assert sched_item.iname == last_entered_loop + assert sched_item.iname == deepest_active_iname if seen_an_insn: can_leave = True break + # {{{ don't leave if doing so would violate must_nest constraints + + # don't leave if must_nest constraints require that + # additional inames be nested inside the current iname + if can_leave: + must_nest_graph = ( + sched_state.kernel.loop_nest_constraints.must_nest_graph + if sched_state.kernel.loop_nest_constraints else None) + + if must_nest_graph: + # get inames that must nest inside the current iname + must_nest_inside = must_nest_graph[deepest_active_iname] + + if must_nest_inside: + # get scheduled inames that are nested inside current iname + within_deepest_active_iname = False + actually_nested_inside = set() + for sched_item in sched_state.schedule: + if isinstance(sched_item, EnterLoop): + if within_deepest_active_iname: + actually_nested_inside.add(sched_item.iname) + elif sched_item.iname == deepest_active_iname: + within_deepest_active_iname = True + elif (isinstance(sched_item, LeaveLoop) and + sched_item.iname == deepest_active_iname): + break + + # don't leave if must_nest constraints require that + # additional inames be nested inside the current iname + if not must_nest_inside.issubset(actually_nested_inside): + can_leave = False + + # }}} + if can_leave and not debug_mode: for sub_sched in generate_loop_schedules_internal( sched_state.copy( schedule=( sched_state.schedule - + (LeaveLoop(iname=last_entered_loop),)), + + (LeaveLoop(iname=deepest_active_iname),)), active_inames=sched_state.active_inames[:-1], insn_ids_to_try=insn_ids_to_try, preschedule=( sched_state.preschedule - if last_entered_loop + if deepest_active_iname not in sched_state.prescheduled_inames else sched_state.preschedule[1:]), ), @@ -1192,11 +1323,11 @@ def insn_sort_key(insn_id): # {{{ see if any loop can be entered now # Find inames that are being referenced by as yet unscheduled instructions. - needed_inames = set() + unscheduled_nonconc_insn_inames_needed = set() for insn_id in sched_state.unscheduled_insn_ids: - needed_inames.update(kernel.insn_inames(insn_id)) + unscheduled_nonconc_insn_inames_needed.update(kernel.insn_inames(insn_id)) - needed_inames = (needed_inames + unscheduled_nonconc_insn_inames_needed = (unscheduled_nonconc_insn_inames_needed # There's no notion of 'entering' a parallel loop - sched_state.parallel_inames @@ -1205,7 +1336,8 @@ def insn_sort_key(insn_id): if debug_mode: print(75*"-") - print("inames still needed :", ",".join(needed_inames)) + print("inames still needed :", ",".join( + unscheduled_nonconc_insn_inames_needed)) print("active inames :", ",".join(sched_state.active_inames)) print("inames entered so far :", ",".join(sched_state.entered_inames)) print("reachable insns:", ",".join(reachable_insn_ids)) @@ -1214,12 +1346,15 @@ def insn_sort_key(insn_id): for grp, c in sched_state.active_group_counts.items())) print(75*"-") - if needed_inames: + if unscheduled_nonconc_insn_inames_needed: iname_to_usefulness = {} - for iname in needed_inames: + for iname in unscheduled_nonconc_insn_inames_needed: # {{{ check if scheduling this iname now is allowed/plausible + # based on preschedule constraints, loop_nest_around_map, + # loop_insn_dep_map, and data dependencies; + # if not, continue if ( iname in sched_state.prescheduled_inames @@ -1233,6 +1368,9 @@ def insn_sort_key(insn_id): currently_accessible_inames = ( active_inames_set | sched_state.parallel_inames) + + # check loop_nest_around_map to determine whether inames that must + # nest around iname are available if ( not sched_state.loop_nest_around_map[iname] <= currently_accessible_inames): @@ -1240,6 +1378,9 @@ def insn_sort_key(insn_id): print("scheduling %s prohibited by loop nest-around map" % iname) continue + # loop_insn_dep_map: dict mapping inames to other insn ids that need to + # be scheduled before the iname should be eligible for scheduling. + # If loop dependency map prohibits scheduling of iname, continue if ( not sched_state.loop_insn_dep_map.get(iname, set()) <= sched_state.scheduled_insn_ids): @@ -1289,23 +1430,31 @@ def insn_sort_key(insn_id): # }}} + # so far, scheduling of iname is allowed/plausible + # {{{ determine if that gets us closer to being able to schedule an insn usefulness = None # highest insn priority enabled by iname + # suppose we were to activate this iname... + # would that get us closer to scheduling an insn? + hypothetically_active_loops = active_inames_set | {iname} + # loop over reachable_insn_ids (reachable insn: no active inames + # conflict w/insn, but may need more inames) for insn_id in reachable_insn_ids: insn = kernel.id_to_insn[insn_id] - want = insn.within_inames + wanted_insn_inames = insn.within_inames - if hypothetically_active_loops <= want: + if hypothetically_active_loops <= wanted_insn_inames: if usefulness is None: usefulness = insn.priority else: usefulness = max(usefulness, insn.priority) if usefulness is None: + # iname won't get us closer to scheduling insn if debug_mode: print("iname '%s' deemed not useful" % iname) continue @@ -1314,67 +1463,128 @@ def insn_sort_key(insn_id): # }}} - # {{{ tier building - - # Build priority tiers. If a schedule is found in the first tier, then - # loops in the second are not even tried (and so on). - loop_priority_set = set().union(*[set(prio) - for prio in - sched_state.kernel.loop_priority]) - useful_loops_set = set(iname_to_usefulness.keys()) - useful_and_desired = useful_loops_set & loop_priority_set - - if useful_and_desired: - wanted = ( - useful_and_desired - - sched_state.ilp_inames - - sched_state.vec_inames - ) - priority_tiers = [t for t in - get_priority_tiers(wanted, - sched_state.kernel.loop_priority - ) - ] - - # Update the loop priority set, because some constraints may have - # have been contradictary. - loop_priority_set = set().union(*[set(t) for t in priority_tiers]) - - priority_tiers.append( + # keys of iname_to_usefulness are now inames that get us closer to + # scheduling an insn + + if sched_state.kernel.loop_nest_constraints: + # {{{ use loop_nest_constraints in determining next_iname_candidates + + # inames not yet entered that would get us closer to scheduling an insn: + useful_loops_set = set(iname_to_usefulness.keys()) + + from loopy.transform.iname import ( + check_all_must_not_nests, + get_graph_sources, + ) + from pytools.graph import compute_induced_subgraph + + # since vec_inames must be innermost, + # they are not valid canidates unless only vec_inames remain + if useful_loops_set - sched_state.vec_inames: + useful_loops_set -= sched_state.vec_inames + + # to enter an iname without violating must_nest constraints, + # iname must be a source in the induced subgraph of must_nest_graph + # containing inames in useful_loops_set + must_nest_graph_full = ( + sched_state.kernel.loop_nest_constraints.must_nest_graph + if sched_state.kernel.loop_nest_constraints else None) + if must_nest_graph_full: + must_nest_graph_useful = compute_induced_subgraph( + must_nest_graph_full, useful_loops_set - - loop_priority_set - - sched_state.ilp_inames - - sched_state.vec_inames ) + source_inames = get_graph_sources(must_nest_graph_useful) + else: + source_inames = useful_loops_set + + # since graph has a key for every iname, + # sources should be the only valid iname candidates + + # check whether entering any source_inames violates + # must-not-nest constraints, given the currently active inames + must_not_nest_constraints = ( + sched_state.kernel.loop_nest_constraints.must_not_nest + if sched_state.kernel.loop_nest_constraints else None) + if must_not_nest_constraints: + next_iname_candidates = set() + for next_iname in source_inames: + iname_orders_to_check = [ + (active_iname, next_iname) + for active_iname in active_inames_set] + + if check_all_must_not_nests( + iname_orders_to_check, must_not_nest_constraints): + next_iname_candidates.add(next_iname) + else: + next_iname_candidates = source_inames + + # }}} else: - priority_tiers = [ - useful_loops_set + # {{{ old tier building + + # Build priority tiers. If a schedule is found in the first tier, then + # loops in the second are not even tried (and so on). + loop_priority_set = set().union(*[set(prio) + for prio in + sched_state.kernel.loop_priority]) + useful_loops_set = set(iname_to_usefulness.keys()) + useful_and_desired = useful_loops_set & loop_priority_set + + if useful_and_desired: + wanted = ( + useful_and_desired - sched_state.ilp_inames - sched_state.vec_inames - ] - - # vectorization must be the absolute innermost loop - priority_tiers.extend([ - [iname] - for iname in sched_state.ilp_inames - if iname in useful_loops_set - ]) + ) + priority_tiers = [t for t in + get_priority_tiers(wanted, + sched_state.kernel.loop_priority + ) + ] + + # Update the loop priority set, because some constraints may have + # have been contradictary. + loop_priority_set = set().union(*[set(t) for t in priority_tiers]) + + priority_tiers.append( + useful_loops_set + - loop_priority_set + - sched_state.ilp_inames + - sched_state.vec_inames + ) + else: + priority_tiers = [ + useful_loops_set + - sched_state.ilp_inames + - sched_state.vec_inames + ] + + # vectorization must be the absolute innermost loop + priority_tiers.extend([ + [iname] + for iname in sched_state.ilp_inames + if iname in useful_loops_set + ]) + + priority_tiers.extend([ + [iname] + for iname in sched_state.vec_inames + if iname in useful_loops_set + ]) - priority_tiers.extend([ - [iname] - for iname in sched_state.vec_inames - if iname in useful_loops_set - ]) + # }}} - # }}} + if sched_state.kernel.loop_nest_constraints: + # {{{ loop over next_iname_candidates generated w/ loop_nest_constraints - if debug_mode: - print("useful inames: %s" % ",".join(useful_loops_set)) - else: - for tier in priority_tiers: + if debug_mode: + print("useful inames: %s" % ",".join(useful_loops_set)) + else: found_viable_schedule = False - for iname in sorted(tier, + # loop over iname candidates; enter inames and recurse: + for iname in sorted(next_iname_candidates, key=lambda iname: ( iname_to_usefulness.get(iname, 0), # Sort by iname to achieve deterministic @@ -1382,6 +1592,7 @@ def insn_sort_key(insn_id): iname), reverse=True): + # enter the loop and recurse for sub_sched in generate_loop_schedules_internal( sched_state.copy( schedule=( @@ -1395,16 +1606,63 @@ def insn_sort_key(insn_id): insn_ids_to_try=insn_ids_to_try, preschedule=( sched_state.preschedule - if iname not in sched_state.prescheduled_inames + if iname not in + sched_state.prescheduled_inames else sched_state.preschedule[1:]), ), debug=debug): + found_viable_schedule = True yield sub_sched + # TODO what happened if found_viable_schedule is false? if found_viable_schedule: return + # }}} + else: + # {{{ old looping over tiers + + if debug_mode: + print("useful inames: %s" % ",".join(useful_loops_set)) + else: + for tier in priority_tiers: + found_viable_schedule = False + + for iname in sorted(tier, + key=lambda iname: ( + iname_to_usefulness.get(iname, 0), + # Sort by iname to achieve deterministic + # ordering of generated schedules. + iname), + reverse=True): + + for sub_sched in generate_loop_schedules_internal( + sched_state.copy( + schedule=( + sched_state.schedule + + (EnterLoop(iname=iname),)), + active_inames=( + sched_state.active_inames + (iname,)), + entered_inames=( + sched_state.entered_inames + | frozenset((iname,))), + insn_ids_to_try=insn_ids_to_try, + preschedule=( + sched_state.preschedule + if iname not in + sched_state.prescheduled_inames + else sched_state.preschedule[1:]), + ), + debug=debug): + found_viable_schedule = True + yield sub_sched + + if found_viable_schedule: + return + + # }}} + # }}} if debug_mode: @@ -1415,10 +1673,32 @@ def insn_sort_key(insn_id): if inp: raise ScheduleDebugInputError(inp) + # {{{ make sure ALL must_nest_constraints are satisfied + + # (the check above avoids contradicting some must_nest constraints, + # but we don't know if all required nestings are present) + # TODO is this the only place we need to check all must_nest constraints? + must_constraints_satisfied = True + if sched_state.kernel.loop_nest_constraints: + from loopy.transform.iname import ( + get_iname_nestings, + loop_nest_constraints_satisfied, + ) + must_nest_constraints = sched_state.kernel.loop_nest_constraints.must_nest + if must_nest_constraints: + sched_tiers = get_iname_nestings(sched_state.schedule) + must_constraints_satisfied = loop_nest_constraints_satisfied( + sched_tiers, must_nest_constraints, + must_not_nest_constraints=None, # (checked upon loop creation) + all_inames=kernel.all_inames()) + + # }}} + if ( not sched_state.active_inames and not sched_state.unscheduled_insn_ids - and not sched_state.preschedule): + and not sched_state.preschedule + and must_constraints_satisfied): # if done, yield result debug.log_success(sched_state.schedule) @@ -1754,10 +2034,10 @@ def _insn_ids_reaching_end(schedule, kind, reverse): return insn_ids_alive_at_scope[-1] -def append_barrier_or_raise_error(kernel_name, schedule, dep, verify_only): +def append_barrier_or_raise_error( + kernel_name, schedule, dep, verify_only, use_dependencies_v2=False): if verify_only: - from loopy.diagnostic import MissingBarrierError - raise MissingBarrierError( + err_str = ( "%s: Dependency '%s' (for variable '%s') " "requires synchronization " "by a %s barrier (add a 'no_sync_with' " @@ -1769,6 +2049,14 @@ def append_barrier_or_raise_error(kernel_name, schedule, dep, verify_only): tgt=dep.target.id, src=dep.source.id), dep.variable, dep.var_kind)) + # TODO need to update all this with v2 deps. For now, make this a warning. + # Do full fix for this later + if use_dependencies_v2: + from warnings import warn + warn(err_str) + else: + from loopy.diagnostic import MissingBarrierError + raise MissingBarrierError(err_str) else: comment = "for {} ({})".format( dep.variable, dep.dep_descr.format( @@ -1836,7 +2124,8 @@ def insert_barriers_at_outer_level(schedule, reverse=False): dep_tracker.gen_dependencies_with_target_at(insn) for insn in loop_head): append_barrier_or_raise_error( - kernel.name, result, dep, verify_only) + kernel.name, result, dep, verify_only, + kernel.options.use_dependencies_v2) # This barrier gets inserted outside the loop, hence it is # executed unconditionally and so kills all sources before # the loop. @@ -1869,7 +2158,8 @@ def insert_barriers_at_outer_level(schedule, reverse=False): for dep in dep_tracker.gen_dependencies_with_target_at( sched_item.insn_id): append_barrier_or_raise_error( - kernel.name, result, dep, verify_only) + kernel.name, result, dep, verify_only, + kernel.options.use_dependencies_v2) dep_tracker.discard_all_sources() break result.append(sched_item) @@ -1998,13 +2288,32 @@ def generate_loop_schedules_inner(kernel, callables_table, debug_args=None): loop_nest_with_map = find_loop_nest_with_map(kernel) loop_nest_around_map = find_loop_nest_around_map(kernel) + + # {{{ create simplified dependency graph with edge from depender* to + # dependee* iff intersection (SAME_map & DEP_map) is not empty + + if kernel.options.use_dependencies_v2: + from loopy.schedule.checker.dependency import ( + filter_deps_by_intersection_with_SAME, + ) + + # Get dep graph edges with edges FROM depender TO dependee + simplified_depends_on_graph = filter_deps_by_intersection_with_SAME(kernel) + else: + simplified_depends_on_graph = None + + # }}} + sched_state = SchedulerState( kernel=kernel, loop_nest_around_map=loop_nest_around_map, loop_insn_dep_map=find_loop_insn_dep_map( kernel, loop_nest_with_map=loop_nest_with_map, - loop_nest_around_map=loop_nest_around_map), + loop_nest_around_map=loop_nest_around_map, + simplified_depends_on_graph=simplified_depends_on_graph, + ), + simplified_depends_on_graph=simplified_depends_on_graph, breakable_inames=ilp_inames, ilp_inames=ilp_inames, vec_inames=vec_inames, @@ -2034,7 +2343,8 @@ def generate_loop_schedules_inner(kernel, callables_table, debug_args=None): active_group_counts={}, insns_in_topologically_sorted_order=( - get_insns_in_topologically_sorted_order(kernel)), + get_insns_in_topologically_sorted_order( + kernel, simplified_depends_on_graph)), ) schedule_gen_kwargs = {} @@ -2133,7 +2443,7 @@ def print_longest_dead_end(): key_builder=LoopyKeyBuilder()) -def _get_one_scheduled_kernel_inner(kernel, callables_table): +def _get_one_scheduled_kernel_inner(kernel, callables_table, debug_args={}): # This helper function exists to ensure that the generator chain is fully # out of scope after the function returns. This allows it to be # garbage-collected in the exit handler of the @@ -2143,7 +2453,8 @@ def _get_one_scheduled_kernel_inner(kernel, callables_table): # # See https://gitlab.tiker.net/inducer/sumpy/issues/31 for context. - return next(iter(generate_loop_schedules(kernel, callables_table))) + return next(iter(generate_loop_schedules( + kernel, callables_table, debug_args=debug_args))) def get_one_scheduled_kernel(kernel, callables_table): @@ -2155,7 +2466,7 @@ def get_one_scheduled_kernel(kernel, callables_table): return get_one_linearized_kernel(kernel, callables_table) -def get_one_linearized_kernel(kernel, callables_table): +def get_one_linearized_kernel(kernel, callables_table, debug_args={}): from loopy import CACHING_ENABLED # must include *callables_table* within the cache key as the preschedule @@ -2176,7 +2487,7 @@ def get_one_linearized_kernel(kernel, callables_table): with ProcessLogger(logger, "%s: schedule" % kernel.name): with MinRecursionLimitForScheduling(kernel): result = _get_one_scheduled_kernel_inner(kernel, - callables_table) + callables_table, debug_args) if CACHING_ENABLED and not from_cache: schedule_cache.store_if_not_present(sched_cache_key, result) diff --git a/loopy/schedule/checker/__init__.py b/loopy/schedule/checker/__init__.py new file mode 100644 index 000000000..ec4b863fe --- /dev/null +++ b/loopy/schedule/checker/__init__.py @@ -0,0 +1,299 @@ +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +# {{{ get pairwise statement orderings + +def get_pairwise_statement_orderings( + knl, + lin_items, + stmt_id_pairs, + ): + r"""For each statement pair in a subset of all statement pairs found in a + linearized kernel, determine the (relative) order in which the statement + instances are executed. For each pair, represent this relative ordering + using three ``statement instance orderings`` (SIOs): + + - The intra-thread SIO: A :class:`islpy.Map` from each instance of the + first statement to all instances of the second statement that occur + later, such that both statement instances in each before-after pair are + executed within the same work-item (thread). + + - The intra-group SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, such + that both statement instances in each before-after pair are executed + within the same work-group (though potentially by different work-items). + + - The global SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, even + if the two statement instances in a given before-after pair are executed + within different work-groups. + + :arg knl: A preprocessed :class:`loopy.kernel.LoopKernel` containing the + linearization items that will be used to create the SIOs. + + :arg lin_items: A list of :class:`loopy.schedule.ScheduleItem` + (to be renamed to `loopy.schedule.LinearizationItem`) containing all + linearization items for which SIOs will be created. To allow usage of + this routine during linearization, a truncated (i.e. partial) + linearization may be passed through this argument. + + :arg stmt_id_pairs: A sequence containing pairs of statement identifiers. + + :returns: A dictionary mapping each two-tuple of statement identifiers + provided in `stmt_id_pairs` to a :class:`StatementOrdering`, which + contains the three SIOs described above. + + .. doctest: + + >>> import loopy as lp + >>> import numpy as np + >>> # Make kernel ----------------------------------------------------------- + >>> knl = lp.make_kernel( + ... "{[j,k]: 0<=j>> knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32}) + >>> # Preprocess + >>> knl = lp.preprocess_kernel(knl) + >>> # Get a linearization + >>> knl = lp.get_one_linearized_kernel( + ... knl["loopy_kernel"], knl.callables_table) + >>> # Get pairwise order info ----------------------------------------------- + >>> from loopy.schedule.checker import get_pairwise_statement_orderings + >>> sio_dict = get_pairwise_statement_orderings( + ... knl, + ... knl.linearization, + ... [("stmt_a", "stmt_b")], + ... ) + >>> # Print map + >>> print(str(sio_dict[("stmt_a", "stmt_b")].sio_intra_thread + ... ).replace("{ ", "{\n").replace(" :", "\n:")) + [pj, pk] -> { + [_lp_linchk_stmt' = 0, j'] -> [_lp_linchk_stmt = 1, k] + : pj > 0 and pk > 0 and 0 <= j' < pj and 0 <= k < pk } + + """ + + # {{{ make sure kernel has been preprocessed + + from loopy.kernel import KernelState + assert knl.state in [ + KernelState.PREPROCESSED, + KernelState.LINEARIZED] + + # }}} + + # {{{ Find any EnterLoop inames that are tagged as concurrent + # so that get_pairwise_statement_orderings_inner() knows to ignore them + # (In the future, this should only include inames tagged with 'vec'.) + from loopy.schedule.checker.utils import ( + partition_inames_by_concurrency, + get_EnterLoop_inames, + ) + conc_inames, _ = partition_inames_by_concurrency(knl) + enterloop_inames = get_EnterLoop_inames(lin_items) + conc_loop_inames = conc_inames & enterloop_inames + + # The only concurrent EnterLoop inames should be Vec and ILP + from loopy.kernel.data import (VectorizeTag, IlpBaseTag) + for conc_iname in conc_loop_inames: + # Assert that there exists an ilp or vectorize tag (out of the + # potentially multiple other tags on this concurrent iname). + assert any( + isinstance(tag, (VectorizeTag, IlpBaseTag)) + for tag in knl.iname_to_tags[conc_iname]) + + # }}} + + # {{{ Create the SIOs + + from loopy.schedule.checker.schedule import ( + get_pairwise_statement_orderings_inner + ) + return get_pairwise_statement_orderings_inner( + knl, + lin_items, + stmt_id_pairs, + loops_to_ignore=conc_loop_inames, + ) + + # }}} + +# }}} + + +# {{{ find_unsatisfied_dependencies() + +def find_unsatisfied_dependencies( + knl, + lin_items=None, + ): + """For each statement (:class:`loopy.InstructionBase`) found in a + preprocessed kernel, determine which dependencies, if any, have been + violated by the linearization described by `lin_items`, and return these + dependencies. + + :arg knl: A preprocessed (or linearized) :class:`loopy.kernel.LoopKernel` + containing the statements (:class:`loopy.InstructionBase`) whose + dependencies will be checked against the linearization items. + + :arg lin_items: A list of :class:`loopy.schedule.ScheduleItem` + (to be renamed to `loopy.schedule.LinearizationItem`) containing all + linearization items in `knl.linearization`. To allow usage of + this routine during linearization, a truncated (i.e. partial) + linearization may be passed through this argument. If not provided, + `knl.linearization` will be used. + + :returns: A list of unsatisfied dependencies, each described using a + :class:`collections.namedtuple` containing the following: + + - `statement_pair`: The (before, after) pair of statement IDs involved + in the dependency. + - `dependency`: An class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that must occur + later. + - `statement_ordering`: A statement ordering information tuple + resulting from `lp.get_pairwise_statement_orderings`, a + :class:`collections.namedtuple` containing the intra-thread + statement instance ordering (SIO) (`sio_intra_thread`), + intra-group SIO (`sio_intra_group`), and global + SIO (`sio_global`), each realized as an :class:`islpy.Map` from each + instance of the first statement to all instances of the second + statement that occur later, as well as the intra-thread pairwise + schedule (`pwsched_intra_thread`), intra-group pairwise schedule + (`pwsched_intra_group`), and the global pairwise schedule + (`pwsched_global`), each containing a pair of mappings from statement + instances to points in a lexicographic ordering, one for each + statement. Note that a pairwise schedule alone cannot be used to + reproduce the corresponding SIO without the corresponding (unique) + lexicographic order map, which is not returned. + + """ + + # {{{ Handle lin_items=None and make sure kernel has been preprocessed + + from loopy.kernel import KernelState + if lin_items is None: + assert knl.state == KernelState.LINEARIZED + lin_items = knl.linearization + else: + # Note: kernels must always be preprocessed before scheduling + assert knl.state in [ + KernelState.PREPROCESSED, + KernelState.LINEARIZED] + + # }}} + + # {{{ Create map from dependent statement id pairs to dependencies + + # To minimize time complexity, all pairwise schedules will be created + # in one pass, which first requires finding all pairs of statements involved + # in deps. We will also need to collect the deps for each statement pair, + # so do this at the same time. + + stmt_pairs_to_deps = {} + + # stmt_pairs_to_deps: + # {(stmt_id_before1, stmt_id_after1): [dep1, dep2, ...], + # (stmt_id_before2, stmt_id_after2): [dep1, dep2, ...], + # ...} + + from loopy.kernel.instruction import BarrierInstruction + # TODO (fix) for now, don't check deps on/by barriers + for stmt_after in knl.instructions: + if not isinstance(stmt_after, BarrierInstruction): + for before_id, dep_list in stmt_after.dependencies.items(): + if not isinstance(knl.id_to_insn[before_id], BarrierInstruction): + # (don't compare dep maps to maps found; + # duplicate deps should be rare) + stmt_pairs_to_deps.setdefault( + (before_id, stmt_after.id), []).extend(dep_list) + # }}} + + # {{{ Get statement instance orderings + + pworders = get_pairwise_statement_orderings( + knl, + lin_items, + stmt_pairs_to_deps.keys(), + ) + + # }}} + + # {{{ For each depender-dependee pair of statements, check all deps vs. SIO + + # Collect info about unsatisfied deps + unsatisfied_deps = [] + from collections import namedtuple + UnsatisfiedDependencyInfo = namedtuple( + "UnsatisfiedDependencyInfo", + ["statement_pair", "dependency", "statement_ordering"]) + + for stmt_id_pair, dependencies in stmt_pairs_to_deps.items(): + + # Get the pairwise ordering info (includes SIOs) + pworder = pworders[stmt_id_pair] + + # Check each dep for this statement pair + for dependency in dependencies: + + # Align constraint map space to match SIO so we can + # check to see whether the constraint map is a subset of the SIO + from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + ) + aligned_dep_map = ensure_dim_names_match_and_align( + dependency, pworder.sio_intra_thread) + + # Spaces must match + assert aligned_dep_map.space == pworder.sio_intra_thread.space + assert aligned_dep_map.space == pworder.sio_intra_group.space + assert aligned_dep_map.space == pworder.sio_global.space + assert (aligned_dep_map.get_var_dict() == + pworder.sio_intra_thread.get_var_dict()) + assert (aligned_dep_map.get_var_dict() == + pworder.sio_intra_group.get_var_dict()) + assert (aligned_dep_map.get_var_dict() == + pworder.sio_global.get_var_dict()) + + # Check dependency + if not aligned_dep_map.is_subset( + pworder.sio_intra_thread | + pworder.sio_intra_group | + pworder.sio_global + ): + + unsatisfied_deps.append(UnsatisfiedDependencyInfo( + stmt_id_pair, aligned_dep_map, pworder)) + + # Could break here if we don't care about remaining deps + + # }}} + + return unsatisfied_deps + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/schedule/checker/dependency.py b/loopy/schedule/checker/dependency.py new file mode 100644 index 000000000..47199a243 --- /dev/null +++ b/loopy/schedule/checker/dependency.py @@ -0,0 +1,138 @@ +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl + + +class DependencyType: + """Strings specifying a particular type of dependency relationship. + + .. attribute:: SAME + + A :class:`str` specifying the following dependency relationship: + + If ``S = {i, j, ...}`` is a set of inames used in both statements + ``insn0`` and ``insn1``, and ``{i', j', ...}`` represent the values + of the inames in ``insn0``, and ``{i, j, ...}`` represent the + values of the inames in ``insn1``, then the dependency + ``insn0 happens before insn1 iff SAME({i, j})`` specifies that + ``insn0 happens before insn1 iff {i' = i and j' = j and ...}``. + Note that ``SAME({}) = True``. + + .. attribute:: PRIOR + + A :class:`str` specifying the following dependency relationship: + + If ``S = {i, j, k, ...}`` is a set of inames used in both statements + ``insn0`` and ``insn1``, and ``{i', j', k', ...}`` represent the values + of the inames in ``insn0``, and ``{i, j, k, ...}`` represent the + values of the inames in ``insn1``, then the dependency + ``insn0 happens before insn1 iff PRIOR({i, j, k})`` specifies one of + two possibilities, depending on whether the loop nest ordering is + known. If the loop nest ordering is unknown, then + ``insn0 happens before insn1 iff {i' < i and j' < j and k' < k ...}``. + If the loop nest ordering is known, the condition becomes + ``{i', j', k', ...}`` is lexicographically less than ``{i, j, k, ...}``, + i.e., ``i' < i or (i' = i and j' < j) or (i' = i and j' = j and k' < k) ...``. + + """ + + SAME = "same" + PRIOR = "prior" + + +def filter_deps_by_intersection_with_SAME(knl): + # Determine which dep relations have a non-empty intersection with + # the SAME relation + # TODO document + + from loopy.schedule.checker.utils import ( + append_mark_to_strings, + partition_inames_by_concurrency, + create_elementwise_comparison_conjunction_set, + convert_map_to_set, + convert_set_back_to_map, + ) + from loopy.schedule.checker.schedule import ( + BEFORE_MARK, + ) + _, non_conc_inames = partition_inames_by_concurrency(knl) + + # NOTE: deps filtered will map depender->dependee + deps_filtered = {} + for stmt in knl.instructions: + + if hasattr(stmt, "dependencies") and stmt.dependencies: + + depender_id = stmt.id + + for dependee_id, dep_maps in stmt.dependencies.items(): + + # Continue if we've been told to ignore this dependee + if stmt.non_linearizing_deps is None: + dependees_to_ignore = set() + else: + dependees_to_ignore = stmt.non_linearizing_deps + if dependee_id in dependees_to_ignore: + # TODO better fix for this...? + continue + + # Continue if we already have this pair + if depender_id in deps_filtered.keys() and ( + dependee_id in deps_filtered[depender_id]): + continue + + for dep_map in dep_maps: + # Create isl map representing "SAME" dep for these two insns + + # Get shared nonconcurrent inames + depender_inames = knl.id_to_insn[depender_id].within_inames + dependee_inames = knl.id_to_insn[dependee_id].within_inames + shared_nc_inames = ( + depender_inames & dependee_inames & non_conc_inames) + + # Temporarily convert to set + dep_set_space, n_in_dims, n_out_dims = convert_map_to_set( + dep_map.space) + + # Create SAME relation + same_set_affs = isl.affs_from_space(dep_set_space) + same_set = create_elementwise_comparison_conjunction_set( + shared_nc_inames, + append_mark_to_strings(shared_nc_inames, BEFORE_MARK), + same_set_affs) + + # Convert back to map + same_map = convert_set_back_to_map( + same_set, n_in_dims, n_out_dims) + + # Don't need to intersect same_map with iname bounds (I think..?) + + # See whether the intersection of dep map and SAME is empty + intersect_dep_and_same = same_map & dep_map + intersect_not_empty = not bool(intersect_dep_and_same.is_empty()) + + if intersect_not_empty: + deps_filtered.setdefault(depender_id, set()).add(dependee_id) + break # No need to check any more deps for this pair + + return deps_filtered diff --git a/loopy/schedule/checker/lexicographic_order_map.py b/loopy/schedule/checker/lexicographic_order_map.py new file mode 100644 index 000000000..5821202cb --- /dev/null +++ b/loopy/schedule/checker/lexicographic_order_map.py @@ -0,0 +1,201 @@ +# coding: utf-8 +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl + + +def get_statement_ordering_map( + sched_before, sched_after, lex_map, before_mark): + """Return a statement ordering represented as a map from each statement + instance to all statement instances occurring later. + + :arg sched_before: An :class:`islpy.Map` representing a schedule + as a mapping from statement instances (for one particular statement) + to lexicographic time. The statement represented will typically + be the dependee in a dependency relationship. + + :arg sched_after: An :class:`islpy.Map` representing a schedule + as a mapping from statement instances (for one particular statement) + to lexicographic time. The statement represented will typically + be the depender in a dependency relationship. + + :arg lex_map: An :class:`islpy.Map` representing a lexicographic + ordering as a mapping from each point in lexicographic time + to every point that occurs later in lexicographic time. E.g.:: + + {[i0', i1', i2', ...] -> [i0, i1, i2, ...] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2) ...} + + :arg before_mark: A :class:`str` to be appended to the names of the + map dimensions representing the 'before' statement in the + 'happens before' relationship. + + :returns: An :class:`islpy.Map` representing the statement odering as + a mapping from each statement instance to all statement instances + occurring later. I.e., we compose relations B, L, and A as + B ∘ L ∘ A^-1, where B is `sched_before`, A is `sched_after`, + and L is `lex_map`. + + """ + + # Perform the composition of relations + sio = sched_before.apply_range( + lex_map).apply_range(sched_after.reverse()) + + # Append mark to in_ dims + from loopy.schedule.checker.utils import ( + append_mark_to_isl_map_var_names, + ) + return append_mark_to_isl_map_var_names( + sio, isl.dim_type.in_, before_mark) + + +def _create_lex_order_set( + dim_names, + in_dim_mark, + var_name_to_pwaff=None, + ): + """Return an :class:`islpy.Set` representing a lexicographic ordering + over a space with the number of dimensions provided in `dim_names` + (the set itself will have twice this many dimensions in order to + represent the ordering as before-after pairs of points). + + :arg dim_names: A list of :class:`str` variable names to be used + to describe lexicographic space dimensions for a point in a lexicographic + ordering. (see example below) + + :arg in_dim_mark: A :class:`str` to be appended to dimension names to + distinguish corresponding dimensions in before-after pairs of points. + (see example below) + + :arg var_name_to_pwaff: A dictionary mapping variable names in `dim_names` to + :class:`islpy.PwAff` instances that represent each of the variables + (var_name_to_pwaff may be produced by `islpy.make_zero_and_vars`). + The key '0' is also included and represents a :class:`islpy.PwAff` zero + constant. This dictionary defines the space to be used for the set and + must also include versions of `dim_names` with the `in_dim_mark` + appended. If no value is passed, the dictionary will be made using + `dim_names` and `dim_names` with the `in_dim_mark` appended. + + :returns: An :class:`islpy.Set` representing a big-endian lexicographic + ordering with the number of dimensions provided in `dim_names`. The set + has two dimensions for each name in `dim_names`, one identified by the + given name and another identified by the same name with `in_dim_mark` + appended. The set contains all points which meet a 'happens before' + constraint defining the lexicographic ordering. E.g., if + `dim_names = [i0, i1, i2]` and `in_dim_mark="'"`, + return the set containing all points in a 3-dimensional, big-endian + lexicographic ordering such that point + `[i0', i1', i2']` happens before `[i0, i1, i2]`. I.e., return:: + + {[i0', i1', i2', i0, i1, i2] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2)} + + """ + + from loopy.schedule.checker.utils import ( + append_mark_to_strings, + ) + + in_dim_names = append_mark_to_strings(dim_names, mark=in_dim_mark) + + # If no var_name_to_pwaff passed, make them using the names provided + # (make sure to pass var names in desired order of space dims) + if var_name_to_pwaff is None: + var_name_to_pwaff = isl.make_zero_and_vars( + in_dim_names+dim_names, + []) + + # Initialize set with constraint i0' < i0 + lex_order_set = var_name_to_pwaff[in_dim_names[0]].lt_set( + var_name_to_pwaff[dim_names[0]]) + + # For each dim d, starting with d=1, equality_conj_set will be constrained + # by d equalities, e.g., (i0' = i0 and i1' = i1 and ... i(d-1)' = i(d-1)). + equality_conj_set = var_name_to_pwaff[0].eq_set( + var_name_to_pwaff[0]) # initialize to 'true' + + for i in range(1, len(in_dim_names)): + + # Add the next equality constraint to equality_conj_set + equality_conj_set = equality_conj_set & \ + var_name_to_pwaff[in_dim_names[i-1]].eq_set( + var_name_to_pwaff[dim_names[i-1]]) + + # Create a set constrained by adding a less-than constraint for this dim, + # e.g., (i1' < i1), to the current equality conjunction set. + # For each dim d, starting with d=1, this full conjunction will have + # d equalities and one inequality, e.g., + # (i0' = i0 and i1' = i1 and ... i(d-1)' = i(d-1) and id' < id) + full_conj_set = var_name_to_pwaff[in_dim_names[i]].lt_set( + var_name_to_pwaff[dim_names[i]]) & equality_conj_set + + # Union this new constraint with the current lex_order_set + lex_order_set = lex_order_set | full_conj_set + + return lex_order_set + + +def create_lex_order_map( + dim_names, + in_dim_mark, + ): + """Return a map from each point in a lexicographic ordering to every + point that occurs later in the lexicographic ordering. + + :arg dim_names: A list of :class:`str` variable names for the + lexicographic space dimensions. + + :arg in_dim_mark: A :class:`str` to be appended to `dim_names` to create + the names for the input dimensions of the map, thereby distinguishing + them from the corresponding output dimensions in before-after pairs of + points. (see example below) + + :returns: An :class:`islpy.Map` representing a lexicographic + ordering as a mapping from each point in lexicographic time + to every point that occurs later in lexicographic time. + E.g., if `dim_names = [i0, i1, i2]` and `in_dim_mark = "'"`, + return the map:: + + {[i0', i1', i2'] -> [i0, i1, i2] : + i0' < i0 or (i0' = i0 and i1' < i1) + or (i0' = i0 and i1' = i1 and i2' < i2)} + + """ + + n_dims = len(dim_names) + dim_type = isl.dim_type + + # First, get a set representing the lexicographic ordering. + lex_order_set = _create_lex_order_set( + dim_names, + in_dim_mark=in_dim_mark, + ) + + # Now convert that set to a map. + lex_map = isl.Map.from_domain(lex_order_set) + return lex_map.move_dims( + dim_type.out, 0, dim_type.in_, + n_dims, n_dims) diff --git a/loopy/schedule/checker/schedule.py b/loopy/schedule/checker/schedule.py new file mode 100644 index 000000000..ab9af51df --- /dev/null +++ b/loopy/schedule/checker/schedule.py @@ -0,0 +1,1013 @@ +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl +from dataclasses import dataclass +dt = isl.dim_type.set + + +# {{{ Constants + +__doc__ = """ + +.. data:: LIN_CHECK_IDENTIFIER_PREFIX + + The :class:`str` prefix for identifiers involved in linearization + checking. + +.. data:: LEX_VAR_PREFIX + + The :class:`str` prefix for the variables representing the + dimensions in the lexicographic ordering used in a pairwise schedule. E.g., + a prefix of ``_lp_linchk_lex`` might yield lexicographic dimension + variables ``_lp_linchk_lex0``, ``_lp_linchk_lex1``, ``_lp_linchk_lex2``. + Cf. :ref:`reserved-identifiers`. + +.. data:: STATEMENT_VAR_NAME + + The :class:`str` name for the statement-identifying dimension of maps + representing schedules and statement instance orderings. + +.. data:: LTAG_VAR_NAME + + An array of :class:`str` names for map dimensions carrying values for local + (intra work-group) thread identifiers in maps representing schedules and + statement instance orderings. + +.. data:: GTAG_VAR_NAME + + An array of :class:`str` names for map dimensions carrying values for group + identifiers in maps representing schedules and statement instance orderings. + +.. data:: BEFORE_MARK + + The :class:`str` identifier to be appended to input dimension names in + maps representing schedules and statement instance orderings. + +""" + +LIN_CHECK_IDENTIFIER_PREFIX = "_lp_linchk_" +LEX_VAR_PREFIX = "%slex" % (LIN_CHECK_IDENTIFIER_PREFIX) +STATEMENT_VAR_NAME = "%sstmt" % (LIN_CHECK_IDENTIFIER_PREFIX) +LTAG_VAR_NAMES = [] +GTAG_VAR_NAMES = [] +for par_level in [0, 1, 2]: + LTAG_VAR_NAMES.append("%slid%d" % (LIN_CHECK_IDENTIFIER_PREFIX, par_level)) + GTAG_VAR_NAMES.append("%sgid%d" % (LIN_CHECK_IDENTIFIER_PREFIX, par_level)) +BEFORE_MARK = "'" + +# }}} + + +# {{{ Helper Functions + +# {{{ _pad_tuple_with_zeros + +def _pad_tuple_with_zeros(tup, desired_length): + return tup[:] + tuple([0]*(desired_length-len(tup))) + +# }}} + + +# {{{ _simplify_lex_dims + +def _simplify_lex_dims(tup0, tup1): + """Simplify a pair of lex tuples in order to reduce the complexity of + resulting maps. Remove lex tuple dimensions with matching integer values + since these do not provide information on relative ordering. Once a + dimension is found where both tuples have non-matching integer values, + remove any faster-updating lex dimensions since they are not necessary + to specify a relative ordering. + """ + + new_tup0 = [] + new_tup1 = [] + + # Loop over dims from slowest updating to fastest + for d0, d1 in zip(tup0, tup1): + if isinstance(d0, int) and isinstance(d1, int): + + # Both vals are ints for this dim + if d0 == d1: + # Do not keep this dim + continue + elif d0 > d1: + # These ints inform us about the relative ordering of + # two statements. While their values may be larger than 1 in + # the lexicographic ordering describing a larger set of + # statements, in a pairwise schedule, only ints 0 and 1 are + # necessary to specify relative order. To keep the pairwise + # schedules as simple and comprehensible as possible, use only + # integers 0 and 1 to specify this relative ordering. + # (doesn't take much extra time since we are already going + # through these to remove unnecessary lex tuple dims) + new_tup0.append(1) + new_tup1.append(0) + + # No further dims needed to fully specify ordering + break + else: # d1 > d0 + new_tup0.append(0) + new_tup1.append(1) + + # No further dims needed to fully specify ordering + break + else: + # Keep this dim without modifying + new_tup0.append(d0) + new_tup1.append(d1) + + if len(new_tup0) == 0: + # Statements map to the exact same point(s) in the lex ordering, + # which is okay, but to represent this, our lex tuple cannot be empty. + return (0, ), (0, ) + else: + return tuple(new_tup0), tuple(new_tup1) + +# }}} + +# }}} + + +# {{{ class SpecialLexPointWRTLoop + +class SpecialLexPointWRTLoop: + """Strings identifying a particular point or set of points in a + lexicographic ordering of statements, specified relative to a loop. + + .. attribute:: PRE + A :class:`str` indicating the last lexicographic point that + precedes the loop. + + .. attribute:: FIRST + A :class:`str` indicating the first lexicographic point in the + first loop iteration (i.e., with the iname set to its min. val). + + .. attribute:: TOP + A :class:`str` indicating the first lexicographic point in + an arbitrary loop iteration. + + .. attribute:: BOTTOM + A :class:`str` indicating the last lexicographic point in + an arbitrary loop iteration. + + .. attribute:: LAST + A :class:`str` indicating the last lexicographic point in the + last loop iteration (i.e., with the iname set to its max val). + + .. attribute:: POST + A :class:`str` indicating the first lexicographic point that + follows the loop. + """ + + PRE = "pre" + FIRST = "first" + TOP = "top" + BOTTOM = "bottom" + LAST = "last" + POST = "post" + +# }}} + + +# {{{ class StatementOrdering + +@dataclass +class StatementOrdering: + r"""A container for the three statement instance orderings (described + below) used to formalize the ordering of statement instances for a pair of + statements. + + Also included (mostly for testing and debugging) are the + intra-thread pairwise schedule (`pwsched_intra_thread`), intra-group + pairwise schedule (`pwsched_intra_group`), and global pairwise schedule + (`pwsched_global`), each containing a pair of mappings from statement + instances to points in a lexicographic ordering, one for each statement. + Each SIO is created by composing the two mappings in the corresponding + pairwise schedule with an associated mapping defining the ordering of + points in the lexicographical space (not included). + """ + + sio_intra_thread: isl.Map + sio_intra_group: isl.Map + sio_global: isl.Map + pwsched_intra_thread: tuple + pwsched_intra_group: tuple + pwsched_global: tuple + +# }}} + + +# {{{ _gather_blex_ordering_info + +def _gather_blex_ordering_info( + sync_kind, + lin_items, loops_with_barriers, loops_to_ignore, + all_stmt_ids, iname_bounds_pwaff, + all_par_lex_dim_names, gid_lex_dim_names, + ): + """For the given sync_kind ("local" or "global"), create a mapping from + statement instances to blex space (dict), as well as a mapping + defining the blex ordering (isl map from blex space -> blex space) + + Note that, unlike in the intra-thread case, there will be a single + blex ordering map defining the blex ordering for all statement pairs, + rather than separate (smaller) lex ordering maps for each pair + """ + from loopy.schedule import (EnterLoop, LeaveLoop, Barrier, RunInstruction) + from loopy.schedule.checker.lexicographic_order_map import ( + create_lex_order_map, + ) + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + append_mark_to_strings, + add_eq_isl_constraint_from_names, + ) + slex = SpecialLexPointWRTLoop + + # {{{ First, create map from stmt instances to blex space. + + # At the same time, gather information necessary to create the + # blex ordering map, i.e., for each loop, gather the 6 lex order tuples + # defined above in SpecialLexPointWRTLoop that will be required to + # create sub-maps which will be *excluded* (subtracted) from a standard + # lexicographic ordering in order to create the blex ordering + + stmt_inst_to_blex = {} # Map stmt instances to blex space + iname_to_blex_dim = {} # Map from inames to corresponding blex space dim + blex_exclusion_info = {} # Info for creating maps to exclude from blex order + blex_order_map_params = set() # Params needed in blex order map + n_seq_blex_dims = 1 # Num dims representing sequential order in blex space + next_blex_tuple = [0] # Next tuple of points in blex order + + for lin_item in lin_items: + if isinstance(lin_item, EnterLoop): + enter_iname = lin_item.iname + if enter_iname in loops_with_barriers[sync_kind] - loops_to_ignore: + pre_loop_blex_pt = next_blex_tuple[:] + + # Increment next_blex_tuple[-1] for statements in the section + # of code between this EnterLoop and the matching LeaveLoop. + next_blex_tuple[-1] += 1 + + # Upon entering a loop, add one blex dimension for the loop + # iteration, add second blex dim to enumerate sections of + # code within new loop + next_blex_tuple.append(enter_iname) + next_blex_tuple.append(0) + + # Store 3 tuples that will be used later to create pairs + # that will later be subtracted from the blex order map + lbound = iname_bounds_pwaff[enter_iname][0] + first_iter_blex_pt = next_blex_tuple[:] + first_iter_blex_pt[-2] = lbound + blex_exclusion_info[enter_iname] = { + slex.PRE: tuple(pre_loop_blex_pt), + slex.TOP: tuple(next_blex_tuple), + slex.FIRST: tuple(first_iter_blex_pt), + } + # (copy these three blex points when creating dict because + # the lists will continue to be updated) + + # Store any new params found + blex_order_map_params |= set(lbound.get_var_names(dt.param)) + + elif isinstance(lin_item, LeaveLoop): + leave_iname = lin_item.iname + if leave_iname in loops_with_barriers[sync_kind] - loops_to_ignore: + + # Update max blex dims + n_seq_blex_dims = max(n_seq_blex_dims, len(next_blex_tuple)) + + # Record the blex dim for this loop iname + iname_to_blex_dim[leave_iname] = len(next_blex_tuple)-2 + + # Update next blex pt + pre_end_loop_blex_pt = next_blex_tuple[:] + # Upon leaving a loop: + # - Pop lex dim for enumerating code sections within this loop + # - Pop lex dim for the loop iteration + # - Increment lex dim val enumerating items in current section + next_blex_tuple.pop() + next_blex_tuple.pop() + next_blex_tuple[-1] += 1 + + # Store 3 tuples that will be used later to create pairs + # that will later be subtracted from the blex order map + ubound = iname_bounds_pwaff[leave_iname][1] + last_iter_blex_pt = pre_end_loop_blex_pt[:] + last_iter_blex_pt[-2] = ubound + blex_exclusion_info[leave_iname][slex.BOTTOM] = tuple( + pre_end_loop_blex_pt) + blex_exclusion_info[leave_iname][slex.LAST] = tuple( + last_iter_blex_pt) + blex_exclusion_info[leave_iname][slex.POST] = tuple( + next_blex_tuple) + # (copy these three blex points when creating dict because + # the lists will continue to be updated) + + # Store any new params found + blex_order_map_params |= set(ubound.get_var_names(dt.param)) + + elif isinstance(lin_item, RunInstruction): + # Add stmt->blex pair to stmt_inst_to_blex + stmt_inst_to_blex[lin_item.insn_id] = tuple(next_blex_tuple) + + # (Don't increment blex dim val) + + elif isinstance(lin_item, Barrier): + # Increment blex dim val if the sync scope matches + if lin_item.synchronization_kind == sync_kind: + next_blex_tuple[-1] += 1 + + lp_stmt_id = lin_item.originating_insn_id + + if lp_stmt_id is None: + # Barriers without stmt ids were inserted as a result of a + # dependency. They don't themselves have dependencies. + # Don't map this barrier to a blex tuple. + continue + + # This barrier has a stmt id. + # If it was included in listed stmts, process it. + # Otherwise, there's nothing left to do (we've already + # incremented next_blex_tuple if necessary, and this barrier + # does not need to be assigned to a designated point in blex + # time) + if lp_stmt_id in all_stmt_ids: + # If sync scope matches, give this barrier its own point in + # lex time and update blex tuple after barrier. + # Otherwise, add stmt->blex pair to stmt_inst_to_blex, but + # don't update the blex tuple (just like with any other + # stmt) + if lin_item.synchronization_kind == sync_kind: + stmt_inst_to_blex[lp_stmt_id] = tuple(next_blex_tuple) + next_blex_tuple[-1] += 1 + else: + stmt_inst_to_blex[lp_stmt_id] = tuple(next_blex_tuple) + else: + from loopy.schedule import (CallKernel, ReturnFromKernel) + # No action needed for these types of linearization item + assert isinstance( + lin_item, (CallKernel, ReturnFromKernel)) + pass + + blex_order_map_params = sorted(blex_order_map_params) + + # At this point, some blex tuples may have more dimensions than others; + # the missing dims are the fastest-updating dims, and their values should + # be zero. Add them. + for stmt, tup in stmt_inst_to_blex.items(): + stmt_inst_to_blex[stmt] = _pad_tuple_with_zeros(tup, n_seq_blex_dims) + + # }}} + + # {{{ Second, create the blex order map + + # {{{ Create the initial (pre-subtraction) blex order map + + # Create names for the blex dimensions for sequential loops + seq_blex_dim_names = [ + LEX_VAR_PREFIX+str(i) for i in range(n_seq_blex_dims)] + seq_blex_dim_names_prime = append_mark_to_strings( + seq_blex_dim_names, mark=BEFORE_MARK) + + # Begin with the blex order map created as a standard lexicographical order + blex_order_map = create_lex_order_map( + dim_names=seq_blex_dim_names, + in_dim_mark=BEFORE_MARK, + ) + + # Add LID/GID dims to blex order map + blex_order_map = add_and_name_isl_dims( + blex_order_map, dt.out, all_par_lex_dim_names) + blex_order_map = add_and_name_isl_dims( + blex_order_map, dt.in_, + append_mark_to_strings(all_par_lex_dim_names, mark=BEFORE_MARK)) + if sync_kind == "local": + # For intra-group case, constrain GID 'before' to equal GID 'after' + for var_name in gid_lex_dim_names: + blex_order_map = add_eq_isl_constraint_from_names( + blex_order_map, var_name, var_name+BEFORE_MARK) + # (if sync_kind == "global", don't need constraints on LID/GID vars) + + # }}} + + # {{{ Subtract unwanted pairs from happens-before blex map + + # Create map from iname to corresponding blex dim name + iname_to_blex_var = {} + for iname, dim in iname_to_blex_dim.items(): + iname_to_blex_var[iname] = seq_blex_dim_names[dim] + iname_to_blex_var[iname+BEFORE_MARK] = seq_blex_dim_names_prime[dim] + + # Add bounds params needed in blex map + blex_order_map = add_and_name_isl_dims( + blex_order_map, dt.param, blex_order_map_params) + + # Get a set representing blex_order_map space + n_blex_dims = n_seq_blex_dims + len(all_par_lex_dim_names) + blex_set_template = isl.align_spaces( + isl.Map("[ ] -> { [ ] -> [ ] }"), blex_order_map + ).move_dims( + dt.in_, n_blex_dims, dt.out, 0, n_blex_dims + ).domain() + blex_set_affs = isl.affs_from_space(blex_set_template.space) + + # {{{ Create blex map to subtract for each iname in blex_exclusion_info + + maps_to_subtract = [] + for iname, key_lex_tuples in blex_exclusion_info.items(): + + # {{{ Create blex map to subract for one iname + + """Create the blex->blex pairs that must be subtracted from the + initial blex order map for this particular loop using the 6 blex + tuples in key_lex_tuples: + PRE->FIRST, BOTTOM(iname')->TOP(iname'+1), LAST->POST + """ + + # Note: + # only key_lex_tuples[slex.FIRST] & key_lex_tuples[slex.LAST] are pwaffs + + # {{{ _create_blex_set_from_tuple_pair + + def _create_blex_set_from_tuple_pair(before, after, wrap_cond=False): + """Given a before->after tuple pair in the key_lex_tuples, which may + have dim vals described by ints, strings (inames), and pwaffs, + create an ISL set in blex space that can be converted into + the ISL map to be subtracted + """ + # (Vars from outside func used here: + # iname, blex_set_affs, blex_set_template, iname_to_blex_var, + # n_seq_blex_dims, seq_blex_dim_names, + # seq_blex_dim_names_prime) + + # Start with a set representing blex_order_map space + blex_set = blex_set_template.copy() + + # Add marks to inames in the 'before' tuple + # (all strings should be inames) + before_prime = tuple( + v+BEFORE_MARK if isinstance(v, str) else v for v in before) + before_padded = _pad_tuple_with_zeros(before_prime, n_seq_blex_dims) + after_padded = _pad_tuple_with_zeros(after, n_seq_blex_dims) + + # Assign vals in the tuple to dims in the ISL set + for dim_name, dim_val in zip( + seq_blex_dim_names_prime+seq_blex_dim_names, + before_padded+after_padded): + + if isinstance(dim_val, int): + # Set idx to int val + blex_set &= blex_set_affs[dim_name].eq_set( + blex_set_affs[0]+dim_val) + elif isinstance(dim_val, str): + # This is an iname, set idx to corresponding blex var + blex_set &= blex_set_affs[dim_name].eq_set( + blex_set_affs[iname_to_blex_var[dim_val]]) + else: + # This is a pwaff iname bound, align and intersect + assert isinstance(dim_val, isl.PwAff) + pwaff_aligned = isl.align_spaces(dim_val, blex_set_affs[0]) + # (doesn't matter which blex_set_affs item we align to^) + blex_set &= blex_set_affs[dim_name].eq_set(pwaff_aligned) + + if wrap_cond: + # This is the BOTTOM->TOP pair, add condition i = i' + 1 + blex_set &= blex_set_affs[iname_to_blex_var[iname]].eq_set( + blex_set_affs[iname_to_blex_var[iname+BEFORE_MARK]] + 1) + + return blex_set + + # }}} end _create_blex_set_from_tuple_pair() + + # Create pairs to be subtracted + # (set will be converted to map) + + # Enter loop case: PRE->FIRST + full_blex_set = _create_blex_set_from_tuple_pair( + key_lex_tuples[slex.PRE], key_lex_tuples[slex.FIRST]) + # Wrap loop case: BOTTOM(iname')->TOP(iname'+1) + full_blex_set |= _create_blex_set_from_tuple_pair( + key_lex_tuples[slex.BOTTOM], key_lex_tuples[slex.TOP], + wrap_cond=True) + # Leave loop case: LAST->POST + full_blex_set |= _create_blex_set_from_tuple_pair( + key_lex_tuples[slex.LAST], key_lex_tuples[slex.POST]) + + # Add condition to fix iteration value for *surrounding* loops (j = j') + for surrounding_iname in key_lex_tuples[slex.PRE][1::2]: + s_blex_var = iname_to_blex_var[surrounding_iname] + full_blex_set &= blex_set_affs[s_blex_var].eq_set( + blex_set_affs[s_blex_var+BEFORE_MARK]) + + # Convert blex set back to map + map_to_subtract = isl.Map.from_domain(full_blex_set).move_dims( + dt.out, 0, dt.in_, n_blex_dims, n_blex_dims) + + # }}} + + maps_to_subtract.append(map_to_subtract) + + # }}} + + # {{{ Subtract transitive closure of union of blex maps to subtract + + if maps_to_subtract: + + # Get union of maps + map_to_subtract = maps_to_subtract[0] + for other_map in maps_to_subtract[1:]: + map_to_subtract |= other_map + + # Get transitive closure of maps + map_to_subtract_closure, closure_exact = map_to_subtract.transitive_closure() + + assert closure_exact # TODO warn instead? + + # Subtract closure from blex order map + blex_order_map = blex_order_map - map_to_subtract_closure + + # }}} + + # }}} + + return ( + stmt_inst_to_blex, # map stmt instances to blex space + blex_order_map, + seq_blex_dim_names, + ) + +# }}} + + +# {{{ get_pairwise_statement_orderings_inner + +def get_pairwise_statement_orderings_inner( + knl, + lin_items, + stmt_id_pairs, + loops_to_ignore=frozenset(), + ): + r"""For each statement pair in a subset of all statement pairs found in a + linearized kernel, determine the (relative) order in which the statement + instances are executed. For each pair, represent this relative ordering + using three ``statement instance orderings`` (SIOs): + + - The intra-thread SIO: A :class:`islpy.Map` from each instance of the + first statement to all instances of the second statement that occur + later, such that both statement instances in each before-after pair are + executed within the same work-item (thread). + + - The intra-group SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, such + that both statement instances in each before-after pair are executed + within the same work-group (though potentially by different work-items). + + - The global SIO: A :class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that occur later, even + if the two statement instances in a given before-after pair are executed + within different work-groups. + + :arg knl: A preprocessed :class:`loopy.kernel.LoopKernel` containing the + linearization items that will be used to create the SIOs. This + kernel will be used to get the domains associated with the inames + used in the statements, and to determine which inames have been + tagged with parallel tags. + + :arg lin_items: A list of :class:`loopy.schedule.ScheduleItem` + (to be renamed to `loopy.schedule.LinearizationItem`) containing + all linearization items for which SIOs will be + created. To allow usage of this routine during linearization, a + truncated (i.e. partial) linearization may be passed through this + argument + + :arg stmt_id_pairs: A list containing pairs of statement identifiers. + + :arg loops_to_ignore: A set of inames that will be ignored when + determining the relative ordering of statements. This will typically + contain concurrent inames tagged with the ``vec`` or ``ilp`` array + access tags. + + :returns: A dictionary mapping each two-tuple of statement identifiers + provided in `stmt_id_pairs` to a :class:`StatementOrdering`, which + contains the three SIOs described above. + """ + + from loopy.schedule import (EnterLoop, LeaveLoop, Barrier, RunInstruction) + from loopy.kernel.data import (LocalInameTag, GroupInameTag) + from loopy.schedule.checker.lexicographic_order_map import ( + create_lex_order_map, + get_statement_ordering_map, + ) + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + append_mark_to_strings, + add_eq_isl_constraint_from_names, + sorted_union_of_names_in_isl_sets, + create_symbolic_map_from_tuples, + insert_and_name_isl_dims, + ) + + all_stmt_ids = set().union(*stmt_id_pairs) + + # {{{ Intra-thread lex order creation + + # First, use one pass through lin_items to generate an *intra-thread* + # lexicographic ordering describing the relative order of all statements + # represented by all_stmt_ids + + # For each statement, map the stmt_id to a tuple representing points + # in the intra-thread lexicographic ordering containing items of :class:`int` or + # :class:`str` :mod:`loopy` inames + stmt_inst_to_lex_intra_thread = {} + + # Keep track of the next tuple of points in our lexicographic + # ordering, initially this as a 1-d point with value 0 + next_lex_tuple = [0] + + # While we're passing through, determine which loops contain barriers, + # this information will be used later when creating *intra-group* and + # *global* lexicographic orderings + loops_with_barriers = {"local": set(), "global": set()} + current_inames = set() + + for lin_item in lin_items: + if isinstance(lin_item, EnterLoop): + iname = lin_item.iname + current_inames.add(iname) + + if iname in loops_to_ignore: + continue + + # Increment next_lex_tuple[-1] for statements in the section + # of code between this EnterLoop and the matching LeaveLoop. + # (not technically necessary if no statement was added in the + # previous section; gratuitous incrementing is counteracted + # in the simplification step below) + next_lex_tuple[-1] += 1 + + # Upon entering a loop, add one lex dimension for the loop iteration, + # add second lex dim to enumerate sections of code within new loop + next_lex_tuple.append(iname) + next_lex_tuple.append(0) + + elif isinstance(lin_item, LeaveLoop): + iname = lin_item.iname + current_inames.remove(iname) + + if iname in loops_to_ignore: + continue + + # Upon leaving a loop: + # - Pop lex dim for enumerating code sections within this loop + # - Pop lex dim for the loop iteration + # - Increment lex dim val enumerating items in current section of code + next_lex_tuple.pop() + next_lex_tuple.pop() + next_lex_tuple[-1] += 1 + + # (not technically necessary if no statement was added in the + # previous section; gratuitous incrementing is counteracted + # in the simplification step below) + + elif isinstance(lin_item, RunInstruction): + lp_stmt_id = lin_item.insn_id + + # Only process listed stmts, otherwise ignore + if lp_stmt_id in all_stmt_ids: + # Add item to stmt_inst_to_lex_intra_thread + stmt_inst_to_lex_intra_thread[lp_stmt_id] = tuple(next_lex_tuple) + + # Increment lex dim val enumerating items in current section of code + next_lex_tuple[-1] += 1 + + elif isinstance(lin_item, Barrier): + lp_stmt_id = lin_item.originating_insn_id + loops_with_barriers[lin_item.synchronization_kind] |= current_inames + + if lp_stmt_id is None: + # Barriers without stmt ids were inserted as a result of a + # dependency. They don't themselves have dependencies. Ignore them. + + # FIXME: It's possible that we could record metadata about them + # (e.g. what dependency produced them) and verify that they're + # adequately protecting all statement instance pairs. + + continue + + # If barrier was identified in listed stmts, process it + if lp_stmt_id in all_stmt_ids: + # Add item to stmt_inst_to_lex_intra_thread + stmt_inst_to_lex_intra_thread[lp_stmt_id] = tuple(next_lex_tuple) + + # Increment lex dim val enumerating items in current section of code + next_lex_tuple[-1] += 1 + + else: + from loopy.schedule import (CallKernel, ReturnFromKernel) + # No action needed for these types of linearization item + assert isinstance( + lin_item, (CallKernel, ReturnFromKernel)) + pass + + # }}} + + # {{{ Create lex dim names representing parallel axes + + # Create lex dim names representing lid/gid axes. + # At the same time, create the dicts that will be used later to create map + # constraints that match each parallel iname to the corresponding lex dim + # name in schedules, i.e., i = lid0, j = lid1, etc. + lid_lex_dim_names = set() + gid_lex_dim_names = set() + par_iname_constraint_dicts = {} + for iname in knl.all_inames(): + ltag = knl.iname_tags_of_type(iname, LocalInameTag) + if ltag: + assert len(ltag) == 1 # (should always be true) + ltag_var = LTAG_VAR_NAMES[ltag.pop().axis] + lid_lex_dim_names.add(ltag_var) + par_iname_constraint_dicts[iname] = {1: 0, iname: 1, ltag_var: -1} + + continue # Shouldn't be any GroupInameTags + + gtag = knl.iname_tags_of_type(iname, GroupInameTag) + if gtag: + assert len(gtag) == 1 # (should always be true) + gtag_var = GTAG_VAR_NAMES[gtag.pop().axis] + gid_lex_dim_names.add(gtag_var) + par_iname_constraint_dicts[iname] = {1: 0, iname: 1, gtag_var: -1} + + # Sort for consistent dimension ordering + lid_lex_dim_names = sorted(lid_lex_dim_names) + gid_lex_dim_names = sorted(gid_lex_dim_names) + + # }}} + + # {{{ Intra-group and global blex ("barrier-lex") order creation + + # (may be combined with pass above in future) + + """In blex space, we order barrier-delimited sections of code. + Each statement instance within a single barrier-delimited section will + map to the same blex point. The resulting statement instance ordering + (SIO) will map each statement to all statements that occur in a later + barrier-delimited section. + + To achieve this, we will first create a map from statement instances to + lexicographic space almost as we did above in the intra-thread case, + though we will not increment the fastest-updating lex dim with each + statement, and we will increment it with each barrier encountered. To + denote these differences, we refer to this space as 'blex' space. + + The resulting pairwise schedule, if composed with a map defining a + standard lexicographic ordering to create an SIO, would include a number + of unwanted 'before->after' pairs of statement instances, so before + creating the SIO, we will subtract unwanted pairs from a standard + lex order map, yielding the 'blex' order map. + """ + + # {{{ Get upper and lower bound for each loop that contains a barrier + + iname_bounds_pwaff = {} + for iname in loops_with_barriers["local"] | loops_with_barriers["global"]: + bounds = knl.get_iname_bounds(iname) + iname_bounds_pwaff[iname] = ( + bounds.lower_bound_pw_aff, bounds.upper_bound_pw_aff) + + # }}} + + # {{{ Create blex order maps and blex tuples defining statement ordering (x2) + + all_par_lex_dim_names = lid_lex_dim_names + gid_lex_dim_names + + # Get the blex schedule blueprint (dict will become a map below) and + # blex order map w.r.t. local and global barriers + (stmt_inst_to_lblex, + lblex_order_map, + seq_lblex_dim_names) = _gather_blex_ordering_info( + "local", + lin_items, loops_with_barriers, loops_to_ignore, + all_stmt_ids, iname_bounds_pwaff, + all_par_lex_dim_names, gid_lex_dim_names, + ) + (stmt_inst_to_gblex, + gblex_order_map, + seq_gblex_dim_names) = _gather_blex_ordering_info( + "global", + lin_items, loops_with_barriers, loops_to_ignore, + all_stmt_ids, iname_bounds_pwaff, + all_par_lex_dim_names, gid_lex_dim_names, + ) + + # }}} + + # }}} end intra-group and global blex order creation + + # {{{ Create pairwise schedules (ISL maps) for each stmt pair + + # {{{ _get_map_for_stmt() + + def _get_map_for_stmt( + stmt_id, lex_points, int_sid, lex_dim_names): + + # Get inames domain for statement instance (a BasicSet) + within_inames = knl.id_to_insn[stmt_id].within_inames + dom = knl.get_inames_domain( + within_inames).project_out_except(within_inames, [dt.set]) + + # Create map space (an isl space in current implementation) + # {('statement', ) -> + # (lexicographic ordering dims)} + dom_inames_ordered = sorted_union_of_names_in_isl_sets([dom]) + + in_names_sched = [STATEMENT_VAR_NAME] + dom_inames_ordered[:] + sched_space = isl.Space.create_from_names( + isl.DEFAULT_CONTEXT, + in_=in_names_sched, + out=lex_dim_names, + params=[], + ) + + # Insert 'statement' dim into domain so that its space allows + # for intersection with sched map later + dom_to_intersect = insert_and_name_isl_dims( + dom, dt.set, [STATEMENT_VAR_NAME], 0) + + # Each map will map statement instances -> lex time. + # At this point, statement instance tuples consist of single int. + # Add all inames from domains to each map domain tuple. + tuple_pair = [( + (int_sid, ) + tuple(dom_inames_ordered), + lex_points + )] + + # Note that lex_points may have fewer dims than the out-dim of sched_space + # if sched_space includes concurrent LID/GID dims. This is okay because + # the following symbolic map creation step, when assigning dim values, + # zips the space dims with the lex tuple, and any leftover LID/GID dims + # will not be assigned a value yet, which is what we want. + + # Create map + sched_map = create_symbolic_map_from_tuples( + tuple_pairs_with_domains=zip(tuple_pair, [dom_to_intersect, ]), + space=sched_space, + ) + + # Set inames equal to relevant gid/lid var names + for iname, constraint_dict in par_iname_constraint_dicts.items(): + # Even though all parallel thread dims are active throughout the + # whole kernel, they may be assigned (tagged) to one iname for some + # subset of statements and another iname for a different subset of + # statements (e.g., tiled, paralle. matmul). + # So before adding each parallel iname constraint, make sure the + # iname applies to this statement: + if iname in dom_inames_ordered: + sched_map = sched_map.add_constraint( + isl.Constraint.eq_from_names(sched_map.space, constraint_dict)) + + return sched_map + + # }}} + + pairwise_sios = {} + + for stmt_ids in stmt_id_pairs: + # Determine integer IDs that will represent each statement in mapping + # (dependency map creation assumes sid_before=0 and sid_after=1, unless + # before and after refer to same stmt, in which case + # sid_before=sid_after=0) + int_sids = [0, 0] if stmt_ids[0] == stmt_ids[1] else [0, 1] + + # {{{ Create SIO for intra-thread case (lid0' == lid0, gid0' == gid0, etc) + + # Simplify tuples to the extent possible ------------------------------------ + + lex_tuples = [stmt_inst_to_lex_intra_thread[stmt_id] for stmt_id in stmt_ids] + + # At this point, one of the lex tuples may have more dimensions than + # another; the missing dims are the fastest-updating dims, and their + # values should be zero. Add them. + max_lex_dims = max([len(lex_tuple) for lex_tuple in lex_tuples]) + lex_tuples_padded = [ + _pad_tuple_with_zeros(lex_tuple, max_lex_dims) + for lex_tuple in lex_tuples] + + # Now generate maps from the blueprint -------------------------------------- + + lex_tuples_simplified = _simplify_lex_dims(*lex_tuples_padded) + + # Create names for the output dimensions for sequential loops + seq_lex_dim_names = [ + LEX_VAR_PREFIX+str(i) for i in range(len(lex_tuples_simplified[0]))] + + intra_thread_sched_maps = [ + _get_map_for_stmt( + stmt_id, lex_tuple, int_sid, + seq_lex_dim_names+all_par_lex_dim_names) + for stmt_id, lex_tuple, int_sid + in zip(stmt_ids, lex_tuples_simplified, int_sids) + ] + + # Create pairwise lex order map (pairwise only in the intra-thread case) + lex_order_map = create_lex_order_map( + dim_names=seq_lex_dim_names, + in_dim_mark=BEFORE_MARK, + ) + + # Add lid/gid dims to lex order map + lex_order_map = add_and_name_isl_dims( + lex_order_map, dt.out, all_par_lex_dim_names) + lex_order_map = add_and_name_isl_dims( + lex_order_map, dt.in_, + append_mark_to_strings(all_par_lex_dim_names, mark=BEFORE_MARK)) + # Constrain lid/gid vars to be equal + for var_name in all_par_lex_dim_names: + lex_order_map = add_eq_isl_constraint_from_names( + lex_order_map, var_name, var_name+BEFORE_MARK) + + # Create statement instance ordering, + # maps each statement instance to all statement instances occurring later + sio_intra_thread = get_statement_ordering_map( + *intra_thread_sched_maps, # note, func accepts exactly two maps + lex_order_map, + before_mark=BEFORE_MARK, + ) + + # }}} + + # {{{ Create SIOs for intra-group case (gid0' == gid0, etc) and global case + + def _get_sched_maps_and_sio( + stmt_inst_to_blex, blex_order_map, seq_blex_dim_names): + # (Vars from outside func used here: + # stmt_ids, int_sids, all_par_lex_dim_names) + + # Use *unsimplified* lex tuples w/ blex map, which are already padded + blex_tuples_padded = [stmt_inst_to_blex[stmt_id] for stmt_id in stmt_ids] + + par_sched_maps = [ + _get_map_for_stmt( + stmt_id, blex_tuple, int_sid, + seq_blex_dim_names+all_par_lex_dim_names) # all par names + for stmt_id, blex_tuple, int_sid + in zip(stmt_ids, blex_tuples_padded, int_sids) + ] + + # Note that for the intra-group case, we already constrained GID + # 'before' to equal GID 'after' earlier in _gather_blex_ordering_info() + + # Create statement instance ordering + sio_par = get_statement_ordering_map( + *par_sched_maps, # note, func accepts exactly two maps + blex_order_map, + before_mark=BEFORE_MARK, + ) + + return par_sched_maps, sio_par + + pwsched_intra_group, sio_intra_group = _get_sched_maps_and_sio( + stmt_inst_to_lblex, lblex_order_map, seq_lblex_dim_names) + pwsched_global, sio_global = _get_sched_maps_and_sio( + stmt_inst_to_gblex, gblex_order_map, seq_gblex_dim_names) + + # }}} + + # Store sched maps along with SIOs + pairwise_sios[tuple(stmt_ids)] = StatementOrdering( + sio_intra_thread=sio_intra_thread, + sio_intra_group=sio_intra_group, + sio_global=sio_global, + pwsched_intra_thread=tuple(intra_thread_sched_maps), + pwsched_intra_group=tuple(pwsched_intra_group), + pwsched_global=tuple(pwsched_global), + ) + + # }}} + + return pairwise_sios + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/schedule/checker/utils.py b/loopy/schedule/checker/utils.py new file mode 100644 index 000000000..5d0858dfb --- /dev/null +++ b/loopy/schedule/checker/utils.py @@ -0,0 +1,445 @@ +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import islpy as isl +dt = isl.dim_type + + +def prettier_map_string(map_obj): + return str(map_obj + ).replace("{ ", "{\n").replace(" }", "\n}").replace("; ", ";\n") + + +def insert_and_name_isl_dims(isl_set, dim_type, names, new_idx_start): + new_set = isl_set.insert_dims(dim_type, new_idx_start, len(names)) + for i, name in enumerate(names): + new_set = new_set.set_dim_name(dim_type, new_idx_start+i, name) + return new_set + + +def add_and_name_isl_dims(isl_map, dim_type, names): + new_idx_start = isl_map.dim(dim_type) + new_map = isl_map.add_dims(dim_type, len(names)) + for i, name in enumerate(names): + new_map = new_map.set_dim_name(dim_type, new_idx_start+i, name) + return new_map + + +def reorder_dims_by_name( + isl_set, dim_type, desired_dims_ordered): + """Return an isl_set with the dimensions of the specified dim_type + in the specified order. + + :arg isl_set: A :class:`islpy.Set` whose dimensions are + to be reordered. + + :arg dim_type: A :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the dimension to be reordered. + + :arg desired_dims_ordered: A :class:`list` of :class:`str` elements + representing the desired dimensions in order by dimension name. + + :returns: An :class:`islpy.Set` matching `isl_set` with the + dimension order matching `desired_dims_ordered`. + + """ + + assert dim_type != dt.param + assert set(isl_set.get_var_names(dim_type)) == set(desired_dims_ordered) + + other_dim_type = dt.param + other_dim_len = len(isl_set.get_var_names(other_dim_type)) + + new_set = isl_set.copy() + for desired_idx, name in enumerate(desired_dims_ordered): + + current_idx = new_set.find_dim_by_name(dim_type, name) + if current_idx != desired_idx: + # First move to other dim because isl is stupid + new_set = new_set.move_dims( + other_dim_type, other_dim_len, dim_type, current_idx, 1) + # Now move it where we actually want it + new_set = new_set.move_dims( + dim_type, desired_idx, other_dim_type, other_dim_len, 1) + + return new_set + + +def move_dim_to_index( + isl_map, dim_name, dim_type, destination_idx): + """Return an isl map with the specified dimension moved to + the specified index. + + :arg isl_map: A :class:`islpy.Map`. + + :arg dim_name: A :class:`str` specifying the name of the dimension + to be moved. + + :arg dim_type: A :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the type of dimension to be reordered. + + :arg destination_idx: A :class:`int` specifying the desired dimension + index of the dimention to be moved. + + :returns: An :class:`islpy.Map` matching `isl_map` with the + specified dimension moved to the specified index. + + """ + + assert dim_type != dt.param + + layover_dim_type = dt.param + layover_dim_len = len(isl_map.get_var_names(layover_dim_type)) + + current_idx = isl_map.find_dim_by_name(dim_type, dim_name) + if current_idx == -1: + raise ValueError("Dimension name %s not found in dim type %s of %s" + % (dim_name, dim_type, isl_map)) + + if current_idx != destination_idx: + # First move to other dim because isl is stupid + new_map = isl_map.move_dims( + layover_dim_type, layover_dim_len, dim_type, current_idx, 1) + # Now move it where we actually want it + new_map = new_map.move_dims( + dim_type, destination_idx, layover_dim_type, layover_dim_len, 1) + + return new_map + + +def remove_dim_by_name(isl_map, dim_type, dim_name): + idx = isl_map.find_dim_by_name(dim_type, dim_name) + if idx == -1: + raise ValueError("Dim '%s' not found. Cannot remove dim.") + return isl_map.remove_dims(dim_type, idx, 1) + + +def ensure_dim_names_match_and_align(obj_map, tgt_map): + + # first make sure names match + if not all( + set(obj_map.get_var_names(dt)) == set(tgt_map.get_var_names(dt)) + for dt in + [dt.in_, dt.out, dt.param]): + raise ValueError( + "Cannot align spaces; names don't match:\n%s\n%s" + % (prettier_map_string(obj_map), prettier_map_string(tgt_map)) + ) + + return isl.align_spaces(obj_map, tgt_map) + + +def add_eq_isl_constraint_from_names(isl_map, var1, var2): + # add constraint var1 = var2 + return isl_map.add_constraint( + isl.Constraint.eq_from_names( + isl_map.space, + {1: 0, var1: 1, var2: -1})) + + +def find_and_rename_dim(old_map, dim_types, old_name, new_name, must_exist=False): + new_map = old_map.copy() + for dim_type in dim_types: + idx = new_map.find_dim_by_name(dim_type, old_name) + if idx == -1: + if must_exist: + raise ValueError( + "must_exist=True but did not find old_name %s in %s" + % (old_name, old_map)) + else: + continue + new_map = new_map.set_dim_name(dim_type, idx, new_name) + return new_map + + +def append_mark_to_isl_map_var_names(old_isl_map, dim_type, mark): + """Return an :class:`islpy.Map` with a mark appended to the specified + dimension names. + + :arg old_isl_map: An :class:`islpy.Map`. + + :arg dim_type: An :class:`islpy.dim_type`, i.e., an :class:`int`, + specifying the dimension to be marked. + + :arg mark: A :class:`str` to be appended to the specified dimension + names. If not provided, `mark` defaults to an apostrophe. + + :returns: An :class:`islpy.Map` matching `old_isl_map` with + `mark` appended to the `dim_type` dimension names. + + """ + + new_map = old_isl_map.copy() + for i in range(len(old_isl_map.get_var_names(dim_type))): + new_map = new_map.set_dim_name(dim_type, i, old_isl_map.get_dim_name( + dim_type, i)+mark) + return new_map + + +def append_mark_to_strings(strings, mark): + return [s+mark for s in strings] + + +# {{{ make_dep_map + +def make_dep_map(s, self_dep=False, knl_with_domains=None): + + # TODO put this function in the right place + + from loopy.schedule.checker.schedule import ( + BEFORE_MARK, + STATEMENT_VAR_NAME, + ) + + map_init = isl.Map(s) + + # TODO something smarter than this assert + for dim_name in map_init.get_var_names(dt.in_): + assert BEFORE_MARK not in dim_name + + # append BEFORE_MARK to in-vars + map_marked = append_mark_to_isl_map_var_names( + map_init, dt.in_, BEFORE_MARK) + + # insert statement dims: + map_with_stmts = insert_and_name_isl_dims( + map_marked, dt.in_, [STATEMENT_VAR_NAME+BEFORE_MARK], 0) + map_with_stmts = insert_and_name_isl_dims( + map_with_stmts, dt.out, [STATEMENT_VAR_NAME], 0) + + # assign values 0 or 1 to statement dims + sid_after = 0 if self_dep else 1 + + map_with_stmts = map_with_stmts.add_constraint( + isl.Constraint.eq_from_names( + map_with_stmts.space, + {1: 0, STATEMENT_VAR_NAME+BEFORE_MARK: -1})) + + map_with_stmts = map_with_stmts.add_constraint( + isl.Constraint.eq_from_names( + map_with_stmts.space, + {1: sid_after, STATEMENT_VAR_NAME: -1})) + + if knl_with_domains is not None: + # intersect map with knl domains + inames_in = map_init.get_var_names(dt.in_) + inames_out = map_init.get_var_names(dt.out) + + inames_in_dom = knl_with_domains.get_inames_domain( + inames_in).project_out_except(inames_in, [dt.set]) + inames_out_dom = knl_with_domains.get_inames_domain( + inames_out).project_out_except(inames_out, [dt.set]) + + # mark dependee inames + inames_in_dom_marked = append_mark_to_isl_map_var_names( + inames_in_dom, dt.set, BEFORE_MARK) + + # align spaces adds the stmt var + inames_in_dom_marked_aligned = isl.align_spaces( + inames_in_dom_marked, map_with_stmts.domain(), + obj_bigger_ok=True) # e.g., params might exist + inames_out_dom_aligned = isl.align_spaces( + inames_out_dom, map_with_stmts.range(), + obj_bigger_ok=True) # e.g., params might exist + + map_with_stmts = map_with_stmts.intersect_range( + inames_out_dom_aligned + ).intersect_domain(inames_in_dom_marked_aligned) + + return map_with_stmts + +# }}} + + +def sorted_union_of_names_in_isl_sets( + isl_sets, + set_dim=dt.set): + r"""Return a sorted list of the union of all variable names found in + the provided :class:`islpy.Set`\ s. + """ + + inames = set().union(*[isl_set.get_var_names(set_dim) for isl_set in isl_sets]) + + # Sorting is not necessary, but keeps results consistent between runs + return sorted(inames) + + +def convert_map_to_set(isl_map): + # also works for spaces + n_in_dims = len(isl_map.get_var_names(dt.in_)) + n_out_dims = len(isl_map.get_var_names(dt.out)) + return isl_map.move_dims( + dt.in_, n_in_dims, dt.out, 0, n_out_dims + ).domain(), n_in_dims, n_out_dims + + +def convert_set_back_to_map(isl_set, n_old_in_dims, n_old_out_dims): + return isl.Map.from_domain( + isl_set).move_dims(dt.out, 0, dt.in_, n_old_in_dims, n_old_out_dims) + + +def create_symbolic_map_from_tuples( + tuple_pairs_with_domains, + space, + ): + """Return an :class:`islpy.Map` constructed using the provided space, + mapping input->output tuples provided in `tuple_pairs_with_domains`, + with each set of tuple variables constrained by the domains provided. + + :arg tuple_pairs_with_domains: A :class:`list` with each element being + a tuple of the form `((tup_in, tup_out), domain)`. + `tup_in` and `tup_out` are tuples containing elements of type + :class:`int` and :class:`str` representing values for the + input and output dimensions in `space`, and `domain` is a + :class:`islpy.Set` constraining variable bounds. + + :arg space: A :class:`islpy.Space` to be used to create the map. + + :returns: A :class:`islpy.Map` constructed using the provided space + as follows. For each `((tup_in, tup_out), domain)` in + `tuple_pairs_with_domains`, map + `(tup_in)->(tup_out) : domain`, where `tup_in` and `tup_out` are + numeric or symbolic values assigned to the input and output + dimension variables in `space`, and `domain` specifies conditions + on these values. + + """ + # FIXME allow None for domains + + space_out_names = space.get_var_names(dt.out) + space_in_names = space.get_var_names(dt.in_) + + def _conjunction_of_dim_eq_conditions(dim_names, values, var_name_to_pwaff): + condition = var_name_to_pwaff[0].eq_set(var_name_to_pwaff[0]) + for dim_name, val in zip(dim_names, values): + if isinstance(val, int): + condition = condition \ + & var_name_to_pwaff[dim_name].eq_set(var_name_to_pwaff[0]+val) + else: + condition = condition \ + & var_name_to_pwaff[dim_name].eq_set(var_name_to_pwaff[val]) + return condition + + # Get islvars from space + var_name_to_pwaff = isl.affs_from_space( + space.move_dims( + dt.out, 0, + dt.in_, 0, + len(space_in_names), + ).range() + ) + + # Initialize union of maps to empty + union_of_maps = isl.Map.from_domain( + var_name_to_pwaff[0].eq_set(var_name_to_pwaff[0]+1) # 0 == 1 (false) + ).move_dims( + dt.out, 0, dt.in_, len(space_in_names), len(space_out_names)) + + # Loop through tuple pairs + for (tup_in, tup_out), dom in tuple_pairs_with_domains: + + # Set values for 'in' dimension using tuple vals + condition = _conjunction_of_dim_eq_conditions( + space_in_names, tup_in, var_name_to_pwaff) + + # Set values for 'out' dimension using tuple vals + condition = condition & _conjunction_of_dim_eq_conditions( + space_out_names, tup_out, var_name_to_pwaff) + + # Convert set to map by moving dimensions around + map_from_set = isl.Map.from_domain(condition) + map_from_set = map_from_set.move_dims( + dt.out, 0, dt.in_, + len(space_in_names), len(space_out_names)) + + # Align the *out* dims of dom with the space *in_* dims + # in preparation for intersection + dom_with_set_dim_aligned = reorder_dims_by_name( + dom, dt.set, + space_in_names, + ) + + # Intersect domain with this map + union_of_maps = union_of_maps.union( + map_from_set.intersect_domain(dom_with_set_dim_aligned)) + + return union_of_maps + + +def partition_inames_by_concurrency(knl): + from loopy.kernel.data import ConcurrentTag + conc_inames = set() + non_conc_inames = set() + + all_inames = knl.all_inames() + for iname in all_inames: + if knl.iname_tags_of_type(iname, ConcurrentTag): + conc_inames.add(iname) + else: + non_conc_inames.add(iname) + + return conc_inames, all_inames-conc_inames + + +def get_EnterLoop_inames(linearization_items): + from loopy.schedule import EnterLoop + + # Note: each iname must live in len-1 list to avoid char separation + return set().union(*[ + [item.iname, ] for item in linearization_items + if isinstance(item, EnterLoop) + ]) + + +def create_elementwise_comparison_conjunction_set( + names0, names1, var_name_to_pwaff, op="eq"): + """Create a set constrained by the conjunction of conditions comparing + `names0` to `names1`. + + :arg names0: A list of :class:`str` representing variable names. + + :arg names1: A list of :class:`str` representing variable names. + + :arg var_name_to_pwaff: A dictionary from variable names to :class:`islpy.PwAff` + instances that represent each of the variables + (var_name_to_pwaff may be produced by `islpy.make_zero_and_vars`). The key + '0' is also include and represents a :class:`islpy.PwAff` zero constant. + + :arg op: A :class:`str` describing the operator to use when creating + the set constraints. Options: `eq` for `=`, `lt` for `<` + + :returns: A set involving `var_name_to_pwaff` cosntrained by the constraints + `{names0[0] names1[0] and names0[1] names1[1] and ...}`. + + """ + + # initialize set with constraint that is always true + conj_set = var_name_to_pwaff[0].eq_set(var_name_to_pwaff[0]) + for n0, n1 in zip(names0, names1): + if op == "eq": + conj_set = conj_set & var_name_to_pwaff[n0].eq_set(var_name_to_pwaff[n1]) + elif op == "ne": + conj_set = conj_set & var_name_to_pwaff[n0].ne_set(var_name_to_pwaff[n1]) + elif op == "lt": + conj_set = conj_set & var_name_to_pwaff[n0].lt_set(var_name_to_pwaff[n1]) + + return conj_set diff --git a/loopy/statistics.py b/loopy/statistics.py index 88e930ce4..41bcbb181 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1316,13 +1316,23 @@ def map_subscript(self, expr): except AttributeError: var_tags = frozenset() + is_temp = False if name in self.knl.arg_dict: array = self.knl.arg_dict[name] + elif name in self.knl.temporary_variables: + # this a temporary variable, but might have global address space + from loopy.kernel.data import AddressSpace + array = self.knl.temporary_variables[name] + if array.address_space != AddressSpace.GLOBAL: + # this is a temporary variable + return self.rec(expr.index) + # this is a temporary variable with global address space + is_temp = True else: # this is a temporary variable return self.rec(expr.index) - if not isinstance(array, lp.ArrayArg): + if (not is_temp) and not isinstance(array, lp.ArrayArg): # this array is not in global memory return self.rec(expr.index) diff --git a/loopy/tools.py b/loopy/tools.py index 0e7a50998..84cbce7e9 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -102,6 +102,8 @@ def update_for_BasicSet(self, key_hash, key): # noqa getattr(prn, "print_"+key._base_name)(key) key_hash.update(prn.get_str().encode("utf8")) + update_for_Map = update_for_BasicSet # noqa + def update_for_type(self, key_hash, key): try: method = getattr(self, "update_for_type_"+key.__name__) diff --git a/loopy/transform/add_barrier.py b/loopy/transform/add_barrier.py index 7a220418f..4fa2fbaff 100644 --- a/loopy/transform/add_barrier.py +++ b/loopy/transform/add_barrier.py @@ -90,10 +90,37 @@ def add_barrier(kernel, insn_before="", insn_after="", id_based_on=None, new_kernel = kernel.copy(instructions=kernel.instructions + [barrier_to_add]) if insn_after is not None: + # TODO this should be a new dependency new_kernel = add_dependency(kernel=new_kernel, insn_match=insn_after, depends_on="id:"+id) + for insn_before_id in insns_before: + # make v2 dep: + from loopy.schedule.checker.utils import ( + append_mark_to_strings, + make_dep_map, + ) + from loopy.schedule.checker.schedule import BEFORE_MARK + inames_before = new_kernel.id_to_insn[insn_before_id].within_inames + inames_before_marked = append_mark_to_strings( + inames_before, BEFORE_MARK) + + inames_after = set(within_inames) if within_inames else set() + + shared_inames = inames_after & inames_before + + in_space_str = ", ".join(inames_before_marked) + out_space_str = ", ".join(inames_after) + constraint_str = " and ".join([ + "{0}{1} = {0}".format(iname, BEFORE_MARK) for iname in shared_inames]) + + dep_v2 = make_dep_map( + f"{{ [{in_space_str}] -> [{out_space_str}] : {constraint_str} }}", + knl_with_domains=new_kernel) + from loopy import add_dependency_v2 + new_kernel = add_dependency_v2(new_kernel, id, insn_before_id, dep_v2) + return new_kernel # }}} diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index c3b4a42ee..16847a3b6 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -22,12 +22,13 @@ import islpy as isl -from islpy import dim_type +from islpy import dim_type as dt from loopy.symbolic import ( RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext) from loopy.diagnostic import LoopyError +from pytools import Record from loopy.translation_unit import (TranslationUnit, for_each_kernel) @@ -72,6 +73,8 @@ .. autofunction:: add_inames_to_insn +.. autofunction:: map_domain + .. autofunction:: add_inames_for_unused_hw_axes """ @@ -121,6 +124,997 @@ def prioritize_loops(kernel, loop_priority): # }}} +# {{{ Handle loop nest constraints + +# {{{ Classes to house loop nest constraints + +# {{{ UnexpandedInameSet + +class UnexpandedInameSet(Record): + def __init__(self, inames, complement=False): + Record.__init__( + self, + inames=inames, + complement=complement, + ) + + def contains(self, inames): + if isinstance(inames, set): + return (not (inames & self.inames) if self.complement + else inames.issubset(self.inames)) + else: + return (inames not in self.inames if self.complement + else inames in self.inames) + + def get_inames_represented(self, iname_universe=None): + """Return the set of inames represented by the UnexpandedInameSet + """ + if self.complement: + if not iname_universe: + raise ValueError( + "Cannot expand UnexpandedInameSet %s without " + "iname_universe." % (self)) + return iname_universe-self.inames + else: + return self.inames.copy() + + def __lt__(self, other): + # FIXME is this function really necessary? If so, what should it return? + return self.__hash__() < other.__hash__() + + def __hash__(self): + return hash(repr(self)) + + def update_persistent_hash(self, key_hash, key_builder): + """Custom hash computation function for use with + :class:`pytools.persistent_dict.PersistentDict`. + """ + + key_builder.rec(key_hash, self.inames) + key_builder.rec(key_hash, self.complement) + + def __str__(self): + return "%s{%s}" % ("~" if self.complement else "", + ",".join(i for i in sorted(self.inames))) + +# }}} + + +# {{{ LoopNestConstraints + +class LoopNestConstraints(Record): + def __init__(self, must_nest=None, must_not_nest=None, + must_nest_graph=None): + Record.__init__( + self, + must_nest=must_nest, + must_not_nest=must_not_nest, + must_nest_graph=must_nest_graph, + ) + + def __hash__(self): + return hash(repr(self)) + + def update_persistent_hash(self, key_hash, key_builder): + """Custom hash computation function for use with + :class:`pytools.persistent_dict.PersistentDict`. + """ + + key_builder.rec(key_hash, self.must_nest) + key_builder.rec(key_hash, self.must_not_nest) + key_builder.rec(key_hash, self.must_nest_graph) + + def __str__(self): + return "LoopNestConstraints(\n" \ + " must_nest = " + str(self.must_nest) + "\n" \ + " must_not_nest = " + str(self.must_not_nest) + "\n" \ + " must_nest_graph = " + str(self.must_nest_graph) + "\n" \ + ")" + +# }}} + +# }}} + + +# {{{ Initial loop nest constraint creation + +# {{{ process_loop_nest_specification + +def process_loop_nest_specification( + nesting, + max_tuple_size=None, + complement_sets_allowed=True, + ): + + # Ensure that user-supplied nesting conforms to syntax rules, and + # convert string representations of nestings to tuple of UnexpandedInameSets + + import re + + def _raise_loop_nest_input_error(msg): + valid_prio_rules = ( + "Valid `must_nest` description formats: " # noqa + "\"iname, iname, ...\" or (str, str, str, ...), " # noqa + "where str can be of form " # noqa + "\"iname\" or \"{iname, iname, ...}\". " # noqa + "No set complements allowed.\n" # noqa + "Valid `must_not_nest` description tuples must have length 2: " # noqa + "\"iname, iname\", \"iname, ~iname\", or " # noqa + "(str, str), where str can be of form " # noqa + "\"iname\", \"~iname\", \"{iname, iname, ...}\", or " # noqa + "\"~{iname, iname, ...}\"." # noqa + ) + raise ValueError( + "Invalid loop nest prioritization: %s\n" + "Loop nest prioritization formatting rules:\n%s" + % (msg, valid_prio_rules)) + + def _error_on_regex_match(match_str, target_str): + if re.findall(match_str, target_str): + _raise_loop_nest_input_error( + "Unrecognized character(s) %s in nest string %s" + % (re.findall(match_str, target_str), target_str)) + + def _process_iname_set_str(iname_set_str): + # Convert something like ~{i,j} or ~i or "i,j" to an UnexpandedInameSet + + # Remove leading/trailing whitespace + iname_set_str_stripped = iname_set_str.strip() + + if not iname_set_str_stripped: + _raise_loop_nest_input_error( + "Found 0 inames in string %s." + % (iname_set_str)) + + # Process complement sets + if iname_set_str_stripped[0] == "~": + # Make sure compelement is allowed + if not complement_sets_allowed: + _raise_loop_nest_input_error( + "Complement (~) not allowed in this loop nest string %s. " + "If you have a use-case where allowing a currently " + "disallowed set complement would be helpful, and the " + "desired nesting constraint cannot easily be expressed " + "another way, " + "please contact the Loo.py maintainers." + % (iname_set_str)) + + # Remove tilde + iname_set_str_stripped = iname_set_str_stripped[1:] + if "~" in iname_set_str_stripped: + _raise_loop_nest_input_error( + "Multiple complement symbols found in iname set string %s" + % (iname_set_str)) + + # Make sure that braces are included if multiple inames present + if "," in iname_set_str_stripped and not ( + iname_set_str_stripped.startswith("{") and + iname_set_str_stripped.endswith("}")): + _raise_loop_nest_input_error( + "Complements of sets containing multiple inames must " + "enclose inames in braces: %s is not valid." + % (iname_set_str)) + + complement = True + else: + complement = False + + # Remove leading/trailing spaces + iname_set_str_stripped = iname_set_str_stripped.strip(" ") + + # Make sure braces are valid and strip them + if iname_set_str_stripped[0] == "{": + if not iname_set_str_stripped[-1] == "}": + _raise_loop_nest_input_error( + "Invalid braces: %s" % (iname_set_str)) + else: + # Remove enclosing braces + iname_set_str_stripped = iname_set_str_stripped[1:-1] + # (If there are dangling braces around, they will be caught next) + + # Remove any more spaces + iname_set_str_stripped = iname_set_str_stripped.strip() + + # Should be no remaining special characters besides comma and space + _error_on_regex_match(r"([^,\w ])", iname_set_str_stripped) + + # Split by commas or spaces to get inames + inames = re.findall(r"([\w]+)(?:[ |,]*|$)", iname_set_str_stripped) + + # Make sure iname count matches what we expect from comma count + if len(inames) != iname_set_str_stripped.count(",") + 1: + _raise_loop_nest_input_error( + "Found %d inames but expected %d in string %s." + % (len(inames), iname_set_str_stripped.count(",") + 1, + iname_set_str)) + + if len(inames) == 0: + _raise_loop_nest_input_error( + "Found empty set in string %s." + % (iname_set_str)) + + # NOTE this won't catch certain cases of bad syntax, e.g., ("{h i j,,}", "k") + + return UnexpandedInameSet( + set([s.strip() for s in iname_set_str_stripped.split(",")]), + complement=complement) + + if isinstance(nesting, str): + # Enforce that constraints involving iname sets be passed as tuple. + # Iname sets defined negatively with a *single* iname are allowed here. + + # Check for any special characters besides comma, space, and tilde. + # E.g., curly braces would indicate that an iname set was NOT + # passed as a tuple, which is not allowed. + _error_on_regex_match(r"([^,\w~ ])", nesting) + + # Split by comma and process each tier + nesting_as_tuple = tuple( + _process_iname_set_str(set_str) for set_str in nesting.split(",")) + else: + assert isinstance(nesting, (tuple, list)) + # Process each tier + nesting_as_tuple = tuple( + _process_iname_set_str(set_str) for set_str in nesting) + + # Check max_tuple_size + if max_tuple_size and len(nesting_as_tuple) > max_tuple_size: + _raise_loop_nest_input_error( + "Loop nest prioritization tuple %s exceeds max tuple size %d." + % (nesting_as_tuple)) + + # Make sure nesting has len > 1 + if len(nesting_as_tuple) <= 1: + _raise_loop_nest_input_error( + "Loop nest prioritization tuple %s must have length > 1." + % (nesting_as_tuple)) + + # Return tuple of UnexpandedInameSets + return nesting_as_tuple + +# }}} + + +# {{{ constrain_loop_nesting + +@for_each_kernel +def constrain_loop_nesting( + kernel, must_nest=None, must_not_nest=None): + r"""Add the provided constraints to the kernel. + + :arg must_nest: A tuple or comma-separated string representing + an ordering of loop nesting tiers that must appear in the + linearized kernel. Each item in the tuple represents a + :class:`UnexpandedInameSet`\ s. + + :arg must_not_nest: A two-tuple or comma-separated string representing + an ordering of loop nesting tiers that must not appear in the + linearized kernel. Each item in the tuple represents a + :class:`UnexpandedInameSet`\ s. + + """ + + # {{{ Get any current constraints, if they exist + if kernel.loop_nest_constraints: + if kernel.loop_nest_constraints.must_nest: + must_nest_constraints_old = kernel.loop_nest_constraints.must_nest + else: + must_nest_constraints_old = set() + + if kernel.loop_nest_constraints.must_not_nest: + must_not_nest_constraints_old = \ + kernel.loop_nest_constraints.must_not_nest + else: + must_not_nest_constraints_old = set() + + if kernel.loop_nest_constraints.must_nest_graph: + must_nest_graph_old = kernel.loop_nest_constraints.must_nest_graph + else: + must_nest_graph_old = {} + else: + must_nest_constraints_old = set() + must_not_nest_constraints_old = set() + must_nest_graph_old = {} + + # }}} + + # {{{ Process must_nest + + if must_nest: + # {{{ Parse must_nest, check for conflicts, combine with old constraints + + # {{{ Parse must_nest (no complements allowed) + must_nest_tuple = process_loop_nest_specification( + must_nest, complement_sets_allowed=False) + # }}} + + # {{{ Error if someone prioritizes concurrent iname + + from loopy.kernel.data import ConcurrentTag + for iname_set in must_nest_tuple: + for iname in iname_set.inames: + if kernel.iname_tags_of_type(iname, ConcurrentTag): + raise ValueError( + "iname %s tagged with ConcurrentTag, " + "cannot use iname in must-nest constraint %s." + % (iname, must_nest_tuple)) + + # }}} + + # {{{ Update must_nest graph (and check for cycles) + + must_nest_graph_new = update_must_nest_graph( + must_nest_graph_old, must_nest_tuple, kernel.all_inames()) + + # }}} + + # {{{ Make sure must_nest constraints don't violate must_not_nest + # (this may not catch all problems) + check_must_not_nest_against_must_nest_graph( + must_not_nest_constraints_old, must_nest_graph_new) + # }}} + + # {{{ Check for conflicts with inames tagged 'vec' (must be innermost) + + from loopy.kernel.data import VectorizeTag + for iname in kernel.all_inames(): + if kernel.iname_tags_of_type(iname, VectorizeTag) and ( + must_nest_graph_new.get(iname, set())): + # Must-nest graph doesn't allow iname to be a leaf, error + raise ValueError( + "Iname %s tagged as 'vec', but loop nest constraints " + "%s require that iname %s nest outside of inames %s. " + "Vectorized inames must nest innermost; cannot " + "impose loop nest specification." + % (iname, must_nest, iname, + must_nest_graph_new.get(iname, set()))) + + # }}} + + # {{{ Add new must_nest constraints to existing must_nest constraints + must_nest_constraints_new = must_nest_constraints_old | set( + [must_nest_tuple, ]) + # }}} + + # }}} + else: + # {{{ No new must_nest constraints, just keep the old ones + + must_nest_constraints_new = must_nest_constraints_old + must_nest_graph_new = must_nest_graph_old + + # }}} + + # }}} + + # {{{ Process must_not_nest + + if must_not_nest: + # {{{ Parse must_not_nest, check for conflicts, combine with old constraints + + # {{{ Parse must_not_nest; complements allowed; max_tuple_size=2 + + must_not_nest_tuple = process_loop_nest_specification( + must_not_nest, max_tuple_size=2) + + # }}} + + # {{{ Make sure must_not_nest constraints don't violate must_nest + + # (cycles are allowed in must_not_nest constraints) + import itertools + must_pairs = [] + for iname_before, inames_after in must_nest_graph_new.items(): + must_pairs.extend(list(itertools.product([iname_before], inames_after))) + + if not check_must_not_nest(must_pairs, must_not_nest_tuple): + raise ValueError( + "constrain_loop_nesting: nest constraint conflict detected. " + "must_not_nest constraints %s inconsistent with " + "must_nest constraints %s." + % (must_not_nest_tuple, must_nest_constraints_new)) + + # }}} + + # {{{ Add new must_not_nest constraints to exisitng must_not_nest constraints + must_not_nest_constraints_new = must_not_nest_constraints_old | set([ + must_not_nest_tuple, ]) + # }}} + + # }}} + else: + # {{{ No new must_not_nest constraints, just keep the old ones + + must_not_nest_constraints_new = must_not_nest_constraints_old + + # }}} + + # }}} + + nest_constraints = LoopNestConstraints( + must_nest=must_nest_constraints_new, + must_not_nest=must_not_nest_constraints_new, + must_nest_graph=must_nest_graph_new, + ) + + return kernel.copy(loop_nest_constraints=nest_constraints) + +# }}} + + +# {{{ update_must_nest_graph + +def update_must_nest_graph(must_nest_graph, must_nest, all_inames): + # Note: there should *not* be any complements in the must_nest tuples + + from copy import deepcopy + new_graph = deepcopy(must_nest_graph) + + # First, each iname must be a node in the graph + for missing_iname in all_inames - new_graph.keys(): + new_graph[missing_iname] = set() + + # Expand must_nest into (before, after) pairs + must_nest_expanded = _expand_iname_sets_in_tuple(must_nest, all_inames) + + # Update must_nest_graph with new pairs + for before, after in must_nest_expanded: + new_graph[before].add(after) + + # Compute transitive closure + from pytools.graph import compute_transitive_closure, contains_cycle + new_graph_closure = compute_transitive_closure(new_graph) + # Note: compute_transitive_closure now allows cycles, will not error + + # Check for inconsistent must_nest constraints by checking for cycle: + if contains_cycle(new_graph_closure): + raise ValueError( + "update_must_nest_graph: Nest constraint cycle detected. " + "must_nest constraints %s inconsistent with existing " + "must_nest constraints %s." + % (must_nest, must_nest_graph)) + + return new_graph_closure + +# }}} + + +# {{{ _expand_iname_sets_in_tuple + +def _expand_iname_sets_in_tuple( + iname_sets_tuple, + iname_universe=None, + ): + + # First convert UnexpandedInameSets to sets. + # Note that must_nest constraints cannot be negatively defined. + positively_defined_iname_sets = [ + iname_set.get_inames_represented(iname_universe) + for iname_set in iname_sets_tuple] + + # Now expand all priority tuples into (before, after) pairs using + # Cartesian product of all pairs of sets + # (Assumes prio_sets length > 1) + import itertools + loop_priority_pairs = set() + for i, before_set in enumerate(positively_defined_iname_sets[:-1]): + for after_set in positively_defined_iname_sets[i+1:]: + loop_priority_pairs.update( + list(itertools.product(before_set, after_set))) + + # Make sure no priority tuple contains an iname twice + for prio_tuple in loop_priority_pairs: + if len(set(prio_tuple)) != len(prio_tuple): + raise ValueError( + "Loop nesting %s contains cycle: %s. " + % (iname_sets_tuple, prio_tuple)) + + return loop_priority_pairs + +# }}} + +# }}} + + +# {{{ Checking constraints + +# {{{ check_must_nest + +def check_must_nest(all_loop_nests, must_nest, all_inames): + r"""Determine whether must_nest constraint is satisfied by + all_loop_nests + + :arg all_loop_nests: A list of lists of inames, each representing + the nesting order of nested loops. + + :arg must_nest: A tuple of :class:`UnexpandedInameSet`\ s describing + nestings that must appear in all_loop_nests. + + :returns: A :class:`bool` indicating whether the must nest constraints + are satisfied by the provided loop nesting. + + """ + + # In order to make sure must_nest is satisfied, we + # need to expand all must_nest tiers + + # FIXME instead of expanding tiers into all pairs up front, + # create these pairs one at a time so that we can stop as soon as we fail + + must_nest_expanded = _expand_iname_sets_in_tuple(must_nest) + + # must_nest_expanded contains pairs + for before, after in must_nest_expanded: + found = False + for nesting in all_loop_nests: + if before in nesting and after in nesting and ( + nesting.index(before) < nesting.index(after)): + found = True + break + if not found: + return False + return True + +# }}} + + +# {{{ check_must_not_nest + +def check_must_not_nest(all_loop_nests, must_not_nest): + r"""Determine whether must_not_nest constraint is satisfied by + all_loop_nests + + :arg all_loop_nests: A list of lists of inames, each representing + the nesting order of nested loops. + + :arg must_not_nest: A two-tuple of :class:`UnexpandedInameSet`\ s + describing nestings that must not appear in all_loop_nests. + + :returns: A :class:`bool` indicating whether the must_not_nest constraints + are satisfied by the provided loop nesting. + + """ + + # Note that must_not_nest may only contain two tiers + + for nesting in all_loop_nests: + + # Go through each pair in all_loop_nests + for i, iname_before in enumerate(nesting): + for iname_after in nesting[i+1:]: + + # Check whether it violates must not nest + if (must_not_nest[0].contains(iname_before) + and must_not_nest[1].contains(iname_after)): + # Stop as soon as we fail + return False + return True + +# }}} + + +# {{{ check_all_must_not_nests + +def check_all_must_not_nests(all_loop_nests, must_not_nests): + r"""Determine whether all must_not_nest constraints are satisfied by + all_loop_nests + + :arg all_loop_nests: A list of lists of inames, each representing + the nesting order of nested loops. + + :arg must_not_nests: A set of two-tuples of :class:`UnexpandedInameSet`\ s + describing nestings that must not appear in all_loop_nests. + + :returns: A :class:`bool` indicating whether the must_not_nest constraints + are satisfied by the provided loop nesting. + + """ + + for must_not_nest in must_not_nests: + if not check_must_not_nest(all_loop_nests, must_not_nest): + return False + return True + +# }}} + + +# {{{ loop_nest_constraints_satisfied + +def loop_nest_constraints_satisfied( + all_loop_nests, + must_nest_constraints=None, + must_not_nest_constraints=None, + all_inames=None): + r"""Determine whether must_not_nest constraint is satisfied by + all_loop_nests + + :arg all_loop_nests: A set of lists of inames, each representing + the nesting order of loops. + + :arg must_nest_constraints: An iterable of tuples of + :class:`UnexpandedInameSet`\ s, each describing nestings that must + appear in all_loop_nests. + + :arg must_not_nest_constraints: An iterable of two-tuples of + :class:`UnexpandedInameSet`\ s, each describing nestings that must not + appear in all_loop_nests. + + :returns: A :class:`bool` indicating whether the constraints + are satisfied by the provided loop nesting. + + """ + + # Check must-nest constraints + if must_nest_constraints: + for must_nest in must_nest_constraints: + if not check_must_nest( + all_loop_nests, must_nest, all_inames): + return False + + # Check must-not-nest constraints + if must_not_nest_constraints: + for must_not_nest in must_not_nest_constraints: + if not check_must_not_nest( + all_loop_nests, must_not_nest): + return False + + return True + +# }}} + + +# {{{ check_must_not_nest_against_must_nest_graph + +def check_must_not_nest_against_must_nest_graph( + must_not_nest_constraints, must_nest_graph): + r"""Ensure none of the must_not_nest constraints are violated by + nestings represented in the must_nest_graph + + :arg must_not_nest_constraints: A set of two-tuples of + :class:`UnexpandedInameSet`\ s describing nestings that must not appear + in loop nestings. + + :arg must_nest_graph: A :class:`dict` mapping each iname to other inames + that must be nested inside it. + + """ + + if must_not_nest_constraints and must_nest_graph: + import itertools + must_pairs = [] + for iname_before, inames_after in must_nest_graph.items(): + must_pairs.extend( + list(itertools.product([iname_before], inames_after))) + if any(not check_must_not_nest(must_pairs, must_not_nest_tuple) + for must_not_nest_tuple in must_not_nest_constraints): + raise ValueError( + "Nest constraint conflict detected. " + "must_not_nest constraints %s inconsistent with " + "must_nest relationships (must_nest graph: %s)." + % (must_not_nest_constraints, must_nest_graph)) + +# }}} + + +# {{{ get_iname_nestings + +def get_iname_nestings(linearization): + """Return a list of iname tuples representing the deepest loop nestings + in a kernel linearization. + """ + from loopy.schedule import EnterLoop, LeaveLoop + nestings = [] + current_tiers = [] + already_exiting_loops = False + for lin_item in linearization: + if isinstance(lin_item, EnterLoop): + already_exiting_loops = False + current_tiers.append(lin_item.iname) + elif isinstance(lin_item, LeaveLoop): + if not already_exiting_loops: + nestings.append(tuple(current_tiers)) + already_exiting_loops = True + del current_tiers[-1] + return nestings + +# }}} + + +# {{{ get_graph_sources + +def get_graph_sources(graph): + sources = set(graph.keys()) + for non_sources in graph.values(): + sources -= non_sources + return sources + +# }}} + +# }}} + + +# {{{ updating constraints during transformation + +# {{{ replace_inames_in_nest_constraints + +def replace_inames_in_nest_constraints( + inames_to_replace, replacement_inames, old_constraints, + coalesce_new_iname_duplicates=False, + ): + """ + :arg inames_to_replace: A set of inames that may exist in + `old_constraints`, each of which is to be replaced with all inames + in `replacement_inames`. + + :arg replacement_inames: A set of inames, all of which will repalce each + iname in `inames_to_replace` in `old_constraints`. + + :arg old_constraints: An iterable of tuples containing one or more + :class:`UnexpandedInameSet` objects. + """ + + # replace each iname in inames_to_replace + # with *all* inames in replacement_inames + + # loop through old_constraints and handle each nesting independently + new_constraints = set() + for old_nesting in old_constraints: + # loop through each iname_set in this nesting and perform replacement + new_nesting = [] + for iname_set in old_nesting: + + # find inames to be replaced + inames_found = inames_to_replace & iname_set.inames + + # create the new set of inames with the replacements + if inames_found: + new_inames = iname_set.inames - inames_found + new_inames.update(replacement_inames) + else: + new_inames = iname_set.inames.copy() + + new_nesting.append( + UnexpandedInameSet(new_inames, iname_set.complement)) + + # if we've removed things, new_nesting might only contain 1 item, + # in which case it's meaningless and we should just remove it + if len(new_nesting) > 1: + new_constraints.add(tuple(new_nesting)) + + # When joining inames, we may need to coalesce: + # e.g., if we join `i` and `j` into `ij`, and old_nesting was + # [{i, k}, {j, h}], at this point we have [{ij, k}, {ij, h}] + # which contains a cycle. If coalescing is enabled, change this + # to [{k}, ij, {h}] to remove the cycle. + if coalesce_new_iname_duplicates: + + def coalesce_duplicate_inames_in_nesting(nesting, coalesce_candidates): + # TODO would like this to be fully generic, but for now, assumes + # all UnexpandedInameSets have complement=False, which works if + # we're only using this for must_nest constraints since they cannot + # have complements + for iname_set in nesting: + assert not iname_set.complement + + import copy + # copy and convert nesting to list so we can modify + coalesced_nesting = list(copy.deepcopy(nesting)) + + # repeat coalescing step until we don't find any adjacent pairs + # containing duplicates (among coalesce_candidates) + found_duplicates = True + while found_duplicates: + found_duplicates = False + # loop through each iname_set in nesting and coalesce + # (assume new_nesting has at least 2 items) + i = 0 + while i < len(coalesced_nesting)-1: + iname_set_before = coalesced_nesting[i] + iname_set_after = coalesced_nesting[i+1] + # coalesce for each iname candidate + for iname in coalesce_candidates: + if (iname_set_before.inames == set([iname, ]) and + iname_set_after.inames == set([iname, ])): + # before/after contain single iname to be coalesced, + # -> remove iname_set_after + del coalesced_nesting[i+1] + found_duplicates = True + elif (iname_set_before.inames == set([iname, ]) and + iname in iname_set_after.inames): + # before contains single iname to be coalesced, + # after contains iname along with others, + # -> remove iname from iname_set_after.inames + coalesced_nesting[i+1] = UnexpandedInameSet( + inames=iname_set_after.inames - set([iname, ]), + complement=iname_set_after.complement, + ) + found_duplicates = True + elif (iname in iname_set_before.inames and + iname_set_after.inames == set([iname, ])): + # after contains single iname to be coalesced, + # before contains iname along with others, + # -> remove iname from iname_set_before.inames + coalesced_nesting[i] = UnexpandedInameSet( + inames=iname_set_before.inames - set([iname, ]), + complement=iname_set_before.complement, + ) + found_duplicates = True + elif (iname in iname_set_before.inames and + iname in iname_set_after.inames): + # before and after contain iname along with others, + # -> remove iname from iname_set_{before,after}.inames + # and insert it in between them + coalesced_nesting[i] = UnexpandedInameSet( + inames=iname_set_before.inames - set([iname, ]), + complement=iname_set_before.complement, + ) + coalesced_nesting[i+1] = UnexpandedInameSet( + inames=iname_set_after.inames - set([iname, ]), + complement=iname_set_after.complement, + ) + coalesced_nesting.insert(i+1, UnexpandedInameSet( + inames=set([iname, ]), + complement=False, + )) + found_duplicates = True + # else, iname was not found in both sets, so do nothing + i = i + 1 + + return tuple(coalesced_nesting) + + # loop through new_constraints; handle each nesting independently + coalesced_constraints = set() + for new_nesting in new_constraints: + coalesced_constraints.add( + coalesce_duplicate_inames_in_nesting( + new_nesting, replacement_inames)) + + return coalesced_constraints + else: + return new_constraints + +# }}} + + +# {{{ replace_inames_in_graph + +def replace_inames_in_graph( + inames_to_replace, replacement_inames, old_graph): + # replace each iname in inames_to_replace with all inames in replacement_inames + + new_graph = {} + iname_to_replace_found_as_key = False + union_of_inames_after_for_replaced_keys = set() + for iname, inames_after in old_graph.items(): + # create new inames_after + new_inames_after = inames_after.copy() + inames_found = inames_to_replace & new_inames_after + + if inames_found: + new_inames_after -= inames_found + new_inames_after.update(replacement_inames) + + # update dict + if iname in inames_to_replace: + iname_to_replace_found_as_key = True + union_of_inames_after_for_replaced_keys = \ + union_of_inames_after_for_replaced_keys | new_inames_after + # don't add this iname as a key in new graph, + # its replacements will be added below + else: + new_graph[iname] = new_inames_after + + # add replacement iname keys + if iname_to_replace_found_as_key: + for new_key in replacement_inames: + new_graph[new_key] = union_of_inames_after_for_replaced_keys.copy() + + # check for cycle + from pytools.graph import contains_cycle + if contains_cycle(new_graph): + raise ValueError( + "replace_inames_in_graph: Loop priority cycle detected. " + "Cannot replace inames %s with inames %s." + % (inames_to_replace, replacement_inames)) + + return new_graph + +# }}} + + +# {{{ replace_inames_in_all_nest_constraints + +def replace_inames_in_all_nest_constraints( + kernel, old_inames, new_inames, + coalesce_new_iname_duplicates=False, + pairs_that_must_not_voilate_constraints=set(), + ): + # replace each iname in old_inames with all inames in new_inames + + # get old must_nest and must_not_nest + # (must_nest_graph will be rebuilt) + if kernel.loop_nest_constraints: + old_must_nest = kernel.loop_nest_constraints.must_nest + old_must_not_nest = kernel.loop_nest_constraints.must_not_nest + # (these could still be None) + else: + old_must_nest = None + old_must_not_nest = None + + if old_must_nest: + # check to make sure special pairs don't conflict with constraints + for iname_before, iname_after in pairs_that_must_not_voilate_constraints: + if iname_before in kernel.loop_nest_constraints.must_nest_graph[ + iname_after]: + raise ValueError( + "Implied nestings violate existing must-nest constraints." + "\nimplied nestings: %s\nmust-nest constraints: %s" + % (pairs_that_must_not_voilate_constraints, old_must_nest)) + + new_must_nest = replace_inames_in_nest_constraints( + old_inames, new_inames, old_must_nest, + coalesce_new_iname_duplicates=coalesce_new_iname_duplicates, + ) + else: + new_must_nest = None + + if old_must_not_nest: + # check to make sure special pairs don't conflict with constraints + if not check_all_must_not_nests( + pairs_that_must_not_voilate_constraints, old_must_not_nest): + raise ValueError( + "Implied nestings violate existing must-not-nest constraints." + "\nimplied nestings: %s\nmust-not-nest constraints: %s" + % (pairs_that_must_not_voilate_constraints, old_must_not_nest)) + + new_must_not_nest = replace_inames_in_nest_constraints( + old_inames, new_inames, old_must_not_nest, + coalesce_new_iname_duplicates=False, + # (for now, never coalesce must-not-nest constraints) + ) + # each must not nest constraint may only contain two tiers + # TODO coalesce_new_iname_duplicates? + else: + new_must_not_nest = None + + # Rebuild must_nest graph + if new_must_nest: + new_must_nest_graph = {} + new_all_inames = ( + kernel.all_inames() - set(old_inames)) | set(new_inames) + from pytools.graph import CycleError + for must_nest_tuple in new_must_nest: + try: + new_must_nest_graph = update_must_nest_graph( + new_must_nest_graph, must_nest_tuple, new_all_inames) + except CycleError: + raise ValueError( + "Loop priority cycle detected when replacing inames %s " + "with inames %s. Previous must_nest constraints: %s" + % (old_inames, new_inames, old_must_nest)) + + # make sure none of the must_nest constraints violate must_not_nest + # this may not catch all problems + check_must_not_nest_against_must_nest_graph( + new_must_not_nest, new_must_nest_graph) + else: + new_must_nest_graph = None + + return kernel.copy( + loop_nest_constraints=LoopNestConstraints( + must_nest=new_must_nest, + must_not_nest=new_must_not_nest, + must_nest_graph=new_must_nest_graph, + ) + ) + +# }}} + +# }}} + +# }}} + + # {{{ split/chunk inames # {{{ backend @@ -270,6 +1264,85 @@ def _split_iname_backend(kernel, iname_to_split, fixed_length, fixed_length_is_inner) for dom in kernel.domains] + # {{{ Split iname in dependencies + + from loopy.transform.instruction import map_dependency_maps + from loopy.schedule.checker.schedule import BEFORE_MARK + from loopy.schedule.checker.utils import ( + convert_map_to_set, + remove_dim_by_name, + ) + + def _split_iname_in_depender(dep): + + # If iname is not present in dep, return unmodified dep + if iname_to_split not in dep.get_var_names(dt.out): + return dep + + # Temporarily convert map to set for processing + set_from_map, n_in_dims, n_out_dims = convert_map_to_set(dep) + + # Split iname + set_from_map = _split_iname_in_set( + set_from_map, iname_to_split, inner_iname, outer_iname, + fixed_length, fixed_length_is_inner) + + # Dim order: [old_inames' ..., old_inames ..., i_outer, i_inner] + + # Convert set back to map + map_from_set = isl.Map.from_domain(set_from_map) + # Move original out dims + 2 new dims: + map_from_set = map_from_set.move_dims( + dt.out, 0, dt.in_, n_in_dims, n_out_dims+2) + + # Remove iname that was split: + map_from_set = remove_dim_by_name( + map_from_set, dt.out, iname_to_split) + + return map_from_set + + def _split_iname_in_dependee(dep): + + iname_to_split_marked = iname_to_split+BEFORE_MARK + + # If iname is not present in dep, return unmodified dep + if iname_to_split_marked not in dep.get_var_names(dt.in_): + return dep + + # Temporarily convert map to set for processing + set_from_map, n_in_dims, n_out_dims = convert_map_to_set(dep) + + # Split iname' + set_from_map = _split_iname_in_set( + set_from_map, iname_to_split_marked, + inner_iname+BEFORE_MARK, outer_iname+BEFORE_MARK, + fixed_length, fixed_length_is_inner) + + # Dim order: [old_inames' ..., old_inames ..., i_outer', i_inner'] + + # Convert set back to map + map_from_set = isl.Map.from_domain(set_from_map) + # Move original out dims new dims: + map_from_set = map_from_set.move_dims( + dt.out, 0, dt.in_, n_in_dims, n_out_dims) + + # Remove iname that was split: + map_from_set = remove_dim_by_name( + map_from_set, dt.in_, iname_to_split_marked) + + return map_from_set + + # TODO figure out proper way to create false match condition + false_id_match = "not id:*" + kernel = map_dependency_maps( + kernel, _split_iname_in_depender, + stmt_match_depender=within, stmt_match_dependee=false_id_match) + kernel = map_dependency_maps( + kernel, _split_iname_in_dependee, + stmt_match_depender=false_id_match, stmt_match_dependee=within) + + # }}} + from pymbolic import var inner = var(inner_iname) outer = var(outer_iname) @@ -308,6 +1381,20 @@ def _split_iname_backend(kernel, iname_to_split, new_prio = new_prio + (prio_iname,) new_priorities.append(new_prio) + # {{{ update nest constraints + + # Add {inner,outer} wherever iname_to_split is found in constraints, while + # still keeping the original around. Then let remove_unused_inames handle + # removal of the old iname if necessary + + # update must_nest, must_not_nest, and must_nest_graph + kernel = replace_inames_in_all_nest_constraints( + kernel, + set([iname_to_split, ]), [iname_to_split, inner_iname, outer_iname], + ) + + # }}} + kernel = kernel.copy( domains=new_domains, iname_slab_increments=iname_slab_increments, @@ -454,15 +1541,15 @@ def make_new_loop_index(inner, outer): if split_iname not in var_dict: continue - dt, idx = var_dict[split_iname] - assert dt == dim_type.set + dim_type, idx = var_dict[split_iname] + assert dim_type == dt.set aff_zero = isl.Aff.zero_on_domain(dom.space) - aff_split_iname = aff_zero.set_coefficient_val(dim_type.in_, idx, 1) + aff_split_iname = aff_zero.set_coefficient_val(dt.in_, idx, 1) aligned_size = isl.align_spaces(size, aff_zero) box_dom = ( dom - .eliminate(dt, idx, 1) + .eliminate(dim_type, idx, 1) & aff_zero.le_set(aff_split_iname) & aff_split_iname.lt_set(aligned_size) ) @@ -545,7 +1632,7 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): from loopy.match import parse_match within = parse_match(within) - # {{{ return the same kernel if no kernel matches + # {{{ return the same kernel if no insn matches if not any(within(kernel, insn) for insn in kernel.instructions): return kernel @@ -566,9 +1653,9 @@ def join_inames(kernel, inames, new_iname=None, tag=None, within=None): "join's leaf domain" % iname) new_domain = domch.domain - new_dim_idx = new_domain.dim(dim_type.set) - new_domain = new_domain.add_dims(dim_type.set, 1) - new_domain = new_domain.set_dim_name(dim_type.set, new_dim_idx, new_iname) + new_dim_idx = new_domain.dim(dt.set) + new_domain = new_domain.add_dims(dt.set, 1) + new_domain = new_domain.set_dim_name(dt.set, new_dim_idx, new_iname) joint_aff = zero = isl.Aff.zero_on_domain(new_domain.space) subst_dict = {} @@ -640,6 +1727,37 @@ def subst_within_inames(fid): applied_iname_rewrites=kernel.applied_iname_rewrites + [subst_dict] )) + # {{{ update must_nest, must_not_nest, and must_nest_graph + + if kernel.loop_nest_constraints and ( + kernel.loop_nest_constraints.must_nest or + kernel.loop_nest_constraints.must_not_nest or + kernel.loop_nest_constraints.must_nest_graph): + + if within != parse_match(None): + raise NotImplementedError( + "join_inames() does not yet handle new loop nest " + "constraints when within is not None.") + + # When joining inames, we create several implied loop nestings. + # make sure that these implied nestings don't violate existing + # constraints. + + # (will fail if cycle is created in must-nest graph) + implied_nestings = set() + inames_orig_order = inames[::-1] # this was reversed above + for i, iname_before in enumerate(inames_orig_order[:-1]): + for iname_after in inames_orig_order[i+1:]: + implied_nestings.add((iname_before, iname_after)) + + kernel = replace_inames_in_all_nest_constraints( + kernel, set(inames), [new_iname], + coalesce_new_iname_duplicates=True, + pairs_that_must_not_voilate_constraints=implied_nestings, + ) + + # }}} + from loopy.match import parse_stack_match within = parse_stack_match(within) @@ -791,6 +1909,7 @@ def parse_tag(tag): # }}} + from loopy.kernel.data import ConcurrentTag, VectorizeTag knl_inames = kernel.inames.copy() for name, new_tag in iname_to_tag.items(): if not new_tag: @@ -801,6 +1920,36 @@ def parse_tag(tag): knl_inames[name] = knl_inames[name].tagged(new_tag) + # {{{ loop nest constraint handling + + if isinstance(new_tag, VectorizeTag): + # {{{ vec_inames will be nested innermost, check whether this + # conflicts with must-nest constraints + must_nest_graph = (kernel.loop_nest_constraints.must_nest_graph + if kernel.loop_nest_constraints else None) + if must_nest_graph and must_nest_graph.get(iname, set()): + # iname is not a leaf + raise ValueError( + "Loop priorities provided specify that iname %s nest " + "outside of inames %s, but vectorized inames " + "must nest innermost. Cannot tag %s with 'vec' tag." + % (iname, must_nest_graph.get(iname, set()), iname)) + # }}} + + elif isinstance(new_tag, ConcurrentTag) and kernel.loop_nest_constraints: + # {{{ Don't allow tagging of must_nest iname as concurrent + must_nest = kernel.loop_nest_constraints.must_nest + if must_nest: + for nesting in must_nest: + for iname_set in nesting: + if iname in iname_set.inames: + raise ValueError("cannot tag '%s' as concurrent--" + "iname involved in must-nest constraint %s." + % (iname, nesting)) + # }}} + + # }}} + return kernel.copy(inames=knl_inames) # }}} @@ -879,7 +2028,7 @@ def duplicate_inames(kernel, inames, within, new_inames=None, suffix=None, new_inames = [iname.strip() for iname in new_inames.split(",")] from loopy.match import parse_stack_match - within = parse_stack_match(within) + within_sm = parse_stack_match(within) if new_inames is None: new_inames = [None] * len(inames) @@ -889,6 +2038,7 @@ def duplicate_inames(kernel, inames, within, new_inames=None, suffix=None, name_gen = kernel.get_var_name_generator() + # Generate new iname names for i, iname in enumerate(inames): new_iname = new_inames[i] @@ -911,17 +2061,60 @@ def duplicate_inames(kernel, inames, within, new_inames=None, suffix=None, # }}} - # {{{ duplicate the inames + # {{{ duplicate the inames in domains for old_iname, new_iname in zip(inames, new_inames): from loopy.kernel.tools import DomainChanger domch = DomainChanger(kernel, frozenset([old_iname])) + # # {{{ update nest constraints + + # (don't remove any unused inames yet, that happens later) + kernel = replace_inames_in_all_nest_constraints( + kernel, set([old_iname, ]), [old_iname, new_iname]) + + # }}} + from loopy.isl_helpers import duplicate_axes kernel = kernel.copy( domains=domch.get_domains_with( duplicate_axes(domch.domain, [old_iname], [new_iname]))) + # {{{ *Rename* iname in dependencies + + # TODO use find_and_rename_dim for simpler code + # (see example in rename_iname) + from loopy.transform.instruction import map_dependency_maps + from loopy.schedule.checker.schedule import BEFORE_MARK + old_iname_p = old_iname+BEFORE_MARK + new_iname_p = new_iname+BEFORE_MARK + + def _rename_iname_in_dim_out(dep): + # update iname in out-dim + out_idx = dep.find_dim_by_name(dt.out, old_iname) + if out_idx != -1: + dep = dep.set_dim_name(dt.out, out_idx, new_iname) + return dep + + def _rename_iname_in_dim_in(dep): + # update iname in in-dim + in_idx = dep.find_dim_by_name(dt.in_, old_iname_p) + if in_idx != -1: + dep = dep.set_dim_name(dt.in_, in_idx, new_iname_p) + return dep + + # TODO figure out proper way to match none + # TODO figure out match vs stack_match + false_id_match = "not id:*" + kernel = map_dependency_maps( + kernel, _rename_iname_in_dim_out, + stmt_match_depender=within, stmt_match_dependee=false_id_match) + kernel = map_dependency_maps( + kernel, _rename_iname_in_dim_in, + stmt_match_depender=false_id_match, stmt_match_dependee=within) + + # }}} + # }}} # {{{ change the inames in the code @@ -930,10 +2123,10 @@ def duplicate_inames(kernel, inames, within, new_inames=None, suffix=None, kernel.substitutions, name_gen) indup = _InameDuplicator(rule_mapping_context, old_to_new=dict(list(zip(inames, new_inames))), - within=within) + within=within_sm) kernel = rule_mapping_context.finish_kernel( - indup.map_kernel(kernel, within=within)) + indup.map_kernel(kernel, within=within_sm)) # }}} @@ -946,6 +2139,18 @@ def duplicate_inames(kernel, inames, within, new_inames=None, suffix=None, # }}} + # TODO why isn't remove_unused_inames called on kernel here? + + # {{{ if there are any now unused inames, remove from nest constraints + + now_unused_inames = (set(inames) - get_used_inames(kernel)) & set(inames) + kernel = replace_inames_in_all_nest_constraints( + kernel, old_inames=now_unused_inames, new_inames=[], + coalesce_new_iname_duplicates=False, + ) + + # }}} + return kernel # }}} @@ -1139,6 +2344,16 @@ def rename_iname(kernel, old_iname, new_iname, existing_ok=False, within=None): "--cannot rename" % new_iname) if does_exist: + + # TODO implement this + if kernel.loop_nest_constraints and ( + kernel.loop_nest_constraints.must_nest or + kernel.loop_nest_constraints.must_not_nest or + kernel.loop_nest_constraints.must_nest_graph): + raise NotImplementedError( + "rename_iname() does not yet handle new loop nest " + "constraints when does_exist=True.") + # {{{ check that the domains match up dom = kernel.get_inames_domain(frozenset((old_iname, new_iname))) @@ -1147,21 +2362,21 @@ def rename_iname(kernel, old_iname, new_iname, existing_ok=False, within=None): _, old_idx = var_dict[old_iname] _, new_idx = var_dict[new_iname] - par_idx = dom.dim(dim_type.param) + par_idx = dom.dim(dt.param) dom_old = dom.move_dims( - dim_type.param, par_idx, dim_type.set, old_idx, 1) + dt.param, par_idx, dt.set, old_idx, 1) dom_old = dom_old.move_dims( - dim_type.set, dom_old.dim(dim_type.set), dim_type.param, par_idx, 1) + dt.set, dom_old.dim(dt.set), dt.param, par_idx, 1) dom_old = dom_old.project_out( - dim_type.set, new_idx if new_idx < old_idx else new_idx - 1, 1) + dt.set, new_idx if new_idx < old_idx else new_idx - 1, 1) - par_idx = dom.dim(dim_type.param) + par_idx = dom.dim(dt.param) dom_new = dom.move_dims( - dim_type.param, par_idx, dim_type.set, new_idx, 1) + dt.param, par_idx, dt.set, new_idx, 1) dom_new = dom_new.move_dims( - dim_type.set, dom_new.dim(dim_type.set), dim_type.param, par_idx, 1) + dt.set, dom_new.dim(dt.set), dt.param, par_idx, 1) dom_new = dom_new.project_out( - dim_type.set, old_idx if old_idx < new_idx else old_idx - 1, 1) + dt.set, old_idx if old_idx < new_idx else old_idx - 1, 1) if not (dom_old <= dom_new and dom_new <= dom_old): raise LoopyError( @@ -1198,6 +2413,44 @@ def rename_iname(kernel, old_iname, new_iname, existing_ok=False, within=None): kernel = kernel.copy(instructions=new_instructions) + # {{{ Rename iname in dependencies + + from loopy.transform.instruction import map_dependency_maps + from loopy.schedule.checker.schedule import BEFORE_MARK + from loopy.schedule.checker.utils import ( + find_and_rename_dim, + ) + old_iname_p = old_iname+BEFORE_MARK + new_iname_p = new_iname+BEFORE_MARK + + def _rename_iname_in_dim_out(dep): + # Update iname in out-dim (depender dim). + + # For now, out_idx should not be -1 because this will only + # be called on dependers + return find_and_rename_dim( + dep, [dt.out], old_iname, new_iname, must_exist=True) + + def _rename_iname_in_dim_in(dep): + # Update iname in in-dim (dependee dim). + + # For now, out_idx should not be -1 because this will only + # be called on dependees + return find_and_rename_dim( + dep, [dt.in_], old_iname_p, new_iname_p, must_exist=True) + + # TODO figure out proper way to match none + # TODO figure out match vs stack_match + false_id_match = "not id:*" + kernel = map_dependency_maps( + kernel, _rename_iname_in_dim_out, + stmt_match_depender=within, stmt_match_dependee=false_id_match) + kernel = map_dependency_maps( + kernel, _rename_iname_in_dim_in, + stmt_match_depender=false_id_match, stmt_match_dependee=within) + + # }}} + else: kernel = duplicate_inames( kernel, [old_iname], within=within, new_inames=[new_iname]) @@ -1225,6 +2478,19 @@ def get_used_inames(kernel): return used_inames +def remove_vars_from_set(s, remove_vars): + from copy import deepcopy + new_s = deepcopy(s) + for var in remove_vars: + try: + dim_type, idx = s.get_var_dict()[var] + except KeyError: + continue + else: + new_s = new_s.project_out(dim_type, idx, 1) + return new_s + + @for_each_kernel def remove_unused_inames(kernel, inames=None): """Delete those among *inames* that are unused, i.e. project them @@ -1252,22 +2518,35 @@ def remove_unused_inames(kernel, inames=None): # {{{ remove them - domains = kernel.domains - for iname in unused_inames: - new_domains = [] + new_domains = [] + for dom in kernel.domains: + new_domains.append(remove_vars_from_set(dom, unused_inames)) + + kernel = kernel.copy(domains=new_domains) - for dom in domains: - try: - dt, idx = dom.get_var_dict()[iname] - except KeyError: - pass - else: - dom = dom.project_out(dt, idx, 1) - new_domains.append(dom) + # }}} + + # {{{ Remove inames from deps + + from loopy.transform.instruction import map_dependency_maps + from loopy.schedule.checker.schedule import BEFORE_MARK + from loopy.schedule.checker.utils import append_mark_to_strings + unused_inames_marked = append_mark_to_strings(unused_inames, BEFORE_MARK) - domains = new_domains + def _remove_iname_from_dep(dep): + return remove_vars_from_set( + remove_vars_from_set(dep, unused_inames), unused_inames_marked) - kernel = kernel.copy(domains=domains) + kernel = map_dependency_maps(kernel, _remove_iname_from_dep) + + # }}} + + # {{{ Remove inames from loop nest constraints + + kernel = replace_inames_in_all_nest_constraints( + kernel, old_inames=unused_inames, new_inames=[], + coalesce_new_iname_duplicates=False, + ) # }}} @@ -1577,10 +2856,10 @@ def parse_equation(eqn): # add inames to domain with correct dim_types dom_new_inames = list(dom_new_inames) for iname in dom_new_inames: - dt = new_iname_dim_types[iname] - iname_idx = dom.dim(dt) - dom = dom.add_dims(dt, 1) - dom = dom.set_dim_name(dt, iname_idx, iname) + dim_type = new_iname_dim_types[iname] + iname_idx = dom.dim(dim_type) + dom = dom.add_dims(dim_type, 1) + dom = dom.set_dim_name(dim_type, iname_idx, iname) # add equations from loopy.symbolic import aff_from_expr @@ -1591,8 +2870,8 @@ def parse_equation(eqn): # project out old inames for iname in dom_old_inames: - dt, idx = dom.get_var_dict()[iname] - dom = dom.project_out(dt, idx, 1) + dim_type, idx = dom.get_var_dict()[iname] + dom = dom.project_out(dim_type, idx, 1) new_domains.append(dom) @@ -1832,6 +3111,467 @@ def add_inames_to_insn(kernel, inames, insn_match): # }}} +# {{{ map_domain + +class _MapDomainMapper(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, within, new_inames, substitutions): + super(_MapDomainMapper, self).__init__(rule_mapping_context) + + self.within = within + + self.old_inames = frozenset(substitutions) + self.new_inames = new_inames + + self.substitutions = substitutions + + def map_reduction(self, expr, expn_state): + red_overlap = frozenset(expr.inames) & self.old_inames + arg_ctx_overlap = frozenset(expn_state.arg_context) & self.old_inames + if (red_overlap + and self.within( + expn_state.kernel, + expn_state.instruction)): + if len(red_overlap) != len(self.old_inames): + raise LoopyError("reduction '%s' involves a part " + "of the map domain inames. Reductions must " + "either involve all or none of the map domain " + "inames." % str(expr)) + + if arg_ctx_overlap: + if arg_ctx_overlap == red_overlap: + # All variables are shadowed by context, that's OK. + return super(_MapDomainMapper, self).map_reduction( + expr, expn_state) + else: + raise LoopyError("reduction '%s' has" + "some of the reduction variables affected " + "by the map_domain shadowed by context. " + "Either all or none must be shadowed." + % str(expr)) + + new_inames = list(expr.inames) + for old_iname in self.old_inames: + new_inames.remove(old_iname) + new_inames.extend(self.new_inames) + + from loopy.symbolic import Reduction + return Reduction(expr.operation, tuple(new_inames), + self.rec(expr.expr, expn_state), + expr.allow_simultaneous) + else: + return super(_MapDomainMapper, self).map_reduction(expr, expn_state) + + def map_variable(self, expr, expn_state): + if (expr.name in self.old_inames + and expr.name not in expn_state.arg_context + and self.within( + expn_state.kernel, + expn_state.instruction)): + return self.substitutions[expr.name] + else: + return super(_MapDomainMapper, self).map_variable(expr, expn_state) + + +def _find_aff_subst_from_map(iname, isl_map): + if not isinstance(isl_map, isl.BasicMap): + raise RuntimeError("isl_map must be a BasicMap") + + dim_type, dim_idx = isl_map.get_var_dict()[iname] + + assert dim_type == dt.in_ + + # Force isl to solve for only this iname on its side of the map, by + # projecting out all other "in" variables. + isl_map = isl_map.project_out( + dim_type, dim_idx+1, isl_map.dim(dim_type)-(dim_idx+1)) + isl_map = isl_map.project_out(dim_type, 0, dim_idx) + dim_idx = 0 + + # Convert map to set to avoid "domain of affine expression should be a set". + # The old "in" variable will be the last of the out_dims. + new_dim_idx = isl_map.dim(dt.out) + isl_map = isl_map.move_dims( + dt.out, isl_map.dim(dt.out), + dim_type, dim_idx, 1) + isl_map = isl_map.range() # now a set + dim_type = dt.set + dim_idx = new_dim_idx + del new_dim_idx + + for cns in isl_map.get_constraints(): + if cns.is_equality() and cns.involves_dims(dim_type, dim_idx, 1): + coeff = cns.get_coefficient_val(dim_type, dim_idx) + cns_zeroed = cns.set_coefficient_val(dim_type, dim_idx, 0) + if cns_zeroed.involves_dims(dim_type, dim_idx, 1): + # not suitable, constraint still involves dim, perhaps in a div + continue + + if coeff.is_one(): + return -cns_zeroed.get_aff() + elif coeff.is_negone(): + return cns_zeroed.get_aff() + else: + # not suitable, coefficient does not have unit coefficient + continue + + raise LoopyError("no suitable equation for '%s' found" % iname) + + +def _apply_identity_for_missing_map_dims(mapping, desired_dims): + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + add_eq_isl_constraint_from_names, + ) + + # If dims in s are missing from transform map, they need to be added + # so that, e.g, intersect_domain doesn't remove them. + # (assume ordering will be handled afterward) + + missing_dims = list( + set(desired_dims) - set(mapping.get_var_names(dt.in_))) + augmented_mapping = add_and_name_isl_dims( + mapping, dt.in_, missing_dims) + + # We want these missing inames to map to themselves so that the map + # has no effect on them. Unfortunatley isl will break if the + # names of the out dims aren't unique, so we will temporariliy rename them + # (and then plan to change the names back afterward). + + # FIXME: need better way to make sure proxy dim names are unique within map + missing_dims_proxies = [d+"__prox" for d in missing_dims] + assert not set(missing_dims_proxies) & set( + augmented_mapping.get_var_dict().keys()) + + augmented_mapping = add_and_name_isl_dims( + augmented_mapping, dt.out, missing_dims_proxies) + + proxy_name_pairs = list(zip(missing_dims, missing_dims_proxies)) + + # Set proxy iname equal to real iname with equality constraint + for real_iname, proxy_iname in proxy_name_pairs: + augmented_mapping = add_eq_isl_constraint_from_names( + augmented_mapping, proxy_iname, real_iname) + + return augmented_mapping, proxy_name_pairs + + +@for_each_kernel +def map_domain(kernel, isl_map, within=None): + # FIXME: Express _split_iname_backend in terms of this + # Missing/deleted for now: + # - slab processing + # - priorities processing + # FIXME: Process priorities + # FIXME: Express affine_map_inames in terms of this, deprecate + # FIXME: Document + + # FIXME: Support within + # FIXME: Right now, this requires all inames in a domain (or none) to + # be mapped. That makes this awkward to use. + + # {{{ within processing (disabled for now) + if within is not None: + raise NotImplementedError("within") + + from loopy.match import parse_match + within = parse_match(within) + + # {{{ return the same kernel if no kernel matches + + if not any(within(kernel, insn) for insn in kernel.instructions): + return kernel + + # }}} + + # }}} + + if not isl_map.is_bijective(): + raise LoopyError("isl_map must be bijective") + + new_inames = frozenset(isl_map.get_var_dict(dt.out)) + old_inames = frozenset(isl_map.get_var_dict(dt.in_)) + + # {{{ solve for representation of old inames in terms of new + + substitutions = {} + var_substitutions = {} + applied_iname_rewrites = kernel.applied_iname_rewrites[:] + + from loopy.symbolic import aff_to_expr + from pymbolic import var + for iname in old_inames: + substitutions[iname] = aff_to_expr( + _find_aff_subst_from_map(iname, isl_map)) + var_substitutions[var(iname)] = aff_to_expr( + _find_aff_subst_from_map(iname, isl_map)) + + applied_iname_rewrites.append(var_substitutions) + del var_substitutions + + # }}} + + from loopy.schedule.checker.schedule import ( + BEFORE_MARK, + STATEMENT_VAR_NAME, + ) + + def _check_overlap_condition_for_domain(s, transform_map_in_names): + + names_to_ignore = set([STATEMENT_VAR_NAME, STATEMENT_VAR_NAME+BEFORE_MARK]) + transform_map_in_inames = transform_map_in_names - names_to_ignore + + var_dict = s.get_var_dict() + + overlap = transform_map_in_inames & frozenset(var_dict) + + # If there is any overlap in the inames in the transform map and s + # (note that we're ignoring the statement var name, which may have been + # added to a transform map or s), all of the transform map inames must be in + # the overlap. + if overlap and len(overlap) != len(transform_map_in_inames): + raise LoopyError("loop domain '%s' involves a part " + "of the map domain inames. Domains must " + "either involve all or none of the map domain " + "inames." % s) + + return overlap + + from loopy.schedule.checker.utils import ( + find_and_rename_dim, + ) + + def process_set(s): + + overlap = _check_overlap_condition_for_domain(s, old_inames) + if not overlap: + # inames in s are not present in transform map, don't change s + return s + + # At this point, overlap condition check guarantees that the + # in-dims of the transform map are a subset of the dims we're + # about to change. + + # {{{ align dims of isl_map and s + + from islpy import _align_dim_type + + map_with_s_domain = isl.Map.from_domain(s) + + # If there are dims in s that are not mapped by isl_map, add them + # to the in/out space of isl_map so that they remain unchanged. + # (temporary proxy dim names are needed in out space of transform + # map because isl won't allow any dim names to match, i.e., instead + # of just mapping {[unused_iname]->[unused_iname]}, we have to map + # {[unused_name]->[unused_name__prox] : unused_name__prox = unused_name}, + # and then rename unused_name__prox afterward.) + augmented_isl_map, proxy_name_pairs = _apply_identity_for_missing_map_dims( + isl_map, s.get_var_names(dt.set)) + + # FIXME: Make this less gross + # FIXME: Make an exported/documented interface of this in islpy + dim_types = [dt.param, dt.in_, dt.out] + s_names = [ + map_with_s_domain.get_dim_name(dim_type, i) + for dim_type in dim_types + for i in range(map_with_s_domain.dim(dim_type)) + ] + map_names = [ + augmented_isl_map.get_dim_name(dim_type, i) + for dim_type in dim_types + for i in range(augmented_isl_map.dim(dim_type)) + ] + + # (order doesn't matter in s_names/map_names, + # _align_dim_type just converts these to sets + # to determine which names are in both the obj and template, + # not sure why this isn't just handled inside _align_dim_type) + aligned_map = _align_dim_type( + dt.param, + augmented_isl_map, map_with_s_domain, False, + map_names, s_names) + aligned_map = _align_dim_type( + dt.in_, + aligned_map, map_with_s_domain, False, + map_names, s_names) + + # }}} + + new_s = aligned_map.intersect_domain(s).range() + + # Now rename the proxy dims back to their original names + for real_iname, proxy_iname in proxy_name_pairs: + new_s = find_and_rename_dim( + new_s, [dt.set], proxy_iname, real_iname) + + return new_s + + # FIXME: Revive _project_out_only_if_all_instructions_in_within + + new_domains = [process_set(dom) for dom in kernel.domains] + + # {{{ update dependencies + + # Prep transform map to be applied to dependency + from loopy.transform.instruction import map_dependency_maps + from loopy.schedule.checker.utils import ( + append_mark_to_isl_map_var_names, + move_dim_to_index, + ) + + # Create version of transform map with before marks + # (for aligning when applying map to dependee portion of deps) + isl_map_marked = append_mark_to_isl_map_var_names( + append_mark_to_isl_map_var_names(isl_map, dt.in_, BEFORE_MARK), + dt.out, BEFORE_MARK) + + def _apply_transform_map_to_depender(dep_map): + # (since 'out' dim of dep is unmarked, use unmarked transform map) + + # Check overlap condition + overlap = _check_overlap_condition_for_domain( + dep_map.range(), set(isl_map.get_var_names(dt.in_))) + + if not overlap: + # Inames in s are not present in depender, don't change dep_map + return dep_map + else: + # At this point, overlap condition check guarantees that the + # in-dims of the transform map are a subset of the dims we're + # about to change. + + # If there are any out-dims (depender dims) in dep_map that are not + # mapped by the transform map, add them to the in/out space of the + # transform map so that they remain unchanged. + # (temporary proxy dim names are needed in out space of transform + # map because isl won't allow any dim names to match, i.e., instead + # of just mapping {[unused_name]->[unused_name]}, we have to map + # {[unused_name]->[unused_name__prox] : unused_name__prox = unused_name}, + # and then rename unused_name__prox afterward.) + ( + augmented_trans_map, proxy_name_pairs + ) = _apply_identity_for_missing_map_dims( + isl_map, dep_map.get_var_names(dt.out)) + + # Align 'in_' dim of transform map with 'out' dim of dep + from loopy.schedule.checker.utils import reorder_dims_by_name + augmented_trans_map_aligned = reorder_dims_by_name( + augmented_trans_map, dt.in_, dep_map.get_var_names(dt.out)) + + # Apply transform map to dep output dims + new_dep_map = dep_map.apply_range(augmented_trans_map_aligned) + + # Now rename the proxy dims back to their original names + for real_iname, proxy_iname in proxy_name_pairs: + new_dep_map = find_and_rename_dim( + new_dep_map, [dt.out], proxy_iname, real_iname) + + # Statement var may have moved, so put it back at the beginning + new_dep_map = move_dim_to_index( + new_dep_map, STATEMENT_VAR_NAME, dt.out, 0) + + return new_dep_map + + def _apply_transform_map_to_dependee(dep_map): + # (since 'in_' dim of dep is marked, use isl_map_marked) + + # Check overlap condition + overlap = _check_overlap_condition_for_domain( + dep_map.domain(), set(isl_map_marked.get_var_names(dt.in_))) + + if not overlap: + # Inames in s are not present in dependee, don't change dep_map + return dep_map + else: + # At this point, overlap condition check guarantees that the + # in-dims of the transform map are a subset of the dims we're + # about to change. + + # If there are any in-dims (dependee dims) in dep_map that are not + # mapped by the transform map, add them to the in/out space of the + # transform map so that they remain unchanged. + # (temporary proxy dim names are needed in out space of transform + # map because isl won't allow any dim names to match, i.e., instead + # of just mapping {[unused_name]->[unused_name]}, we have to map + # {[unused_name]->[unused_name__prox] : unused_name__prox = unused_name}, + # and then rename unused_name__prox afterward.) + ( + augmented_trans_map_marked, proxy_name_pairs + ) = _apply_identity_for_missing_map_dims( + isl_map_marked, dep_map.get_var_names(dt.in_)) + + # Align 'in_' dim of transform map with 'in_' dim of dep + from loopy.schedule.checker.utils import reorder_dims_by_name + augmented_trans_map_aligned = reorder_dims_by_name( + augmented_trans_map_marked, dt.in_, + dep_map.get_var_names(dt.in_)) + + # Apply transform map to dep input dims + new_dep_map = dep_map.apply_domain(augmented_trans_map_aligned) + + # Now rename the proxy dims back to their original names + for real_iname, proxy_iname in proxy_name_pairs: + new_dep_map = find_and_rename_dim( + new_dep_map, [dt.in_], proxy_iname, real_iname) + + # Statement var may have moved, so put it back at the beginning + new_dep_map = move_dim_to_index( + new_dep_map, STATEMENT_VAR_NAME+BEFORE_MARK, dt.in_, 0) + + return new_dep_map + + # TODO figure out proper way to create false match condition + false_id_match = "not id:*" + kernel = map_dependency_maps( + kernel, _apply_transform_map_to_depender, + stmt_match_depender=within, stmt_match_dependee=false_id_match) + kernel = map_dependency_maps( + kernel, _apply_transform_map_to_dependee, + stmt_match_depender=false_id_match, stmt_match_dependee=within) + + # }}} + + # {{{ update within_inames + + new_insns = [] + for insn in kernel.instructions: + overlap = old_inames & insn.within_inames + if overlap and within(kernel, insn): + if len(overlap) != len(old_inames): + raise LoopyError("instruction '%s' is within only a part " + "of the map domain inames. Instructions must " + "either be within all or none of the map domain " + "inames." % insn.id) + + insn = insn.copy( + within_inames=(insn.within_inames - old_inames) | new_inames) + else: + # leave insn unmodified + pass + + new_insns.append(insn) + + # }}} + + kernel = kernel.copy( + domains=new_domains, + instructions=new_insns, + applied_iname_rewrites=applied_iname_rewrites) + + rule_mapping_context = SubstitutionRuleMappingContext( + kernel.substitutions, kernel.get_var_name_generator()) + ins = _MapDomainMapper(rule_mapping_context, within, + new_inames, substitutions) + + kernel = ins.map_kernel(kernel) + kernel = rule_mapping_context.finish_kernel(kernel) + + return kernel + +# }}} + + +# {{{ add_inames_for_unused_hw_axes + @for_each_kernel def add_inames_for_unused_hw_axes(kernel, within=None): """ @@ -1942,4 +3682,6 @@ def add_inames_for_unused_hw_axes(kernel, within=None): return kernel.copy(instructions=new_insns) +# }}} + # vim: foldmethod=marker diff --git a/loopy/transform/instruction.py b/loopy/transform/instruction.py index c7598c356..013555749 100644 --- a/loopy/transform/instruction.py +++ b/loopy/transform/instruction.py @@ -144,6 +144,152 @@ def add_dep(insn): # }}} +# {{{ map dependencies + +# Terminiology: +# stmtX.dependencies: # <- "stmt dependencies" = full dict of deps +# {stmt0: [dep_map00, dep_map01, ...], # <- "one dependency" +# stmt1: [dep_map10, dep_map11, ...], +# ...} +# one dependency includes one "dependency list", which contains "dep maps" + + +def map_stmt_dependencies(kernel, stmt_match, f): + # Set stmt.dependences = f(stmt.dependencies) for stmts matching stmt_match + # Only modifies dependencies for depender! + # Does not search for matching dependees of non-matching depender statements! + + def _update_deps(stmt): + # pass stmt to f because might need info + new_deps = f(stmt.dependencies, stmt) + return stmt.copy(dependencies=new_deps) + + return map_instructions(kernel, stmt_match, _update_deps) + + +def _parse_match_if_necessary(match_candidate): + from loopy.match import ( + MatchExpressionBase, + StackMatch, + ) + if not isinstance( + match_candidate, (MatchExpressionBase, StackMatch)): + from loopy.match import parse_match + # TODO assumes StackMatches are already parsed + # TODO determine when to use parse_stack_match (AKQ) + return parse_match(match_candidate) + else: + return match_candidate + + +def map_dependency_lists( + kernel, f, stmt_match_depender="id:*", stmt_match_dependee="id:*"): + # Set dependency = f(dependency) for: + # All deps of stmts matching stmt_match_depender + # All deps ON stmts matching stmt_match_dependee + # (but doesn't call f() twice if dep matches both depender and dependee) + from loopy.match import ( + StackMatch, + ) + + match_depender = _parse_match_if_necessary(stmt_match_depender) + match_dependee = _parse_match_if_necessary(stmt_match_dependee) + + # TODO figure out right way to simultaneously handle + # both MatchExpressionBase and StackMatch + if isinstance(match_depender, StackMatch): + extra_match_depender_args = [()] + else: + extra_match_depender_args = [] + if isinstance(match_dependee, StackMatch): + extra_match_dependee_args = [()] + else: + extra_match_dependee_args = [] + + new_stmts = [] + for stmt in kernel.instructions: + new_deps = {} + if match_depender(kernel, stmt, *extra_match_depender_args): + # Stmt matches as depender + # Replace all deps + for dep_id, dep_maps in stmt.dependencies.items(): + new_deps[dep_id] = f(dep_maps) + else: + # Stmt didn't match as a depender + # Replace deps matching dependees + for dep_id, dep_maps in stmt.dependencies.items(): + if match_dependee( + kernel, kernel.id_to_insn[dep_id], + *extra_match_dependee_args): + new_deps[dep_id] = f(dep_maps) + else: + new_deps[dep_id] = dep_maps + new_stmts.append(stmt.copy(dependencies=new_deps)) + + return kernel.copy(instructions=new_stmts) + + +def map_dependency_maps( + kernel, f, stmt_match_depender="id:*", stmt_match_dependee="id:*"): + # Set dep_map = f(dep_map) for dep_map in: + # All dependencies of stmts matching stmt_match_depender + # All dependencies ON stmts matching stmt_match_dependee + + def _update_dep_maps(dep_maps): + return [f(dep_map) for dep_map in dep_maps] + + return map_dependency_lists( + kernel, _update_dep_maps, stmt_match_depender, stmt_match_dependee) + +# }}} + + +# {{{ add_dependency_v2 + +@for_each_kernel +def add_dependency_v2( + kernel, stmt_id, depends_on_id, new_dependency): + """Add the statement instance dependency `new_dependency` to the statement with + id `stmt_id`. + + :arg kernel: A :class:`loopy.kernel.LoopKernel`. + + :arg stmt_id: The :class:`str` statement identifier of the statement to + which the dependency will be added. + + :arg depends_on_id: The :class:`str` identifier of the statement that is + depended on, i.e., the statement with statement instances that must + happen before those of `stmt_id`. + + :arg new_dependency: An class:`islpy.Map` from each instance of the first + statement to all instances of the second statement that must occur + later. + + """ + # TODO make this accept multiple deps and/or multiple stmts so that + # these can be added in fewer passes through the instructions + + if stmt_id not in kernel.id_to_insn: + raise LoopyError("no instructions found matching '%s'," + "cannot add dependency %s->%s" + % (stmt_id, depends_on_id, stmt_id)) + if depends_on_id not in kernel.id_to_insn: + raise LoopyError("no instructions found matching '%s'," + "cannot add dependency %s->%s" + % (depends_on_id, depends_on_id, stmt_id)) + + def _add_dep(stmt_deps, stmt): + # stmt_deps: dict mapping depends-on ids to dep maps + stmt_deps.setdefault(depends_on_id, []).append(new_dependency) + return stmt_deps + + result = map_stmt_dependencies(kernel, "id:%s" % (stmt_id), _add_dep) + + return result + +# }}} + + # {{{ remove_instructions def _toposort_of_subset_of_insns(kernel, subset_insns): @@ -234,13 +380,24 @@ def remove_instructions(kernel, insn_ids): assert (new_deps & insn_ids) == frozenset() + # {{{ Remove any new-world stmt inst dependencies on removed stmts + + new_dependencies = insn.dependencies + for removed_id in insn_ids: + # TODO propagate these intelligently? + new_dependencies.pop(removed_id, None) + + # }}} + # update no_sync_with new_no_sync_with = frozenset((insn_id, scope) for insn_id, scope in insn.no_sync_with if insn_id not in insn_ids) - new_insns.append( - insn.copy(depends_on=new_deps, no_sync_with=new_no_sync_with)) + new_insns.append(insn.copy( + depends_on=new_deps, + dependencies=new_dependencies, + no_sync_with=new_no_sync_with)) return kernel.copy( instructions=new_insns) diff --git a/loopy/transform/parameter.py b/loopy/transform/parameter.py index 4916dd4e7..433558fd1 100644 --- a/loopy/transform/parameter.py +++ b/loopy/transform/parameter.py @@ -91,6 +91,34 @@ def process_set(s): new_domains = [process_set(dom) for dom in kernel.domains] + # {{{ Fix parameter in deps + + from loopy.transform.instruction import map_dependency_maps + from loopy.schedule.checker.utils import convert_map_to_set + + def _fix_parameter_in_dep(dep): + # For efficiency: could check for param presence first + dt = isl.dim_type + + # Temporarily convert map to set for processing + set_from_map, n_in_dims, n_out_dims = convert_map_to_set(dep) + + # Fix param + set_from_map = process_set(set_from_map) + + # Now set dims look like [inames' ..., inames ...] + # Convert set back to map + map_from_set = isl.Map.from_domain(set_from_map) + # Move original out dims back + map_from_set = map_from_set.move_dims( + dt.out, 0, dt.in_, n_in_dims, n_out_dims) + + return map_from_set + + kernel = map_dependency_maps(kernel, _fix_parameter_in_dep) + + # }}} + from pymbolic.mapper.substitutor import make_subst_func subst_func = make_subst_func({name: value}) diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index 730f21542..6150f6e30 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -156,6 +156,9 @@ def __init__(self, rule_mapping_context, subst_name, subst_tag, within, self.compute_read_variables = compute_read_variables self.compute_insn_depends_on = set() + # TODO determine whether there's a better strategy for this + self.things_replaced = set() + def map_substitution(self, name, tag, arguments, expn_state): if not ( name == self.subst_name @@ -236,6 +239,9 @@ def map_kernel(self, kernel): insn.depends_on | frozenset([self.compute_dep_id]))) + if hasattr(insn, "id"): + self.things_replaced.add(insn.id) + for dep in insn.depends_on: if dep in excluded_insn_ids: continue @@ -1058,6 +1064,117 @@ def add_assumptions(d): from loopy.kernel.tools import assign_automatic_axes kernel = assign_automatic_axes(kernel, callables_table) + # {{{ update dependencies + # FIXME Handle deps in precompute + """ + # Get some values that will be useful later + fetch_stmt_id = compute_insn_id + fetch_stmt = kernel.id_to_insn[compute_insn_id] + fetch_inames = fetch_stmt.within_inames + + # Go through all stmts that now use the fetch stuff + for usage_stmt_id in invr.things_replaced: + from loopy.schedule.checker.utils import ( + make_dep_map, + append_mark_to_strings, + remove_dim_by_name, + add_and_name_isl_dims, + insert_and_name_isl_dims, + ) + from loopy.schedule.checker.schedule import ( + BEFORE_MARK, + STATEMENT_VAR_NAME, + ) + # Get some values that will be useful later + usage_stmt = kernel.id_to_insn[usage_stmt_id] + usage_inames = usage_stmt.within_inames + shared_inames = fetch_inames & usage_inames + # TODO understand why this isn't true: + # assert shared_inames == usage_stmt.within_inames - set(sweep_inames) + fetch_inames_not_shared = fetch_inames - shared_inames + + # {{{ create dep fetch_stmt->usage_stmt : SAME(shared_inames) + + dep_in_names = list(fetch_inames) # want a copy anyway + dep_in_names_marked = append_mark_to_strings(dep_in_names, BEFORE_MARK) + dep_out_names = usage_inames + + in_space_str = ", ".join(dep_in_names_marked) + out_space_str = ", ".join(dep_out_names) + constraint_str = " and ".join([ + "{0}{1} = {0}".format(iname, BEFORE_MARK) for iname in shared_inames]) + dep_usage_on_fetch = make_dep_map( + f"{{ [{in_space_str}] -> [{out_space_str}] : {constraint_str} }}", + knl_with_domains=kernel) + # (add this dep below after next step) + + # }}} + + from islpy import dim_type as dt + for dependee_id, old_deps in usage_stmt.dependencies.items(): + for old_dep in old_deps: + # old dep: dependee->usage_stmt + # {{{ create dep dependee->fetch_stmt + + new_dep = old_dep.copy() + + old_out_inames = old_dep.get_var_names(dt.out) + assert ( + set(old_out_inames) - set([STATEMENT_VAR_NAME, ]) == + set(usage_inames)) + + non_shared_inames = set(usage_inames) - shared_inames + # Remove inames from old out dims that won't appear in new out dims + for non_shared_iname in non_shared_inames: + new_dep = remove_dim_by_name(new_dep, dt.out, non_shared_iname) + + # These new out inames will take on full domain values + assert ( + (set(usage_inames) - non_shared_inames) | fetch_inames_not_shared + == fetch_inames) + + # Add new_unconstrained_out_names to out dims + new_dep = add_and_name_isl_dims( + new_dep, dt.out, fetch_inames_not_shared) + + # Intersect dom for fetch_inames_not_shared + dom_to_intersect = kernel.get_inames_domain( + fetch_inames_not_shared + ).project_out_except(fetch_inames_not_shared, [dt.set]) + + dom_to_intersect_aligned = isl.align_spaces( + dom_to_intersect, new_dep.range(), + obj_bigger_ok=True) # e.g., params might exist? + + new_dep = new_dep.intersect_range(dom_to_intersect_aligned) + + # {{{ Old dep might have been self-dep, set stmt var correctly + + # add and remove stmt dim + new_dep = remove_dim_by_name(new_dep, dt.out, STATEMENT_VAR_NAME) + new_dep = insert_and_name_isl_dims( + new_dep, dt.out, [STATEMENT_VAR_NAME], 0) + # set stmt dim value + sid_out = 0 if fetch_stmt_id == dependee_id else 1 + new_dep = new_dep.add_constraint( + isl.Constraint.eq_from_names( + new_dep.space, + {1: sid_out, STATEMENT_VAR_NAME: -1})) + # }}} + + # Add this dep: dependee->fetch : dep + kernel = lp.add_dependency_v2( + kernel, fetch_stmt_id, dependee_id, new_dep) + + # }}} + + # Add other new dep from above: fetch->usage + kernel = lp.add_dependency_v2( + kernel, usage_stmt_id, fetch_stmt_id, dep_usage_on_fetch) + + """ + # }}} + return kernel diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index fd6d93f09..7e3d89d6b 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -234,9 +234,9 @@ def __init__(self, rule_mapping_context, lhs_name, definition_insn_ids, self.definition_insn_id_to_subst_name = {} - self.saw_unmatched_usage_sites = {} + self.unmatched_usage_sites_found = {} for def_id in self.definition_insn_ids: - self.saw_unmatched_usage_sites[def_id] = False + self.unmatched_usage_sites_found[def_id] = set() def get_subst_name(self, def_insn_id): try: @@ -278,7 +278,7 @@ def transform_access(self, index, expn_state): expn_state.kernel, expn_state.instruction, expn_state.stack): - self.saw_unmatched_usage_sites[my_def_id] = True + self.unmatched_usage_sites_found[my_def_id].add(my_insn_id) return None subst_name = self.get_subst_name(my_def_id) @@ -362,6 +362,7 @@ def get_relevant_definition_insn_id(usage_insn_id): return def_id usage_to_definition = {} + definition_to_usage_ids = {} for insn in dep_kernel.instructions: if lhs_name not in insn.read_dependency_names(): @@ -374,11 +375,26 @@ def get_relevant_definition_insn_id(usage_insn_id): % (lhs_name, insn.id)) usage_to_definition[insn.id] = def_id + definition_to_usage_ids.setdefault(def_id, set()).add(insn.id) + # these insns may be removed so can't get within_inames later + definition_to_within_inames = {} + for def_id in definition_to_usage_ids.keys(): + definition_to_within_inames[def_id] = kernel.id_to_insn[def_id].within_inames + + # Get deps for subst_def statements before any of them get removed + definition_id_to_deps = {} + from copy import deepcopy definition_insn_ids = set() for insn in kernel.instructions: if lhs_name in insn.write_dependency_names(): definition_insn_ids.add(insn.id) + definition_id_to_deps[insn.id] = deepcopy(insn.dependencies) + + # usage_to_definition maps each usage to the most recent assignment to the var, + # (most recent "definition"), + # so set(usage_to_definition.values()) is a subset of definition_insn_ids, + # which contains ALL the insns where the var is assigned # }}} @@ -443,7 +459,7 @@ def get_relevant_definition_insn_id(usage_insn_id): new_args = kernel.args if lhs_name in kernel.temporary_variables: - if not any(tts.saw_unmatched_usage_sites.values()): + if not any(tts.unmatched_usage_sites_found.values()): # All usage sites matched--they're now substitution rules. # We can get rid of the variable. @@ -451,7 +467,7 @@ def get_relevant_definition_insn_id(usage_insn_id): del new_temp_vars[lhs_name] if lhs_name in kernel.arg_dict and not force_retain_argument: - if not any(tts.saw_unmatched_usage_sites.values()): + if not any(tts.unmatched_usage_sites_found.values()): # All usage sites matched--they're now substitution rules. # We can get rid of the argument @@ -464,13 +480,90 @@ def get_relevant_definition_insn_id(usage_insn_id): # }}} import loopy as lp + # Remove defs if the subst expression is not still used anywhere kernel = lp.remove_instructions( kernel, { insn_id - for insn_id, still_used in tts.saw_unmatched_usage_sites.items() + for insn_id, still_used in tts.unmatched_usage_sites_found.items() if not still_used}) + # {{{ update dependencies + + from loopy.transform.instruction import map_stmt_dependencies + + # Add dependencies from each subst_def to any statement where its + # LHS was found and the subst was performed + for subst_def_id, subst_usage_ids in definition_to_usage_ids.items(): + + unmatched_usage_ids = tts.unmatched_usage_sites_found[subst_def_id] + matched_usage_ids = subst_usage_ids - unmatched_usage_ids + if matched_usage_ids: + import islpy as isl + dt = isl.dim_type + # Create match condition string: + match_any_matched_usage_id = " or ".join( + ["id:%s" % (usage_id) for usage_id in matched_usage_ids]) + + subst_def_deps_dict = definition_id_to_deps[subst_def_id] + old_dep_out_inames = definition_to_within_inames[subst_def_id] + + def _add_deps_to_stmt(old_dep_dict, stmt): + # old_dep_dict: prev dep dict for this stmt + + # want to add old dep from def stmt to usage stmt, + # but if inames of def stmt don't match inames of usage stmt, + # need to get rid of unwanted inames in old dep out dims and add + # any missing inames (inames from usage stmt not present in def stmt) + new_dep_out_inames = stmt.within_inames + out_inames_to_project_out = old_dep_out_inames - new_dep_out_inames + out_inames_to_add = new_dep_out_inames - old_dep_out_inames + # inames_domain for new inames to add + dom_for_new_inames = kernel.get_inames_domain( + out_inames_to_add + ).project_out_except(out_inames_to_add, [dt.set]) + + # process and add the old deps + for depends_on_id, old_dep_list in subst_def_deps_dict.items(): + # pu.db + + new_dep_list = [] + for old_dep in old_dep_list: + # TODO figure out when copies are necessary + new_dep = deepcopy(old_dep) + + # project out inames from old dep (out dim) that don't apply + # to this statement + for old_iname in out_inames_to_project_out: + idx_of_old_iname = old_dep.find_dim_by_name( + dt.out, old_iname) + assert idx_of_old_iname != -1 + new_dep = new_dep.project_out( + dt.out, idx_of_old_iname, 1) + + # add inames from this stmt that were not present in old dep + from loopy.schedule.checker.utils import ( + add_and_name_isl_dims, + ) + new_dep = add_and_name_isl_dims( + new_dep, dt.out, out_inames_to_add) + + # add inames domain for new inames + dom_aligned = isl.align_spaces( + dom_for_new_inames, new_dep.range()) + + # Intersect domain with dep + new_dep = new_dep.intersect_range(dom_aligned) + new_dep_list.append(new_dep) + + old_dep_dict.setdefault(depends_on_id, []).extend(new_dep_list) + return old_dep_dict + + kernel = map_stmt_dependencies( + kernel, match_any_matched_usage_id, _add_deps_to_stmt) + + # }}} + return kernel.copy( substitutions=new_substs, temporary_variables=new_temp_vars, diff --git a/setup.py b/setup.py index 2e907c1b9..701f796d5 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ def write_git_revision(package_name): # https://github.com/inducer/loopy/pull/419 "numpy>=1.19", + "dataclasses>=0.7;python_version<='3.6'", "cgen>=2016.1", "islpy>=2019.1", diff --git a/test/test_linearization_checker.py b/test/test_linearization_checker.py new file mode 100644 index 000000000..8d96e09a0 --- /dev/null +++ b/test/test_linearization_checker.py @@ -0,0 +1,3038 @@ +from __future__ import division, print_function + +__copyright__ = "Copyright (C) 2019 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import six # noqa: F401 +import sys +import numpy as np +import loopy as lp +import islpy as isl +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl + as pytest_generate_tests) +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa +import logging +from loopy import ( + preprocess_kernel, + get_one_linearized_kernel, +) +from loopy.schedule.checker.schedule import ( + LEX_VAR_PREFIX, + STATEMENT_VAR_NAME, + LTAG_VAR_NAMES, + GTAG_VAR_NAMES, + BEFORE_MARK, +) +from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + make_dep_map, +) + +logger = logging.getLogger(__name__) + + +# {{{ Helper functions for map creation/handling + +def _align_and_compare_maps(maps): + from loopy.schedule.checker.utils import prettier_map_string + + for map1, map2 in maps: + # Align maps and compare + map1_aligned = ensure_dim_names_match_and_align(map1, map2) + if map1_aligned != map2: + print("Maps not equal:") + print(prettier_map_string(map1_aligned)) + print(prettier_map_string(map2)) + assert map1_aligned == map2 + + +def _lex_point_string(dim_vals, lid_inames=(), gid_inames=()): + # Return a string describing a point in a lex space + # by assigning values to lex dimension variables + # (used to create maps below) + + return ", ".join( + ["%s%d=%s" % (LEX_VAR_PREFIX, idx, str(val)) + for idx, val in enumerate(dim_vals)] + + ["%s=%s" % (LTAG_VAR_NAMES[idx], iname) + for idx, iname in enumerate(lid_inames)] + + ["%s=%s" % (GTAG_VAR_NAMES[idx], iname) + for idx, iname in enumerate(gid_inames)] + ) + + +def _isl_map_with_marked_dims(s, placeholder_mark="'"): + # For creating legible tests, map strings may be created with a placeholder + # for the 'before' mark. Replace this placeholder with BEFORE_MARK before + # creating the map. + # ALSO, if BEFORE_MARK == "'", ISL will ignore this mark when creating + # variable names, so it must be added manually. + from loopy.schedule.checker.utils import ( + append_mark_to_isl_map_var_names, + ) + dt = isl.dim_type + if BEFORE_MARK == "'": + # ISL will ignore the apostrophe; manually name the in_ vars + return append_mark_to_isl_map_var_names( + isl.Map(s.replace(placeholder_mark, BEFORE_MARK)), + dt.in_, + BEFORE_MARK) + else: + return isl.Map(s.replace(placeholder_mark, BEFORE_MARK)) + + +def _check_orderings_for_stmt_pair( + stmt_id_before, + stmt_id_after, + all_sios, + sio_intra_thread_exp=None, + sched_before_intra_thread_exp=None, + sched_after_intra_thread_exp=None, + sio_intra_group_exp=None, + sched_before_intra_group_exp=None, + sched_after_intra_group_exp=None, + sio_global_exp=None, + sched_before_global_exp=None, + sched_after_global_exp=None, + ): + + order_info = all_sios[(stmt_id_before, stmt_id_after)] + + # Get pairs of maps to compare for equality + map_candidates = zip([ + sio_intra_thread_exp, + sched_before_intra_thread_exp, sched_after_intra_thread_exp, + sio_intra_group_exp, + sched_before_intra_group_exp, sched_after_intra_group_exp, + sio_global_exp, + sched_before_global_exp, sched_after_global_exp, + ], [ + order_info.sio_intra_thread, + order_info.pwsched_intra_thread[0], order_info.pwsched_intra_thread[1], + order_info.sio_intra_group, + order_info.pwsched_intra_group[0], order_info.pwsched_intra_group[1], + order_info.sio_global, + order_info.pwsched_global[0], order_info.pwsched_global[1], + ]) + + # Only compare to maps that were passed + maps_to_compare = [(m1, m2) for m1, m2 in map_candidates if m1 is not None] + _align_and_compare_maps(maps_to_compare) + + +def _process_and_linearize(prog, knl_name="loopy_kernel"): + # Return linearization items along with the preprocessed kernel and + # linearized kernel + proc_prog = preprocess_kernel(prog) + lin_prog = get_one_linearized_kernel( + proc_prog[knl_name], proc_prog.callables_table) + return lin_prog.linearization, proc_prog[knl_name], lin_prog + + +def _get_runinstruction_ids_from_linearization(lin_items): + from loopy.schedule import RunInstruction + return [ + lin_item.insn_id for lin_item in lin_items + if isinstance(lin_item, RunInstruction)] + +# }}} + + +# {{{ Helper functions for dependency tests + + +def _compare_dependencies( + prog, deps_expected, return_unsatisfied=False, knl_name="loopy_kernel"): + + deps_found = {} + for stmt in prog[knl_name].instructions: + if hasattr(stmt, "dependencies") and stmt.dependencies: + deps_found[stmt.id] = stmt.dependencies + + assert deps_found.keys() == deps_expected.keys() + + for stmt_id_after, dep_dict_found in deps_found.items(): + + dep_dict_expected = deps_expected[stmt_id_after] + + # Ensure deps for stmt_id_after match + assert dep_dict_found.keys() == dep_dict_expected.keys() + + for stmt_id_before, dep_list_found in dep_dict_found.items(): + + # Ensure deps from (stmt_id_before -> stmt_id_after) match + dep_list_expected = dep_dict_expected[stmt_id_before] + print("comparing deps %s->%s" % (stmt_id_before, stmt_id_after)) + assert len(dep_list_found) == len(dep_list_expected) + _align_and_compare_maps(zip(dep_list_found, dep_list_expected)) + + if not return_unsatisfied: + return + + # Get unsatisfied deps + lin_items, proc_prog, lin_prog = _process_and_linearize(prog, knl_name) + unsatisfied_deps = lp.find_unsatisfied_dependencies(proc_prog, lin_items) + + # Make sure dep checking also works with just linearized kernel + unsatisfied_deps_2 = lp.find_unsatisfied_dependencies(lin_prog) + assert len(unsatisfied_deps) == len(unsatisfied_deps_2) + + return unsatisfied_deps + +# }}} + + +# {{{ test_intra_thread_pairwise_schedule_creation() + +def test_intra_thread_pairwise_schedule_creation(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + # Example kernel + # stmt_c depends on stmt_b only to create deterministic order + # stmt_d depends on stmt_c only to create deterministic order + knl = lp.make_kernel( + [ + "{[i]: 0<=itemp = b[i,k] {id=stmt_a} + end + for j + a[i,j] = temp + 1 {id=stmt_b,dep=stmt_a} + c[i,j] = d[i,j] {id=stmt_c,dep=stmt_b} + end + end + for t + e[t] = f[t] {id=stmt_d, dep=stmt_c} + end + """, + assumptions="pi,pj,pk,pt >= 1", + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32, "d": np.float32, "f": np.float32}) + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ("stmt_a", "stmt_c"), + ("stmt_a", "stmt_d"), + ("stmt_b", "stmt_c"), + ("stmt_b", "stmt_d"), + ("stmt_c", "stmt_d"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + ) + + # {{{ Relationship between stmt_a and stmt_b + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi, pk] -> { [%s=0, i, k] -> [%s] : 0 <= i < pi and 0 <= k < pk }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "0"]), + ) + ) + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=1, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "1"]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_b_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_a and stmt_c + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi, pk] -> { [%s=0, i, k] -> [%s] : 0 <= i < pi and 0 <= k < pk }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "0"]), + ) + ) + + sched_stmt_c_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=1, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "1"]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_c", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_c_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_a and stmt_d + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi, pk] -> { [%s=0, i, k] -> [%s] : 0 <= i < pi and 0 <= k < pk }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([0, ]), + ) + ) + + sched_stmt_d_intra_thread_exp = isl.Map( + "[pt] -> { [%s=1, t] -> [%s] : 0 <= t < pt }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([1, ]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_d", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_d_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_b and stmt_c + + # Create expected maps and compare + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=0, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "j", 0]), + ) + ) + + sched_stmt_c_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=1, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string(["i", "j", 1]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_c", pworders, + sched_before_intra_thread_exp=sched_stmt_b_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_c_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_b and stmt_d + + # Create expected maps and compare + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=0, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([0, ]), + ) + ) + + sched_stmt_d_intra_thread_exp = isl.Map( + "[pt] -> { [%s=1, t] -> [%s] : 0 <= t < pt }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([1, ]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_d", pworders, + sched_before_intra_thread_exp=sched_stmt_b_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_d_intra_thread_exp, + ) + + # }}} + + # {{{ Relationship between stmt_c and stmt_d + + # Create expected maps and compare + + sched_stmt_c_intra_thread_exp = isl.Map( + "[pi, pj] -> { [%s=0, i, j] -> [%s] : 0 <= i < pi and 0 <= j < pj }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([0, ]), + ) + ) + + sched_stmt_d_intra_thread_exp = isl.Map( + "[pt] -> { [%s=1, t] -> [%s] : 0 <= t < pt }" + % ( + STATEMENT_VAR_NAME, + _lex_point_string([1, ]), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_c", "stmt_d", pworders, + sched_before_intra_thread_exp=sched_stmt_c_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_d_intra_thread_exp, + ) + + # }}} + +# }}} + + +# {{{ test_pairwise_schedule_creation_with_hw_par_tags() + +def test_pairwise_schedule_creation_with_hw_par_tags(): + # (further sched testing in SIO tests below) + + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + # Example kernel + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=i,iitemp = b[i,ii,j,jj] {id=stmt_a} + a[i,ii,j,jj] = temp + 1 {id=stmt_b,dep=stmt_a} + end + end + end + end + """, + assumptions="pi,pj >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32}) + knl = lp.tag_inames(knl, {"j": "l.1", "jj": "l.0", "i": "g.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + ) + + # {{{ Relationship between stmt_a and stmt_b + + # Create expected maps and compare + + sched_stmt_a_intra_thread_exp = isl.Map( + "[pi,pj] -> {[%s=0,i,ii,j,jj] -> [%s] : 0 <= i,ii < pi and 0 <= j,jj < pj}" + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["ii", "0"], + lid_inames=["jj", "j"], gid_inames=["i"], + ), + ) + ) + + sched_stmt_b_intra_thread_exp = isl.Map( + "[pi,pj] -> {[%s=1,i,ii,j,jj] -> [%s] : 0 <= i,ii < pi and 0 <= j,jj < pj}" + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["ii", "1"], + lid_inames=["jj", "j"], gid_inames=["i"], + ), + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sched_before_intra_thread_exp=sched_stmt_a_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_b_intra_thread_exp, + ) + + # }}} + +# }}} + + +# {{{ test_lex_order_map_creation() + +def test_lex_order_map_creation(): + from loopy.schedule.checker.lexicographic_order_map import ( + create_lex_order_map, + ) + + def _check_lex_map(exp_lex_order_map, n_dims): + + lex_order_map = create_lex_order_map( + dim_names=["%s%d" % (LEX_VAR_PREFIX, i) for i in range(n_dims)], + in_dim_mark=BEFORE_MARK, + ) + + assert lex_order_map == exp_lex_order_map + assert lex_order_map.get_var_dict() == exp_lex_order_map.get_var_dict() + + exp_lex_order_map = _isl_map_with_marked_dims( + "{{ " + "[{0}0', {0}1', {0}2', {0}3', {0}4'] -> [{0}0, {0}1, {0}2, {0}3, {0}4] :" + "(" + "{0}0' < {0}0 " + ") or (" + "{0}0'={0}0 and {0}1' < {0}1 " + ") or (" + "{0}0'={0}0 and {0}1'={0}1 and {0}2' < {0}2 " + ") or (" + "{0}0'={0}0 and {0}1'={0}1 and {0}2'={0}2 and {0}3' < {0}3 " + ") or (" + "{0}0'={0}0 and {0}1'={0}1 and {0}2'={0}2 and {0}3'={0}3 and {0}4' < {0}4" + ")" + "}}".format(LEX_VAR_PREFIX)) + + _check_lex_map(exp_lex_order_map, 5) + + exp_lex_order_map = _isl_map_with_marked_dims( + "{{ " + "[{0}0'] -> [{0}0] :" + "(" + "{0}0' < {0}0 " + ")" + "}}".format(LEX_VAR_PREFIX)) + + _check_lex_map(exp_lex_order_map, 1) + +# }}} + + +# {{{ test_intra_thread_statement_instance_ordering() + +def test_intra_thread_statement_instance_ordering(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + # Example kernel (add deps to fix loop order) + knl = lp.make_kernel( + [ + "{[i]: 0<=itemp = b[i,k] {id=stmt_a} + end + for j + a[i,j] = temp + 1 {id=stmt_b,dep=stmt_a} + c[i,j] = d[i,j] {id=stmt_c,dep=stmt_b} + end + end + for t + e[t] = f[t] {id=stmt_d, dep=stmt_c} + end + """, + assumptions="pi,pj,pk,pt >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes( + knl, + {"b": np.float32, "d": np.float32, "f": np.float32}) + knl = lp.prioritize_loops(knl, "i,k") + knl = lp.prioritize_loops(knl, "i,j") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Get pairwise schedules + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ("stmt_a", "stmt_c"), + ("stmt_a", "stmt_d"), + ("stmt_b", "stmt_c"), + ("stmt_b", "stmt_d"), + ("stmt_c", "stmt_d"), + ] + pworders = get_pairwise_statement_orderings( + proc_knl, + lin_items, + stmt_id_pairs, + ) + + # {{{ Relationship between stmt_a and stmt_b + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj, pk] -> {{ " + "[{0}'=0, i', k'] -> [{0}=1, i, j] : " + "0 <= i,i' < pi and 0 <= k' < pk and 0 <= j < pj and i >= i' " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_a and stmt_c + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj, pk] -> {{ " + "[{0}'=0, i', k'] -> [{0}=1, i, j] : " + "0 <= i,i' < pi and 0 <= k' < pk and 0 <= j < pj and i >= i' " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_c", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_a and stmt_d + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pt, pi, pk] -> {{ " + "[{0}'=0, i', k'] -> [{0}=1, t] : " + "0 <= i' < pi and 0 <= k' < pk and 0 <= t < pt " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_d", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_b and stmt_c + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', j'] -> [{0}=1, i, j] : " + "0 <= i,i' < pi and 0 <= j,j' < pj and i > i'; " + "[{0}'=0, i', j'] -> [{0}=1, i=i', j] : " + "0 <= i' < pi and 0 <= j,j' < pj and j >= j'; " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_c", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_b and stmt_d + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pt, pi, pj] -> {{ " + "[{0}'=0, i', j'] -> [{0}=1, t] : " + "0 <= i' < pi and 0 <= j' < pj and 0 <= t < pt " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "stmt_d", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between stmt_c and stmt_d + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pt, pi, pj] -> {{ " + "[{0}'=0, i', j'] -> [{0}=1, t] : " + "0 <= i' < pi and 0 <= j' < pj and 0 <= t < pt " + "}}".format(STATEMENT_VAR_NAME) + ) + + _check_orderings_for_stmt_pair( + "stmt_c", "stmt_d", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + +# }}} + + +# {{{ test_statement_instance_ordering_with_hw_par_tags() + +def test_statement_instance_ordering_with_hw_par_tags(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + from loopy.schedule.checker.utils import ( + partition_inames_by_concurrency, + ) + + # Example kernel + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=i,iitemp = b[i,ii,j,jj] {id=stmt_a} + a[i,ii,j,jj] = temp + 1 {id=stmt_b,dep=stmt_a} + end + end + end + end + """, + assumptions="pi,pj >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32, "b": np.float32}) + knl = lp.tag_inames(knl, {"j": "l.1", "jj": "l.0", "i": "g.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Get pairwise schedules + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + ) + + # Create string for representing parallel iname condition in sio + conc_inames, _ = partition_inames_by_concurrency(knl["loopy_kernel"]) + par_iname_condition = " and ".join( + "{0} = {0}'".format(iname) for iname in conc_inames) + + # {{{ Relationship between stmt_a and stmt_b + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj and ii >= ii' " + "and {1} " + "}}".format( + STATEMENT_VAR_NAME, + par_iname_condition, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + +# }}} + + +# {{{ test_statement_instance_ordering_of_barriers() + +def test_statement_instance_ordering_of_barriers(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + from loopy.schedule.checker.utils import ( + partition_inames_by_concurrency, + ) + + # Example kernel + knl = lp.make_kernel( + [ + "{[i,ii]: 0<=i,iitemp = b[i,ii,j,jj] {id=stmt_a,dep=gbar} + ... lbarrier {id=lbar0,dep=stmt_a} + a[i,ii,j,jj] = temp + 1 {id=stmt_b,dep=lbar0} + ... lbarrier {id=lbar1,dep=stmt_b} + end + end + end + end + <>temp2 = 0.5 {id=stmt_c,dep=lbar1} + """, + assumptions="pi,pj >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a,b": np.float32}) + knl = lp.tag_inames(knl, {"j": "l.0", "i": "g.0"}) + knl = lp.prioritize_loops(knl, "ii,jj") + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Get pairwise schedules + stmt_id_pairs = [ + ("stmt_a", "stmt_b"), + ("gbar", "stmt_a"), + ("stmt_b", "lbar1"), + ("lbar1", "stmt_c"), + ] + pworders = get_pairwise_statement_orderings( + lin_knl, + lin_items, + stmt_id_pairs, + ) + + # Create string for representing parallel iname SAME condition in sio + conc_inames, _ = partition_inames_by_concurrency(knl["loopy_kernel"]) + par_iname_condition = " and ".join( + "{0} = {0}'".format(iname) for iname in conc_inames) + + # {{{ Intra-thread relationship between stmt_a and stmt_b + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " + "and (ii > ii' or (ii = ii' and jj >= jj')) " + "and {1} " + "}}".format( + STATEMENT_VAR_NAME, + par_iname_condition, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sio_intra_thread_exp=sio_intra_thread_exp) + + # }}} + + # {{{ Relationship between gbar and stmt_a + + # intra-thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj < pj " # domains + "and i = i' " # parallel inames must be same + "and ii >= ii' " # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + # TODO figure out what this should be + """ + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj < pj " # domains + "and i = i' " # GID inames must be same + "and (ii > ii' or (ii = ii' and jj = 0))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + """ + + # global case + + sio_global_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj < pj " # domains + "and ii >= ii' " # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "gbar", "stmt_a", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + # sio_intra_group_exp=sio_intra_group_exp, + sio_global_exp=sio_global_exp) + + # }}} + + # {{{ Relationship between stmt_b and lbar1 + + # intra thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' and j = j'" # parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' " # GID parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # global case + + sio_global_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and ii > ii'" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_b", "lbar1", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + sio_global_exp=sio_global_exp, + ) + + # }}} + + # {{{ Relationship between stmt_a and stmt_b + + # intra thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' and j = j'" # parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1, i, ii, j, jj] : " + "0 <= i,ii,i',ii' < pi and 0 <= j,jj,j',jj' < pj " # domains + "and i = i' " # GID parallel inames must be same + "and (ii > ii' or (ii = ii' and jj >= jj'))" # before->after condtion + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "stmt_a", "stmt_b", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + ) + + # }}} + + # {{{ Relationship between lbar1 and stmt_c + + # intra thread case + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1] : " + "0 <= i',ii' < pi and 0 <= j',jj' < pj " # domains + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # intra-group case + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1] : " + "0 <= i',ii' < pi and 0 <= j',jj' < pj " # domains + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + # global case + + # (only happens before if not last iteration of ii + sio_global_exp = _isl_map_with_marked_dims( + "[pi, pj] -> {{ " + "[{0}'=0, i', ii', j', jj'] -> [{0}=1] : " + "0 <= i',ii' < pi and 0 <= j',jj' < pj " # domains + "and ii' < pi-1" + "}}".format( + STATEMENT_VAR_NAME, + ) + ) + + _check_orderings_for_stmt_pair( + "lbar1", "stmt_c", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + sio_global_exp=sio_global_exp, + ) + + # }}} + +# }}} + + +# {{{ test_sios_and_schedules_with_barriers() + +def test_sios_and_schedules_with_barriers(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + assumptions = "ij_end >= ij_start + 1 and lg_end >= 1" + knl = lp.make_kernel( + [ + "{[i,j]: ij_start<=i,jtemp0 = 0 {id=stmt_0} + ... lbarrier {id=stmt_b0,dep=stmt_0} + <>temp1 = 1 {id=stmt_1,dep=stmt_b0} + for i + <>tempi0 = 0 {id=stmt_i0,dep=stmt_1} + ... lbarrier {id=stmt_ib0,dep=stmt_i0} + ... gbarrier {id=stmt_ibb0,dep=stmt_i0} + <>tempi1 = 0 {id=stmt_i1,dep=stmt_ib0} + <>tempi2 = 0 {id=stmt_i2,dep=stmt_i1} + for j + <>tempj0 = 0 {id=stmt_j0,dep=stmt_i2} + ... lbarrier {id=stmt_jb0,dep=stmt_j0} + <>tempj1 = 0 {id=stmt_j1,dep=stmt_jb0} + end + end + <>temp2 = 0 {id=stmt_2,dep=stmt_i0} + end + end + end + """, + assumptions=assumptions, + lang_version=(2018, 2) + ) + knl = lp.tag_inames(knl, {"l0": "l.0", "l1": "l.1", "g0": "g.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [("stmt_j1", "stmt_2"), ("stmt_1", "stmt_i0")] + pworders = get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs) + + # {{{ Relationship between stmt_j1 and stmt_2 + + # Create expected maps and compare + + # Iname bound strings to facilitate creation of expected maps + i_bound_str = "ij_start <= i < ij_end" + i_bound_str_p = "ij_start <= i' < ij_end" + j_bound_str = "ij_start <= j < ij_end" + j_bound_str_p = "ij_start <= j' < ij_end" + ij_bound_str = i_bound_str + " and " + j_bound_str + ij_bound_str_p = i_bound_str_p + " and " + j_bound_str_p + conc_iname_bound_str = "0 <= l0,l1,g0 < lg_end" + conc_iname_bound_str_p = "0 <= l0',l1',g0' < lg_end" + + # {{{ Intra-group + + sched_stmt_j1_intra_group_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=0, i, j, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["2", "i", "2", "j", "1"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_intra_group_exp = isl.Map( + "[lg_end] -> {[%s=1, l0, l1, g0] -> [%s] : %s}" + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["3", "0", "0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{ " + "[{0}'=0, i', j', l0', l1', g0'] -> [{0}=1, l0, l1, g0] : " + "(ij_start <= j' < ij_end-1 or " # not last iteration of j + " ij_start <= i' < ij_end-1) " # not last iteration of i + "and g0 = g0' " # within a single group + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + # {{{ Global + + sched_stmt_j1_global_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=0, i, j, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "i", "1"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_global_exp = isl.Map( + "[lg_end] -> {[%s=1, l0, l1, g0] -> [%s] : " + "%s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["2", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sio_global_exp = _isl_map_with_marked_dims( + "[ij_start,ij_end,lg_end] -> {{ " + "[{0}'=0, i', j', l0', l1', g0'] -> [{0}=1, l0, l1, g0] : " + "ij_start <= i' < ij_end-1 " # not last iteration of i + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + _check_orderings_for_stmt_pair( + "stmt_j1", "stmt_2", pworders, + sio_intra_group_exp=sio_intra_group_exp, + sched_before_intra_group_exp=sched_stmt_j1_intra_group_exp, + sched_after_intra_group_exp=sched_stmt_2_intra_group_exp, + sio_global_exp=sio_global_exp, + sched_before_global_exp=sched_stmt_j1_global_exp, + sched_after_global_exp=sched_stmt_2_global_exp, + ) + + # {{{ Check for some key example pairs in the sio_intra_group map + + # Get maps + order_info = pworders[("stmt_j1", "stmt_2")] + + # As long as this is not the last iteration of the i loop, then there + # should be a barrier between the last instance of statement stmt_j1 + # and statement stmt_2: + ij_end_val = 7 + last_i_val = ij_end_val - 1 + max_non_last_i_val = last_i_val - 1 # max i val that isn't the last iteration + + wanted_pairs = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{" + "[{0}' = 0, i', j'=ij_end-1, g0', l0', l1'] -> [{0} = 1, l0, l1, g0] : " + "ij_start <= i' <= {1} " # constrain i + "and ij_end >= {2} " # constrain ij_end + "and g0 = g0' " # within a single group + "and {3} and {4} " # conc iname bounds + "}}".format( + STATEMENT_VAR_NAME, + max_non_last_i_val, + ij_end_val, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + wanted_pairs = ensure_dim_names_match_and_align( + wanted_pairs, order_info.sio_intra_group) + + assert wanted_pairs.is_subset(order_info.sio_intra_group) + + # If this IS the last iteration of the i loop, then there + # should NOT be a barrier between the last instance of statement stmt_j1 + # and statement stmt_2: + unwanted_pairs = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{" + "[{0}' = 0, i', j'=ij_end-1, g0', l0', l1'] -> [{0} = 1, l0, l1, g0] : " + "ij_start <= i' <= {1} " # constrain i + "and ij_end >= {2} " # constrain p + "and g0 = g0' " # within a single group + "and {3} and {4} " # conc iname bounds + "}}".format( + STATEMENT_VAR_NAME, + last_i_val, + ij_end_val, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + unwanted_pairs = ensure_dim_names_match_and_align( + unwanted_pairs, order_info.sio_intra_group) + + assert not unwanted_pairs.is_subset(order_info.sio_intra_group) + + # }}} + + # }}} + + # {{{ Relationship between stmt_1 and stmt_i0 + + # Create expected maps and compare + + # {{{ Intra-group + + sched_stmt_1_intra_group_exp = isl.Map( + "[lg_end] -> {[%s=0, l0, l1, g0] -> [%s] : " + "%s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "0", "0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sched_stmt_i0_intra_group_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=1, i, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["2", "i", "0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + i_bound_str, + conc_iname_bound_str, + ) + ) + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{ " + "[{0}'=0, l0', l1', g0'] -> [{0}=1, i, l0, l1, g0] : " + "ij_start + 1 <= i < ij_end " # not first iteration of i + "and g0 = g0' " # within a single group + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + i_bound_str, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + # {{{ Global + + sched_stmt_1_global_exp = isl.Map( + "[lg_end] -> {[%s=0, l0, l1, g0] -> [%s] : " + "%s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["0", "0", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + conc_iname_bound_str, + ) + ) + + sched_stmt_i0_global_exp = isl.Map( + "[ij_start, ij_end, lg_end] -> {" + "[%s=1, i, l0, l1, g0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "i", "0"], # lex points + lid_inames=["l0", "l1"], gid_inames=["g0"], + ), + i_bound_str, + conc_iname_bound_str, + ) + ) + + sio_global_exp = _isl_map_with_marked_dims( + "[ij_start, ij_end, lg_end] -> {{ " + "[{0}'=0, l0', l1', g0'] -> [{0}=1, i, l0, l1, g0] : " + "ij_start + 1 <= i < ij_end " # not first iteration of i + "and {1} and {2} and {3} " # iname bounds + "and {4}" # param assumptions + "}}".format( + STATEMENT_VAR_NAME, + i_bound_str, + conc_iname_bound_str, + conc_iname_bound_str_p, + assumptions, + ) + ) + + # }}} + + _check_orderings_for_stmt_pair( + "stmt_1", "stmt_i0", pworders, + sio_intra_group_exp=sio_intra_group_exp, + sched_before_intra_group_exp=sched_stmt_1_intra_group_exp, + sched_after_intra_group_exp=sched_stmt_i0_intra_group_exp, + sio_global_exp=sio_global_exp, + sched_before_global_exp=sched_stmt_1_global_exp, + sched_after_global_exp=sched_stmt_i0_global_exp, + ) + + # }}} + +# }}} + + +# {{{ test_sios_and_schedules_with_vec_and_barriers() + +def test_sios_and_schedules_with_vec_and_barriers(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + + knl = lp.make_kernel( + "{[i, j, l0] : 0 <= i < 4 and 0 <= j < n and 0 <= l0 < 32}", + """ + for l0 + for i + for j + b[i,j,l0] = 1 {id=stmt_1} + ... lbarrier {id=b,dep=stmt_1} + c[i,j,l0] = 2 {id=stmt_2, dep=b} + end + end + end + """) + knl = lp.add_and_infer_dtypes(knl, {"b": "float32", "c": "float32"}) + + knl = lp.tag_inames(knl, {"i": "vec", "l0": "l.0"}) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + stmt_id_pairs = [("stmt_1", "stmt_2")] + pworders = get_pairwise_statement_orderings( + lin_knl, lin_items, stmt_id_pairs) + + # {{{ Relationship between stmt_1 and stmt_2 + + # Create expected maps and compare + + # Iname bound strings to facilitate creation of expected maps + ij_bound_str = "0 <= i < 4 and 0 <= j < n" + ij_bound_str_p = "0 <= i' < 4 and 0 <= j' < n" + conc_iname_bound_str = "0 <= l0 < 32" + conc_iname_bound_str_p = "0 <= l0' < 32" + + # {{{ Intra-thread + + sched_stmt_1_intra_thread_exp = isl.Map( + "[n] -> {" + "[%s=0, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["j", "0"], # lex points (initial matching dim gets removed) + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_intra_thread_exp = isl.Map( + "[n] -> {" + "[%s=1, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["j", "1"], # lex points (initial matching dim gets removed) + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sio_intra_thread_exp = _isl_map_with_marked_dims( + "[n] -> {{ " + "[{0}'=0, i', j', l0'] -> [{0}=1, i, j, l0] : " + "j' <= j " + "and l0 = l0' " # within a single thread + "and {1} and {2} and {3} and {4}" # iname bounds + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + + # }}} + + # {{{ Intra-group + + # Intra-group scheds would be same due to lbarrier, + # but since lex tuples are not simplified in intra-group/global + # cases, there's an extra lex dim: + + sched_stmt_1_intra_group_exp = isl.Map( + "[n] -> {" + "[%s=0, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "j", "0"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sched_stmt_2_intra_group_exp = isl.Map( + "[n] -> {" + "[%s=1, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["1", "j", "1"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sio_intra_group_exp = _isl_map_with_marked_dims( + "[n] -> {{ " + "[{0}'=0, i', j', l0'] -> [{0}=1, i, j, l0] : " + "j' <= j " + "and {1} and {2} and {3} and {4}" # iname bounds + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + + # }}} + + # {{{ Global + + sched_stmt_1_global_exp = isl.Map( + "[n] -> {" + "[%s=0, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["0"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + # (same as stmt_1 except for statement id because no global barriers) + sched_stmt_2_global_exp = isl.Map( + "[n] -> {" + "[%s=1, i, j, l0] -> [%s] : " + "%s and %s}" # iname bounds + % ( + STATEMENT_VAR_NAME, + _lex_point_string( + ["0"], # lex points + lid_inames=["l0"], + ), + ij_bound_str, + conc_iname_bound_str, + ) + ) + + sio_global_exp = _isl_map_with_marked_dims( + "[n] -> {{ " + "[{0}'=0, i', j', l0'] -> [{0}=1, i, j, l0] : " + "False " + "and {1} and {2} and {3} and {4}" # iname bounds + "}}".format( + STATEMENT_VAR_NAME, + ij_bound_str, + ij_bound_str_p, + conc_iname_bound_str, + conc_iname_bound_str_p, + ) + ) + + # }}} + + _check_orderings_for_stmt_pair( + "stmt_1", "stmt_2", pworders, + sio_intra_thread_exp=sio_intra_thread_exp, + sched_before_intra_thread_exp=sched_stmt_1_intra_thread_exp, + sched_after_intra_thread_exp=sched_stmt_2_intra_thread_exp, + sio_intra_group_exp=sio_intra_group_exp, + sched_before_intra_group_exp=sched_stmt_1_intra_group_exp, + sched_after_intra_group_exp=sched_stmt_2_intra_group_exp, + sio_global_exp=sio_global_exp, + sched_before_global_exp=sched_stmt_1_global_exp, + sched_after_global_exp=sched_stmt_2_global_exp, + ) + + # }}} + +# }}} + + +# {{{ test_sios_with_matmul + +def test_sios_with_matmul(): + from loopy.schedule.checker import ( + get_pairwise_statement_orderings, + ) + # For now, this test just ensures all pairwise SIOs can be created + # for a complex parallel kernel without any errors/exceptions. Later PRs + # will examine this kernel's SIOs and related dependencies for accuracy. + + bsize = 16 + knl = lp.make_kernel( + "{[i,k,j]: 0<=i {{ [i'] -> [i] : i > i' " + "and {0} }}".format(assumptions_str), + self_dep=False, knl_with_domains=knl["loopy_kernel"]) + + # test make_dep_map while we're here: + dep_b_on_a_test = _isl_map_with_marked_dims( + "[pi] -> {{ [{3}'=0, i'] -> [{3}=1, i] : i > i' " + "and {0} and {1} and {2} }}".format( + i_range_str, + i_range_str_p, + assumptions_str, + STATEMENT_VAR_NAME, + )) + _align_and_compare_maps([(dep_b_on_a, dep_b_on_a_test)]) + + knl = lp.add_dependency_v2(knl, "stmt_b", "stmt_a", dep_b_on_a) + + _compare_dependencies( + knl, + {"stmt_b": { + "stmt_a": [dep_b_on_a, ]}}) + + # Add a second dependency to stmt_b + dep_b_on_a_2 = make_dep_map( + "[pi] -> {{ [i'] -> [i] : i = i' " + "and {0}}}".format(assumptions_str), + self_dep=False, knl_with_domains=knl["loopy_kernel"]) + + # test make_dep_map while we're here: + dep_b_on_a_2_test = _isl_map_with_marked_dims( + "[pi] -> {{ [{3}'=0, i'] -> [{3}=1, i] : i = i' " + "and {0} and {1} and {2} }}".format( + i_range_str, + i_range_str_p, + assumptions_str, + STATEMENT_VAR_NAME, + )) + _align_and_compare_maps([(dep_b_on_a_2, dep_b_on_a_2_test)]) + + knl = lp.add_dependency_v2(knl, "stmt_b", "stmt_a", dep_b_on_a_2) + + _compare_dependencies( + knl, + {"stmt_b": { + "stmt_a": [dep_b_on_a, dep_b_on_a_2]}}) + + # Add dependencies to stmt_c + # TODO use make_dep_map instead of _isl_map_with_marked_dims where possible + + dep_c_on_a = _isl_map_with_marked_dims( + "[pi] -> {{ [{0}'=0, i'] -> [{0}=1, i] : i >= i' " + "and {1} and {2} and {3} }}".format( + STATEMENT_VAR_NAME, + i_range_str, + i_range_str_p, + assumptions_str, + )) + dep_c_on_b = _isl_map_with_marked_dims( + "[pi] -> {{ [{0}'=0, i'] -> [{0}=1, i] : i >= i' " + "and {1} and {2} and {3} }}".format( + STATEMENT_VAR_NAME, + i_range_str, + i_range_str_p, + assumptions_str, + )) + + knl = lp.add_dependency_v2(knl, "stmt_c", "stmt_a", dep_c_on_a) + knl = lp.add_dependency_v2(knl, "stmt_c", "stmt_b", dep_c_on_b) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmt_b": { + "stmt_a": [dep_b_on_a, dep_b_on_a_2]}, + "stmt_c": { + "stmt_a": [dep_c_on_a, ], "stmt_b": [dep_c_on_b, ]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + +# }}} + + +# {{{ test_make_dep_map + +def test_make_dep_map(): + # This is also tested inside other test functions, but + # here we specifically test case where the statement inames + # don't match + + # Make kernel and use OLD deps to control linearization order for now + i_range_str = "0 <= i < n" + i_range_str_p = "0 <= i' < n" + j_range_str = "0 <= j < n" + j_range_str_p = "0 <= j' < n" + k_range_str = "0 <= k < n" + # k_range_str_p = "0 <= k' < n" # (not used) + knl = lp.make_kernel( + "{[i,j,k]: %s}" % (" and ".join([i_range_str, j_range_str, k_range_str])), + """ + a[i,j] = 3.14 {id=stmt_a} + b[k] = a[i,k] {id=stmt_b, dep=stmt_a} + """, + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a,b": np.float32}) + + for stmt in knl["loopy_kernel"].instructions: + assert not stmt.dependencies + + # Add a dependency to stmt_b + dep_b_on_a = make_dep_map( + "[n] -> { [i',j'] -> [i,k] : i > i' and j' < k}", + self_dep=False, knl_with_domains=knl["loopy_kernel"]) + + # Create expected dep + dep_b_on_a_test = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i', j'] -> [{0}=1, i, k] : i > i' and j' < k" + " and {1} }}".format( + STATEMENT_VAR_NAME, + " and ".join([ + i_range_str, + i_range_str_p, + j_range_str_p, + k_range_str, + ]) + )) + _align_and_compare_maps([(dep_b_on_a, dep_b_on_a_test)]) + +# }}} + + +# {{{ test_new_dependencies_finite_diff: + +def test_new_dependencies_finite_diff(): + + # Define kernel + knl = lp.make_kernel( + "[nx,nt] -> {[x, t]: 0<=x {{ [x', t'] -> [x, t] : " + "((x = x' and t = t'+2) or " + " (x'-1 <= x <= x'+1 and t = t' + 1)) and " + "{0} and {1} }}".format( + xt_range_str, + xt_range_str_p, + ), + self_dep=True) + + knl = lp.add_dependency_v2(knl, "stmt", "stmt", dep) + + ref_knl = knl + + # {{{ Check with corrct loop nest order + + # Prioritize loops correctly + knl = lp.prioritize_loops(knl, "t,x") + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt": {"stmt": [dep, ]}, }, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + # {{{ Check with incorrect loop nest order + + # Now prioritize loops incorrectly + knl = ref_knl + knl = lp.prioritize_loops(knl, "x,t") + + # Compare deps and make sure unsatisfied deps are caught + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt": {"stmt": [dep, ]}, }, + return_unsatisfied=True) + + assert len(unsatisfied_deps) == 1 + + # }}} + # {{{ Check with parallel x and no barrier + + # Parallelize the x loop + knl = ref_knl + knl = lp.prioritize_loops(knl, "t,x") + knl = lp.tag_inames(knl, "x:l.0") + + # Make sure unsatisfied deps are caught + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Without a barrier, deps not satisfied + # Make sure there is no barrier, and that unsatisfied deps are caught + from loopy.schedule import Barrier + for lin_item in lin_items: + assert not isinstance(lin_item, Barrier) + + unsatisfied_deps = lp.find_unsatisfied_dependencies( + proc_knl, lin_items) + + assert len(unsatisfied_deps) == 1 + + # }}} + # {{{ Check with parallel x and included barrier + + # Insert a barrier to satisfy deps + knl = lp.make_kernel( + "[nx,nt] -> {[x, t]: 0<=x= 1 and nx >= 1") + # knl = lp.tag_inames(knl, "x_outer:g.0, x_inner:l.0") + +# }}} + +# }}} + + +# {{{ Dependency handling during transformations + +# {{{ test_fix_parameters_with_dependencies + +def test_fix_parameters_with_dependencies(): + knl = lp.make_kernel( + "{[i,j]: 0 <= i < n and 0 <= j < m}", + """ + <>temp0 = 0.1*i+j {id=stmt0} + <>tsq = temp0**2+i+j {id=stmt1,dep=stmt0} + a[i,j] = 23*tsq + 25*tsq+j {id=stmt2,dep=stmt1} + """) + + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) + + dep_orig = _isl_map_with_marked_dims( + "[n,m] -> {{ [{0}'=0, i', j']->[{0}=1, i, j] : " + "0 <= i,i' < n and 0 <= j,j' < m " + "and i' = i and j' = j" + "}}".format(STATEMENT_VAR_NAME)) + + from copy import deepcopy + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", deepcopy(dep_orig)) + knl = lp.add_dependency_v2(knl, "stmt2", "stmt1", deepcopy(dep_orig)) + + fix_val = 64 + knl = lp.fix_parameters(knl, m=fix_val) + + dep_exp = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i', j']->[{0}=1, i, j] : " + "0 <= i,i' < n and 0 <= j,j' < {1} " + "and i' = i and j' = j" + "}}".format(STATEMENT_VAR_NAME, fix_val)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmt1": {"stmt0": [dep_exp, ]}, + "stmt2": {"stmt1": [dep_exp, ]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + +# }}} + + +# {{{ test_assignment_to_subst_with_dependencies + +def test_assignment_to_subst_with_dependencies(): + knl = lp.make_kernel( + "{[i]: 0 <= i < n}", + """ + <>temp0 = 0.1*i {id=stmt0} + <>tsq = temp0**2 {id=stmt1,dep=stmt0} + a[i] = 23*tsq + 25*tsq {id=stmt2,dep=stmt1} + <>temp3 = 3*tsq {id=stmt3,dep=stmt1} + <>temp4 = 5.5*i {id=stmt4,dep=stmt1} + """) + + # TODO test with multiple subst definition sites + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) + + dep_eq = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i']->[{0}=1, i] : " + "0 <= i,i' < n and i' = i" + "}}".format(STATEMENT_VAR_NAME)) + dep_le = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i']->[{0}=1, i] : " + "0 <= i,i' < n and i' <= i" + "}}".format(STATEMENT_VAR_NAME)) + + from copy import deepcopy + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", deepcopy(dep_le)) + knl = lp.add_dependency_v2(knl, "stmt2", "stmt1", deepcopy(dep_eq)) + knl = lp.add_dependency_v2(knl, "stmt3", "stmt1", deepcopy(dep_eq)) + knl = lp.add_dependency_v2(knl, "stmt4", "stmt1", deepcopy(dep_eq)) + + knl = lp.assignment_to_subst(knl, "tsq") + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmt2": {"stmt0": [dep_le, ]}, + "stmt3": {"stmt0": [dep_le, ]}, + }, + return_unsatisfied=True) + # (stmt4 dep was removed because dependee was removed, but dependee's + # deps were not added to stmt4 because the substitution was not made + # in stmt4) TODO this behavior will change when we propagate deps properly + + assert not unsatisfied_deps + + # Test using 'within' -------------------------------------------------- + + knl = lp.make_kernel( + "{[i]: 0 <= i < n}", + """ + <>temp0 = 0.1*i {id=stmt0} + <>tsq = temp0**2 {id=stmt1,dep=stmt0} + a[i] = 23*tsq + 25*tsq {id=stmt2,dep=stmt1} + <>temp3 = 3*tsq {id=stmt3,dep=stmt1} + <>temp4 = 5.5*i {id=stmt4,dep=stmt1} + <>temp5 = 5.6*tsq*i {id=stmt5,dep=stmt1} + """) + + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) + + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", deepcopy(dep_le)) + knl = lp.add_dependency_v2(knl, "stmt2", "stmt1", deepcopy(dep_eq)) + knl = lp.add_dependency_v2(knl, "stmt3", "stmt1", deepcopy(dep_eq)) + knl = lp.add_dependency_v2(knl, "stmt4", "stmt1", deepcopy(dep_eq)) + knl = lp.add_dependency_v2(knl, "stmt5", "stmt1", deepcopy(dep_eq)) + + knl = lp.assignment_to_subst(knl, "tsq", within="id:stmt2 or id:stmt3") + + # Replacement will not be made in stmt5, so stmt1 will not be removed, + # which means no deps will be removed, and the statements where the replacement + # *was* made (stmt2 and stmt3) will still receive the deps from stmt1 + # TODO this behavior may change when we propagate deps properly + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmt1": {"stmt0": [dep_le, ]}, + "stmt2": { + "stmt0": [dep_le, ], "stmt1": [dep_eq, ]}, + "stmt3": { + "stmt0": [dep_le, ], "stmt1": [dep_eq, ]}, + "stmt4": {"stmt1": [dep_eq, ]}, + "stmt5": {"stmt1": [dep_eq, ]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # test case where subst def is removed, has deps, and + # inames of subst_def don't match subst usage + + knl = lp.make_kernel( + "{[i,j,k,m]: 0 <= i,j,k,m < n}", + """ + for i,j + <>temp0 = 0.1*i {id=stmt0} + end + for k + <>tsq = temp0**2 {id=stmt1,dep=stmt0} + end + for m + <>res = 23*tsq + 25*tsq {id=stmt2,dep=stmt1} + end + """) + knl = lp.add_and_infer_dtypes(knl, {"temp0,tsq,res": np.float32}) + + dep_1_on_0 = make_dep_map( + "[n] -> { [i', j']->[k] : 0 <= i',j',k < n }", self_dep=False) + dep_2_on_1 = make_dep_map( + "[n] -> { [k']->[m] : 0 <= k',m < n }", self_dep=False) + + from copy import deepcopy + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", deepcopy(dep_1_on_0)) + knl = lp.add_dependency_v2(knl, "stmt2", "stmt1", deepcopy(dep_2_on_1)) + + knl = lp.assignment_to_subst(knl, "tsq") + + dep_exp = make_dep_map( + "[n] -> { [i', j']->[m] : 0 <= i',j',m < n }", self_dep=False) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmt2": {"stmt0": [dep_exp, ]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + +# }}} + + +# {{{ test_duplicate_inames_with_dependencies + +def test_duplicate_inames_with_dependencies(): + + knl = lp.make_kernel( + "{[i,j]: 0 <= i,j < n}", + """ + b[i,j] = a[i,j] {id=stmtb} + c[i,j] = a[i,j] {id=stmtc,dep=stmtb} + """) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) + + dep_eq = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i', j']->[{0}=1, i, j] : " + "0 <= i,i',j,j' < n and i' = i and j' = j" + "}}".format(STATEMENT_VAR_NAME)) + + # Create dep stmtb->stmtc + knl = lp.add_dependency_v2(knl, "stmtc", "stmtb", dep_eq) + + ref_knl = knl + + # {{{ Duplicate j within stmtc + + knl = lp.duplicate_inames(knl, ["j"], within="id:stmtc", new_inames=["j_new"]) + + dep_exp = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i', j']->[{0}=1, i, j_new] : " + "0 <= i,i',j_new,j' < n and i' = i and j' = j_new" + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmtc": {"stmtb": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + # {{{ Duplicate j within stmtb + + knl = ref_knl + knl = lp.duplicate_inames(knl, ["j"], within="id:stmtb", new_inames=["j_new"]) + + dep_exp = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i', j_new']->[{0}=1, i, j] : " + "0 <= i,i',j,j_new' < n and i' = i and j_new' = j" + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmtc": {"stmtb": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + # {{{ Duplicate j within stmtb and stmtc + + knl = ref_knl + knl = lp.duplicate_inames( + knl, ["j"], within="id:stmtb or id:stmtc", new_inames=["j_new"]) + + dep_exp = _isl_map_with_marked_dims( + "[n] -> {{ [{0}'=0, i', j_new']->[{0}=1, i, j_new] : " + "0 <= i,i',j_new,j_new' < n and i' = i and j_new' = j_new" + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmtc": {"stmtb": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + +# }}} + + +# {{{ test_rename_inames_with_dependencies + +def test_rename_inames_with_dependencies(): + # When rename_iname is called and the new iname + # *doesn't* already exist, then duplicate_inames is called, + # and we test that elsewhere. Here we test the case where + # rename_iname is called and the new iname already exists. + + knl = lp.make_kernel( + "{[i,j,m,j_new]: 0 <= i,j,m,j_new < n}", + """ + b[i,j] = a[i,j] {id=stmtb} + c[i,j] = a[i,j] {id=stmtc,dep=stmtb} + e[i,j_new] = 1.1 + d[m] = 5.5 {id=stmtd,dep=stmtc} + """) + knl = lp.add_and_infer_dtypes(knl, {"a,d": np.float32}) + + dep_c_on_b = make_dep_map( + "[n] -> { [i', j']->[i, j] : 0 <= i,i',j,j' < n and i' = i and j' = j }", + self_dep=False) + dep_c_on_c = make_dep_map( + "[n] -> { [i', j']->[i, j] : 0 <= i,i',j,j' < n and i' < i and j' < j }", + self_dep=True) + dep_d_on_c = make_dep_map( + "[n] -> { [i', j']->[m] : 0 <= m,i',j' < n }", + self_dep=False) + + # Create dep stmtb->stmtc + knl = lp.add_dependency_v2(knl, "stmtc", "stmtb", dep_c_on_b) + knl = lp.add_dependency_v2(knl, "stmtc", "stmtc", dep_c_on_c) + knl = lp.add_dependency_v2(knl, "stmtd", "stmtc", dep_d_on_c) + + # Rename j within stmtc + + knl = lp.rename_iname( + knl, "j", "j_new", within="id:stmtc", existing_ok=True) + + dep_c_on_b_exp = make_dep_map( + "[n] -> { [i', j']->[i, j_new] : " + "0 <= i,i',j_new,j' < n and i' = i and j' = j_new}", + self_dep=False) + dep_c_on_c_exp = make_dep_map( + "[n] -> { [i', j_new']->[i, j_new] : " + "0 <= i,i',j_new,j_new' < n and i' < i and j_new' < j_new }", + self_dep=True) + dep_d_on_c_exp = make_dep_map( + "[n] -> { [i', j_new']->[m] : 0 <= m,i',j_new' < n }", + self_dep=False) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmtc": {"stmtb": [dep_c_on_b_exp, ], "stmtc": [dep_c_on_c_exp, ]}, + "stmtd": {"stmtc": [dep_d_on_c_exp, ]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + +# }}} + + +# {{{ test_split_iname_with_dependencies + +def test_split_iname_with_dependencies(): + knl = lp.make_kernel( + "{[i]: 0<=i { %s : 0 <= i < p and i' = i }" + % (dep_inout_space_str)) + + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", dep_satisfied) + knl = lp.split_iname(knl, "i", 32) + + dep_exp = _isl_map_with_marked_dims( + "[p] -> {{ [{0}'=0, i_outer', i_inner'] -> [{0}=1, i_outer, i_inner] : " + "0 <= i_inner, i_inner' < 32" # new bounds + " and 0 <= 32*i_outer + i_inner < p" # transformed bounds (0 <= i < p) + " and 0 <= 32*i_outer' + i_inner' < p" # transformed bounds (0 <= i' < p) + " and i_inner + 32*i_outer = 32*i_outer' + i_inner'" # i = i' + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt1": {"stmt0": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + # {{{ Split iname within stmt1 and make sure dep is correct + + knl = deepcopy(ref_knl) + + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", dep_satisfied) + knl = lp.split_iname(knl, "i", 32, within="id:stmt1") + + dep_exp = _isl_map_with_marked_dims( + "[p] -> {{ [{0}'=0, i'] -> [{0}=1, i_outer, i_inner] : " + "0 <= i_inner < 32" # new bounds + " and 0 <= 32*i_outer + i_inner < p" # transformed bounds (0 <= i < p) + " and 0 <= i' < p" # original bounds + " and i_inner + 32*i_outer = i'" # transform {i = i'} + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt1": {"stmt0": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + # {{{ Split iname within stmt0 and make sure dep is correct + + knl = deepcopy(ref_knl) + + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", dep_satisfied) + knl = lp.split_iname(knl, "i", 32, within="id:stmt0") + + dep_exp = _isl_map_with_marked_dims( + "[p] -> {{ [{0}'=0, i_outer', i_inner'] -> [{0}=1, i] : " + "0 <= i_inner' < 32" # new bounds + " and 0 <= i < p" # original bounds + " and 0 <= 32*i_outer' + i_inner' < p" # transformed bounds (0 <= i' < p) + " and i = 32*i_outer' + i_inner'" # transform {i = i'} + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt1": {"stmt0": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + # {{{ Check dep that should not be satisfied + + knl = deepcopy(ref_knl) + + dep_unsatisfied = _isl_map_with_marked_dims( + "[p] -> { %s : 0 <= i < p and i' = i + 1 }" + % (dep_inout_space_str)) + + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", dep_unsatisfied) + knl = lp.split_iname(knl, "i", 32) + + dep_exp = _isl_map_with_marked_dims( + "[p] -> {{ [{0}'=0, i_outer', i_inner'] -> [{0}=1, i_outer, i_inner] : " + "0 <= i_inner, i_inner' < 32" # new bounds + " and 0 <= 32*i_outer + i_inner < p" # transformed bounds (0 <= i < p) + " and 0 <= 32*i_outer' + i_inner' - 1 < p" # trans. bounds (0 <= i'-1 < p) + " and i_inner + 32*i_outer + 1 = 32*i_outer' + i_inner'" # i' = i + 1 + "}}".format(STATEMENT_VAR_NAME)) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt1": {"stmt0": [dep_exp, ]}}, + return_unsatisfied=True) + + assert len(unsatisfied_deps) == 1 + + # }}} + + # {{{ Deps that should be satisfied after gratuitous splitting + + knl = lp.make_kernel( + "{[i,j,k,m]: 0<=i,j,k,m { %s : %s and i' = i and k' = k}" + % (dep_ik_space_str, ik_bounds_str)) + dep_stmt1_on_stmt0_lt = _isl_map_with_marked_dims( + "[p] -> { %s : %s and i' < i and k' < k}" + % (dep_ik_space_str, ik_bounds_str)) + dep_stmt3_on_stmt2_eq = _isl_map_with_marked_dims( + "[p] -> { %s : %s and i' = i and k' = k and j' = j and m' = m}" + % (dep_ijkm_space_str, ijkm_bounds_str)) + + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", dep_stmt1_on_stmt0_eq) + knl = lp.add_dependency_v2(knl, "stmt1", "stmt0", dep_stmt1_on_stmt0_lt) + knl = lp.add_dependency_v2(knl, "stmt3", "stmt2", dep_stmt3_on_stmt2_eq) + + # Gratuitous splitting + knl = lp.split_iname(knl, "i", 64) + knl = lp.split_iname(knl, "j", 64) + knl = lp.split_iname(knl, "k", 64) + knl = lp.split_iname(knl, "m", 64) + knl = lp.split_iname(knl, "i_inner", 8) + knl = lp.split_iname(knl, "j_inner", 8) + knl = lp.split_iname(knl, "k_inner", 8) + knl = lp.split_iname(knl, "m_inner", 8) + knl = lp.split_iname(knl, "i_outer", 4) + knl = lp.split_iname(knl, "j_outer", 4) + knl = lp.split_iname(knl, "k_outer", 4) + knl = lp.split_iname(knl, "m_outer", 4) + + # Get a linearization + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + unsatisfied_deps = lp.find_unsatisfied_dependencies( + proc_knl, lin_items) + + assert not unsatisfied_deps + + # }}} + +# }}} + + +# {{{ test map domain with dependencies + +# {{{ test_map_domain_with_only_partial_dep_pair_affected + +def test_map_domain_with_only_partial_dep_pair_affected(): + + # Split an iname using map_domain, and have (misaligned) deps + # where only the dependee uses the split iname + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, t]: 0 <= x < nx and 0 <= t < nt}", + "[ni] -> {[i]: 0 <= i < ni}", + ], + """ + a[x,t] = b[x,t] {id=stmta} + c[x,t] = d[x,t] {id=stmtc,dep=stmta} + e[i] = f[i] {id=stmte,dep=stmtc} + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b,d,f": np.float32}) + + # }}} + + # {{{ Add dependencies + + dep_c_on_a = _isl_map_with_marked_dims( + "[nx, nt] -> {{" + "[{0}' = 0, x', t'] -> [{0} = 1, x, t] : " + "0 <= x,x' < nx and 0 <= t,t' < nt and " + "t' <= t and x' <= x" + "}}".format(STATEMENT_VAR_NAME)) + + knl = lp.add_dependency_v2( + knl, "stmtc", "stmta", dep_c_on_a) + + # Intentionally make order of x and t different from transform_map below + # to test alignment steps in map_domain + dep_e_on_c = _isl_map_with_marked_dims( + "[nx, nt, ni] -> {{" + "[{0}' = 0, t', x'] -> [{0} = 1, i] : " + "0 <= x' < nx and 0 <= t' < nt and 0 <= i < ni" + "}}".format(STATEMENT_VAR_NAME)) + + knl = lp.add_dependency_v2( + knl, "stmte", "stmtc", dep_e_on_c) + + # }}} + + # {{{ Apply domain change mapping + + # Create map_domain mapping: + import islpy as isl + transform_map = isl.BasicMap( + "[nt] -> {[t] -> [t_outer, t_inner]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt}") + + # Call map_domain to transform kernel + knl = lp.map_domain(knl, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + knl = lp.prioritize_loops(knl, "x, t_outer, t_inner") + + # }}} + + # {{{ Create expected dependencies + + dep_c_on_a_exp = _isl_map_with_marked_dims( + "[nx, nt] -> {{" + "[{0}' = 0, x', t_outer', t_inner'] -> [{0} = 1, x, t_outer, t_inner] : " + "0 <= x,x' < nx and " # old bounds + "0 <= t_inner,t_inner' < 32 and " # new bounds + "0 <= 32*t_outer + t_inner < nt and " # new bounds + "0 <= 32*t_outer' + t_inner' < nt and " # new bounds + "32*t_outer' + t_inner' <= 32*t_outer + t_inner and " # new constraint t'<=t + "x' <= x" # old constraint + "}}".format(STATEMENT_VAR_NAME)) + + dep_e_on_c_exp = _isl_map_with_marked_dims( + "[nx, nt, ni] -> {{" + "[{0}' = 0, x', t_outer', t_inner'] -> [{0} = 1, i] : " + "0 <= x' < nx and 0 <= i < ni and " # old bounds + "0 <= t_inner' < 32 and " # new bounds + "0 <= 32*t_outer' + t_inner' < nt" # new bounds + "}}".format(STATEMENT_VAR_NAME)) + + # }}} + + # {{{ Make sure deps are correct and satisfied + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmtc": { + "stmta": [dep_c_on_a_exp, ]}, + "stmte": { + "stmtc": [dep_e_on_c_exp, ]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + +# }}} + + +# {{{ test_map_domain_with_inames_missing_in_transform_map + +def test_map_domain_with_inames_missing_in_transform_map(): + + # Make sure map_domain updates deps correctly when the mapping doesn't + # include all the dims in the domain. + + # {{{ Make kernel + + knl = lp.make_kernel( + "[nx,nt] -> {[x, y, z, t]: 0 <= x,y,z < nx and 0 <= t < nt}", + """ + a[y,x,t,z] = b[y,x,t,z] {id=stmta} + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b": np.float32}) + + # }}} + + # {{{ Create dependency + + dep = _isl_map_with_marked_dims( + "[nx, nt] -> {{" + "[{0}' = 0, x', y', z', t'] -> [{0} = 0, x, y, z, t] : " + "0 <= x,y,z,x',y',z' < nx and 0 <= t,t' < nt and " + "t' < t and x' < x and y' < y and z' < z" + "}}".format(STATEMENT_VAR_NAME)) + + knl = lp.add_dependency_v2(knl, "stmta", "stmta", dep) + + # }}} + + # {{{ Apply domain change mapping + + # Create map_domain mapping that only includes t and y + # (x and z should be unaffected) + transform_map = isl.BasicMap( + "[nx,nt] -> {[t, y] -> [t_outer, t_inner, y_new]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt and " + "y = y_new" + "}") + + # Call map_domain to transform kernel + knl = lp.map_domain(knl, transform_map) + + # }}} + + # {{{ Create expected dependency after transformation + + dep_exp = _isl_map_with_marked_dims( + "[nx, nt] -> {{" + "[{0}' = 0, x', y_new', z', t_outer', t_inner'] -> " + "[{0} = 0, x, y_new, z, t_outer, t_inner] : " + "0 <= x,z,x',z' < nx " # old bounds + "and 0 <= t_inner,t_inner' < 32 and 0 <= y_new,y_new' < nx " # new bounds + "and 0 <= 32*t_outer + t_inner < nt " # new bounds + "and 0 <= 32*t_outer' + t_inner' < nt " # new bounds + "and x' < x and z' < z " # old constraints + "and y_new' < y_new " # new constraint + "and 32*t_outer' + t_inner' < 32*t_outer + t_inner" # new constraint + "}}".format(STATEMENT_VAR_NAME)) + + # }}} + + # {{{ Make sure deps are correct and satisfied + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmta": {"stmta": [dep_exp, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + +# }}} + + +# {{{ test_map_domain_with_stencil_dependencies + +def test_map_domain_with_stencil_dependencies(): + + # {{{ Make kernel + + knl = lp.make_kernel( + "[nx,nt] -> {[ix, it]: 1<=ix {{" + "[{0}' = 0, ix', it'] -> [{0} = 0, ix, it = 1 + it'] : " + "0 < ix' <= -2 + nx and 0 <= it' <= -2 + nt and ix >= -1 + ix' and " + "0 < ix <= 1 + ix' and ix <= -2 + nx; " + "[statement' = 0, ix', it'] -> [statement = 0, ix = ix', it = 2 + it'] : " + "0 < ix' <= -2 + nx and 0 <= it' <= -3 + nt" + "}}".format(STATEMENT_VAR_NAME)) + + knl = lp.add_dependency_v2( + knl, stmt_after, stmt_before, dep_map) + + # }}} + + # {{{ Check deps *without* map_domain transformation + + ref_knl = knl + + # Prioritize loops + knl = lp.prioritize_loops(knl, ("it", "ix")) # valid + #knl = lp.prioritize_loops(knl, ("ix", "it")) # invalid + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt": {"stmt": [dep_map, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + + # {{{ Check dependency after domain change mapping + + knl = ref_knl # loop priority goes away, deps stay + + # Create map_domain mapping: + transform_map = isl.BasicMap( + "[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") + + # Call map_domain to transform kernel + knl = lp.map_domain(knl, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx") + + # {{{ Create expected dependency + + # Prep transform map to be applied to dependency + from loopy.schedule.checker.utils import ( + insert_and_name_isl_dims, + add_eq_isl_constraint_from_names, + append_mark_to_isl_map_var_names, + ) + dt = isl.dim_type + # Insert 'statement' dim into transform map + transform_map = insert_and_name_isl_dims( + transform_map, dt.in_, [STATEMENT_VAR_NAME+BEFORE_MARK], 0) + transform_map = insert_and_name_isl_dims( + transform_map, dt.out, [STATEMENT_VAR_NAME], 0) + # Add stmt = stmt' constraint + transform_map = add_eq_isl_constraint_from_names( + transform_map, STATEMENT_VAR_NAME, STATEMENT_VAR_NAME+BEFORE_MARK) + + # Apply transform map to dependency + mapped_dep_map = dep_map.apply_range(transform_map).apply_domain(transform_map) + mapped_dep_map = append_mark_to_isl_map_var_names( + mapped_dep_map, dt.in_, BEFORE_MARK) + + # }}} + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt": {"stmt": [mapped_dep_map, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + # }}} + +# }}} + +# }}} + + +# {{{ test_add_prefetch_with_dependencies + +# FIXME handle deps during prefetch + +''' + +def test_add_prefetch_with_dependencies(): + + lp.set_caching_enabled(False) + knl = lp.make_kernel( + "[p] -> { [i,j,k,m] : 0 <= i,j < p and 0 <= k,m < 16}", + """ + for i,j,k,m + a[i+1,j+1,k+1,m+1] = a[i,j,k,m] {id=stmt} + end + """, + assumptions="p >= 1", + lang_version=(2018, 2) + ) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) + + dep_init = make_dep_map( + "{ [i',j',k',m'] -> [i,j,k,m] : " + "i' + 1 = i and j' + 1 = j and k' + 1 = k and m' + 1 = m }", + self_dep=True, knl_with_domains=knl) + knl = lp.add_dependency_v2(knl, "stmt", "stmt", dep_init) + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + {"stmt": {"stmt": [dep_init, ]}}, + return_unsatisfied=True) + + assert not unsatisfied_deps + + knl = lp.add_prefetch( + knl, "a", sweep_inames=["k", "m"], + fetch_outer_inames=frozenset({"i", "j"}), + # dim_arg_names=["k_fetch", "m_fetch"], # TODO not sure why these don't work + ) + + # create expected deps + dep_stmt_on_fetch_exp = make_dep_map( + "{ [i',j',a_dim_2',a_dim_3'] -> [i,j,k,m] : " + "i' = i and j' = j }", + knl_with_domains=knl) + dep_fetch_on_stmt_exp = make_dep_map( + "{ [i',j',k',m'] -> [i,j,a_dim_2,a_dim_3] : " + "i' + 1 = i and j' + 1 = j " + "and 0 <= k',m' < 15 " + "}", + knl_with_domains=knl) + # (make_dep_map will set k',m' upper bound to 16, so add manually^) + + # Why is this necessary to avoid dependency cycle? + knl.id_to_insn["a_fetch_rule"].depends_on_is_final = True + + # Compare deps and make sure they are satisfied + unsatisfied_deps = _compare_dependencies( + knl, + { + "stmt": {"stmt": [dep_init], "a_fetch_rule": [dep_stmt_on_fetch_exp]}, + "a_fetch_rule": {"stmt": [dep_fetch_on_stmt_exp]}, + }, + return_unsatisfied=True) + + assert not unsatisfied_deps + +''' + +# }}} + +# }}} + + +# {{{ Dependency handling during linearization + +# {{{ test_filtering_deps_by_same + +def test_filtering_deps_by_same(): + + # Make a kernel (just need something that can carry deps) + knl = lp.make_kernel( + "{[i,j,k,m] : 0 <= i,j,k,m < n}", + """ + a[i,j,k,m] = 5 {id=s5} + a[i,j,k,m] = 4 {id=s4} + a[i,j,k,m] = 3 {id=s3} + a[i,j,k,m] = 2 {id=s2} + a[i,j,k,m] = 1 {id=s1} + """) + knl = lp.add_and_infer_dtypes(knl, {"a": np.float32}) + knl = lp.tag_inames(knl, "m:l.0") + + # Make some deps + + def _dep_with_condition(stmt_before, stmt_after, cond): + sid_after = 0 if stmt_before == stmt_after else 1 + return _isl_map_with_marked_dims( + "[n] -> {{" + "[{0}'=0, i', j', k', m'] -> [{0}={1}, i, j, k, m] : " + "0 <= i,j,k,m,i',j',k',m' < n and {2}" + "}}".format( + STATEMENT_VAR_NAME, sid_after, cond)) + + dep_s2_on_s1_1 = _dep_with_condition(2, 1, "i'< i and j'<=j and k'=k and m't5 = 5 {id=s5} + <>t3 = 3 {id=s3} + <>t4 = 4 {id=s4} + <>t1 = 1 {id=s1} + <>t2 = 2 {id=s2} + end + """) + knl = lp.tag_inames(knl, "m:l.0") + + stmt_ids_ordered_desired = ["s1", "s2", "s3", "s4", "s5"] + + # {{{ Add some deps + + def _dep_with_condition(stmt_before, stmt_after, cond): + sid_after = 0 if stmt_before == stmt_after else 1 + return _isl_map_with_marked_dims( + "[n] -> {{" + "[{0}'=0, i', j', k', m'] -> [{0}={1}, i, j, k, m] : " + "0 <= i,j,k,m,i',j',k',m' < n and {2}" + "}}".format( + STATEMENT_VAR_NAME, sid_after, cond)) + + # Should NOT create an edge: + dep_s2_on_s1_1 = _dep_with_condition(2, 1, "i'< i and j'<=j and k' =k and m'=m") + # Should create an edge: + dep_s2_on_s1_2 = _dep_with_condition(2, 1, "i'<=i and j'<=j and k' =k and m'=m") + # Should NOT create an edge: + dep_s2_on_s2_1 = _dep_with_condition(2, 2, "i'< i and j'<=j and k' =k and m'=m") + # Should NOT create an edge: + dep_s2_on_s2_2 = _dep_with_condition(2, 2, "i'<=i and j'<=j and k'< k and m'=m") + # Should create an edge: + dep_s3_on_s2_1 = _dep_with_condition(3, 2, "i'<=i and j'<=j and k' =k and m'=m") + # Should create an edge: + dep_s4_on_s3_1 = _dep_with_condition(4, 3, "i'<=i and j'<=j and k' =k and m'=m") + # Should create an edge: + dep_s5_on_s4_1 = _dep_with_condition(5, 4, "i' =i and j' =j and k' =k and m'=m") + + knl = lp.add_dependency_v2(knl, "s2", "s1", dep_s2_on_s1_1) + knl = lp.add_dependency_v2(knl, "s2", "s1", dep_s2_on_s1_2) + knl = lp.add_dependency_v2(knl, "s2", "s2", dep_s2_on_s2_1) + knl = lp.add_dependency_v2(knl, "s2", "s2", dep_s2_on_s2_2) + knl = lp.add_dependency_v2(knl, "s3", "s2", dep_s3_on_s2_1) + knl = lp.add_dependency_v2(knl, "s4", "s3", dep_s4_on_s3_1) + knl = lp.add_dependency_v2(knl, "s5", "s4", dep_s5_on_s4_1) + + # }}} + + # {{{ Test filteringn of deps by intersection with SAME + + from loopy.schedule.checker.dependency import ( + filter_deps_by_intersection_with_SAME, + ) + filtered_depends_on_dict = filter_deps_by_intersection_with_SAME( + knl["loopy_kernel"]) + + # Make sure filtered edges are correct + + # (m is concurrent so shouldn't matter) + depends_on_dict_expected = { + "s2": set(["s1"]), + "s3": set(["s2"]), + "s4": set(["s3"]), + "s5": set(["s4"]), + } + + assert filtered_depends_on_dict == depends_on_dict_expected + + # }}} + + # {{{ Get a linearization WITHOUT using the simplified dep graph + + knl = lp.set_options(knl, use_dependencies_v2=False) + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Check stmt order (should be wrong) + stmt_ids_ordered = _get_runinstruction_ids_from_linearization(lin_items) + assert stmt_ids_ordered != stmt_ids_ordered_desired + + # Check dep satisfaction (should not all be satisfied) + unsatisfied_deps = lp.find_unsatisfied_dependencies(proc_knl, lin_items) + assert unsatisfied_deps + + # }}} + + # {{{ Get a linearization using the simplified dep graph + + knl = lp.set_options(knl, use_dependencies_v2=True) + lin_items, proc_knl, lin_knl = _process_and_linearize(knl) + + # Check stmt order + stmt_ids_ordered = _get_runinstruction_ids_from_linearization(lin_items) + assert stmt_ids_ordered == stmt_ids_ordered_desired + + # Check dep satisfaction + unsatisfied_deps = lp.find_unsatisfied_dependencies(proc_knl, lin_items) + assert not unsatisfied_deps + + # }}} + +# }}} + +# }}} + +# }}} + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker diff --git a/test/test_nest_constraints.py b/test/test_nest_constraints.py new file mode 100644 index 000000000..a931e9e72 --- /dev/null +++ b/test/test_nest_constraints.py @@ -0,0 +1,1160 @@ +__copyright__ = "Copyright (C) 2021 James Stevens" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +import loopy as lp +import numpy as np +import pyopencl as cl +from loopy import preprocess_kernel, get_one_linearized_kernel + +import logging +logger = logging.getLogger(__name__) + +try: + import faulthandler +except ImportError: + pass +else: + faulthandler.enable() + +from pyopencl.tools import pytest_generate_tests_for_pyopencl \ + as pytest_generate_tests + +__all__ = [ + "pytest_generate_tests", + "cl" # "cl.create_some_context" + ] + + +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa + + +# {{{ Helper functions + +def _process_and_linearize(prog, knl_name="loopy_kernel"): + # Return linearized kernel + proc_prog = preprocess_kernel(prog) + lin_prog = get_one_linearized_kernel( + proc_prog[knl_name], proc_prog.callables_table) + return lin_prog + + +def _linearize_and_get_nestings(prog, knl_name="loopy_kernel"): + from loopy.transform.iname import get_iname_nestings + lin_knl = _process_and_linearize(prog, knl_name) + return get_iname_nestings(lin_knl.linearization) + +# }}} + + +# {{{ test_loop_constraint_string_parsing + +def test_loop_constraint_string_parsing(): + ref_knl = lp.make_kernel( + "{ [g,h,i,j,k,xx]: 0<=g,h,i,j,k,xx 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker diff --git a/test/test_transform.py b/test/test_transform.py index 51e7c2636..feda064e3 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -574,7 +574,6 @@ def test_split_iname_only_if_in_within(): def test_nested_substs_in_insns(ctx_factory): ctx = ctx_factory() - import loopy as lp ref_prg = lp.make_kernel( "{[i]: 0<=i<10}", @@ -594,6 +593,232 @@ def test_nested_substs_in_insns(ctx_factory): lp.auto_test_vs_ref(ref_prg, ctx, t_unit) +# {{{ test_map_domain_vs_split_iname + +def test_map_domain_vs_split_iname(): + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, t]: 0 <= x < nx and 0 <= t < nt}", + "[ni] -> {[i]: 0 <= i < ni}", + ], + """ + a[x,t] = b[x,t] {id=stmta} + c[x,t] = d[x,t] {id=stmtc} + e[i] = f[i] + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b,d,f": np.float32}) + ref_knl = knl + + # }}} + + # {{{ Apply domain change mapping + + knl_map_dom = ref_knl # loop priority goes away, deps stay + + # Create map_domain mapping: + import islpy as isl + transform_map = isl.BasicMap( + "[nt] -> {[t] -> [t_outer, t_inner]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt}") + + # Call map_domain to transform kernel + knl_map_dom = lp.map_domain(knl_map_dom, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + knl_map_dom = lp.prioritize_loops(knl_map_dom, "x, t_outer, t_inner") + + # Get a linearization + proc_knl_map_dom = lp.preprocess_kernel(knl_map_dom) + lin_knl_map_dom = lp.get_one_linearized_kernel( + proc_knl_map_dom["loopy_kernel"], proc_knl_map_dom.callables_table) + + # }}} + + # {{{ Split iname and see if we get the same result + + knl_split_iname = ref_knl + knl_split_iname = lp.split_iname(knl_split_iname, "t", 32) + knl_split_iname = lp.prioritize_loops(knl_split_iname, "x, t_outer, t_inner") + proc_knl_split_iname = lp.preprocess_kernel(knl_split_iname) + lin_knl_split_iname = lp.get_one_linearized_kernel( + proc_knl_split_iname["loopy_kernel"], proc_knl_split_iname.callables_table) + + from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + ) + for d_map_domain, d_split_iname in zip( + knl_map_dom["loopy_kernel"].domains, + knl_split_iname["loopy_kernel"].domains): + d_map_domain_aligned = ensure_dim_names_match_and_align( + d_map_domain, d_split_iname) + assert d_map_domain_aligned == d_split_iname + + for litem_map_domain, litem_split_iname in zip( + lin_knl_map_dom.linearization, lin_knl_split_iname.linearization): + assert litem_map_domain == litem_split_iname + + # Can't easily compare instructions because equivalent subscript + # expressions may have different orders + + # }}} + +# }}} + + +# {{{ test_map_domain_with_transform_map_missing_dims + +def test_map_domain_with_transform_map_missing_dims(): + # Make sure map_domain works correctly when the mapping doesn't include + # all the dims in the domain. + + # {{{ Make kernel + + knl = lp.make_kernel( + [ + "[nx,nt] -> {[x, y, z, t]: 0 <= x,y,z < nx and 0 <= t < nt}", + ], + """ + a[y,x,t,z] = b[y,x,t,z] {id=stmta} + """, + lang_version=(2018, 2), + ) + knl = lp.add_and_infer_dtypes(knl, {"b": np.float32}) + ref_knl = knl + + # }}} + + # {{{ Apply domain change mapping + + knl_map_dom = ref_knl # loop priority goes away, deps stay + + # Create map_domain mapping that only includes t and y + # (x and z should be unaffected) + import islpy as isl + transform_map = isl.BasicMap( + "[nx,nt] -> {[t, y] -> [t_outer, t_inner, y_new]: " + "0 <= t_inner < 32 and " + "32*t_outer + t_inner = t and " + "0 <= 32*t_outer + t_inner < nt and " + "y = y_new" + "}") + + # Call map_domain to transform kernel + knl_map_dom = lp.map_domain(knl_map_dom, transform_map) + + # Prioritize loops (prio should eventually be updated in map_domain?) + try: + # Use constrain_loop_nesting if it's available + desired_prio = "x, t_outer, t_inner, z, y_new" + knl_map_dom = lp.constrain_loop_nesting(knl_map_dom, desired_prio) + except AttributeError: + # For some reason, prioritize_loops can't handle the ordering above + # when linearizing knl_split_iname below + desired_prio = "z, y_new, x, t_outer, t_inner" + knl_map_dom = lp.prioritize_loops(knl_map_dom, desired_prio) + + # Get a linearization + proc_knl_map_dom = lp.preprocess_kernel(knl_map_dom) + lin_knl_map_dom = lp.get_one_linearized_kernel( + proc_knl_map_dom["loopy_kernel"], proc_knl_map_dom.callables_table) + + # }}} + + # {{{ Split iname and see if we get the same result + + knl_split_iname = ref_knl + knl_split_iname = lp.split_iname(knl_split_iname, "t", 32) + knl_split_iname = lp.rename_iname(knl_split_iname, "y", "y_new") + try: + # Use constrain_loop_nesting if it's available + knl_split_iname = lp.constrain_loop_nesting(knl_split_iname, desired_prio) + except AttributeError: + knl_split_iname = lp.prioritize_loops(knl_split_iname, desired_prio) + proc_knl_split_iname = lp.preprocess_kernel(knl_split_iname) + lin_knl_split_iname = lp.get_one_linearized_kernel( + proc_knl_split_iname["loopy_kernel"], proc_knl_split_iname.callables_table) + + from loopy.schedule.checker.utils import ( + ensure_dim_names_match_and_align, + ) + for d_map_domain, d_split_iname in zip( + knl_map_dom["loopy_kernel"].domains, + knl_split_iname["loopy_kernel"].domains): + d_map_domain_aligned = ensure_dim_names_match_and_align( + d_map_domain, d_split_iname) + assert d_map_domain_aligned == d_split_iname + + for litem_map_domain, litem_split_iname in zip( + lin_knl_map_dom.linearization, lin_knl_split_iname.linearization): + assert litem_map_domain == litem_split_iname + + # Can't easily compare instructions because equivalent subscript + # expressions may have different orders + + # }}} + +# }}} + + +def test_diamond_tiling(ctx_factory, interactive=False): + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + + ref_knl = lp.make_kernel( + "[nx,nt] -> {[ix, it]: 1<=ix {[ix, it] -> [tx, tt, tparity, itt, itx]: " + "16*(tx - tt) + itx - itt = ix - it and " + "16*(tx + tt + tparity) + itt + itx = ix + it and " + "0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}") + knl = lp.map_domain(knl_for_transform, m) + knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx") + + if interactive: + nx = 43 + u = np.zeros((nx, 200)) + x = np.linspace(-1, 1, nx) + dx = x[1] - x[0] + u[:, 0] = u[:, 1] = np.exp(-100*x**2) + + u_dev = cl.array.to_device(queue, u) + knl(queue, u=u_dev, dx=dx, dt=dx) + + u = u_dev.get() + import matplotlib.pyplot as plt + plt.imshow(u.T) + plt.show() + else: + types = {"dt,dx,u": np.float64} + knl = lp.add_and_infer_dtypes(knl, types) + ref_knl = lp.add_and_infer_dtypes(ref_knl, types) + + lp.auto_test_vs_ref(ref_knl, ctx, knl, + parameters={ + "nx": 200, "nt": 300, + "dx": 1, "dt": 1 + }) + + def test_extract_subst_with_iname_deps_in_templ(ctx_factory): knl = lp.make_kernel( "{[i, j, k]: 0<=i<100 and 0<=j,k<5}",