Skip to content
4 changes: 2 additions & 2 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
affine_map_inames, find_unused_axis_tag,
make_reduction_inames_unique,
has_schedulable_iname_nesting, get_iname_duplication_options,
add_inames_to_insn, add_inames_for_unused_hw_axes)
add_inames_to_insn, add_inames_for_unused_hw_axes, map_domain)

from loopy.transform.instruction import (
find_instructions, map_instructions,
Expand Down Expand Up @@ -192,7 +192,7 @@
"affine_map_inames", "find_unused_axis_tag",
"make_reduction_inames_unique",
"has_schedulable_iname_nesting", "get_iname_duplication_options",
"add_inames_to_insn", "add_inames_for_unused_hw_axes",
"add_inames_to_insn", "add_inames_for_unused_hw_axes", "map_domain",

"add_prefetch", "change_arg_to_image",
"tag_array_axes", "tag_data_axes",
Expand Down
259 changes: 259 additions & 0 deletions loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,265 @@ def add_inames_to_insn(kernel, inames, insn_match):
# }}}


# {{{ map_domain

class _MapDomainMapper(RuleAwareIdentityMapper):
def __init__(self, rule_mapping_context, within, new_inames, substitutions):
super(_MapDomainMapper, self).__init__(rule_mapping_context)

self.within = within

self.old_inames = frozenset(substitutions)
self.new_inames = new_inames

self.substitutions = substitutions

def map_reduction(self, expr, expn_state):
red_overlap = frozenset(expr.inames) & self.old_inames
arg_ctx_overlap = frozenset(expn_state.arg_context) & self.old_inames
if (red_overlap
and self.within(
expn_state.kernel,
expn_state.instruction)):
if len(red_overlap) != len(self.old_inames):
raise LoopyError("reduction '%s' involves a part "
"of the map domain inames. Reductions must "
"either involve all or none of the map domain "
"inames." % str(expr))

if arg_ctx_overlap:
if arg_ctx_overlap == red_overlap:
# All variables are shadowed by context, that's OK.
return super(_MapDomainMapper, self).map_reduction(
expr, expn_state)
else:
raise LoopyError("reduction '%s' has"
"some of the reduction variables affected "
"by the map_domain shadowed by context. "
"Either all or none must be shadowed."
% str(expr))

new_inames = list(expr.inames)
for old_iname in self.old_inames:
new_inames.remove(old_iname)
new_inames.extend(self.new_inames)

from loopy.symbolic import Reduction
return Reduction(expr.operation, tuple(new_inames),
self.rec(expr.expr, expn_state),
expr.allow_simultaneous)
else:
return super(_MapDomainMapper, self).map_reduction(expr, expn_state)

def map_variable(self, expr, expn_state):
if (expr.name in self.old_inames
and expr.name not in expn_state.arg_context
and self.within(
expn_state.kernel,
expn_state.instruction)):
return self.substitutions[expr.name]
else:
return super(_MapDomainMapper, self).map_variable(expr, expn_state)


def _find_aff_subst_from_map(iname, isl_map):
if not isinstance(isl_map, isl.BasicMap):
raise RuntimeError("isl_map must be a BasicMap")

dt, dim_idx = isl_map.get_var_dict()[iname]

assert dt == dim_type.in_

# Force isl to solve for only this iname on its side of the map, by
# projecting out all other "in" variables.
isl_map = isl_map.project_out(dt, dim_idx+1, isl_map.dim(dt)-(dim_idx+1))
isl_map = isl_map.project_out(dt, 0, dim_idx)
dim_idx = 0

# Convert map to set to avoid "domain of affine expression should be a set".
# The old "in" variable will be the last of the out_dims.
new_dim_idx = isl_map.dim(dim_type.out)
isl_map = isl_map.move_dims(
dim_type.out, isl_map.dim(dim_type.out),
dt, dim_idx, 1)
isl_map = isl_map.range() # now a set
dt = dim_type.set
dim_idx = new_dim_idx
del new_dim_idx

for cns in isl_map.get_constraints():
if cns.is_equality() and cns.involves_dims(dt, dim_idx, 1):
coeff = cns.get_coefficient_val(dt, dim_idx)
cns_zeroed = cns.set_coefficient_val(dt, dim_idx, 0)
if cns_zeroed.involves_dims(dt, dim_idx, 1):
# not suitable, constraint still involves dim, perhaps in a div
continue

if coeff.is_one():
return -cns_zeroed.get_aff()
elif coeff.is_negone():
return cns_zeroed.get_aff()
else:
# not suitable, coefficient does not have unit coefficient
continue

raise LoopyError("no suitable equation for '%s' found" % iname)


def map_domain(kernel, isl_map, within=None):
# FIXME: Express _split_iname_backend in terms of this
# Missing/deleted for now:
# - slab processing
# - priorities processing
# FIXME: Process priorities
# FIXME: Express affine_map_inames in terms of this, deprecate
# FIXME: Document

# FIXME: Support within

# {{{ within processing (disabled for now)
if within is not None:
raise NotImplementedError("within")

from loopy.match import parse_match
within = parse_match(within)

# {{{ return the same kernel if no kernel matches

def _do_not_transform_if_no_within_matches():
for insn in kernel.instructions:
if within(kernel, insn):
return

return kernel

_do_not_transform_if_no_within_matches()

# }}}

# }}}

if not isl_map.is_bijective():
raise LoopyError("isl_map must be bijective")

new_inames = frozenset(isl_map.get_var_dict(dim_type.out))
old_inames = frozenset(isl_map.get_var_dict(dim_type.in_))

# {{{ solve for representation of old inames in terms of new

substitutions = {}
var_substitutions = {}
applied_iname_rewrites = kernel.applied_iname_rewrites[:]

from loopy.symbolic import aff_to_expr
from pymbolic import var
for iname in old_inames:
substitutions[iname] = aff_to_expr(
_find_aff_subst_from_map(iname, isl_map))
var_substitutions[var(iname)] = aff_to_expr(
_find_aff_subst_from_map(iname, isl_map))

applied_iname_rewrites.append(var_substitutions)
del var_substitutions

# }}}

def process_set(s):
var_dict = s.get_var_dict()

overlap = old_inames & frozenset(var_dict)
if overlap and len(overlap) != len(old_inames):
raise LoopyError("loop domain '%s' involves a part "
"of the map domain inames. Domains must "
"either involve all or none of the map domain "
"inames." % s)

# {{{ align dims of isl_map and s

# FIXME: Make this less gross
# FIXME: Make an exported/documented interface of this in islpy
from islpy import _align_dim_type

map_with_s_domain = isl.Map.from_domain(s)

dim_types = [dim_type.param, dim_type.in_, dim_type.out]
s_names = [
map_with_s_domain.get_dim_name(dt, i)
for dt in dim_types
for i in range(map_with_s_domain.dim(dt))
]
map_names = [
isl_map.get_dim_name(dt, i)
for dt in dim_types
for i in range(isl_map.dim(dt))
]
aligned_map = _align_dim_type(
dim_type.param,
isl_map, map_with_s_domain, False,
map_names, s_names)
aligned_map = _align_dim_type(
dim_type.in_,
isl_map, map_with_s_domain, False,
map_names, s_names)
# Old code
"""
aligned_map = _align_dim_type(
dim_type.param,
isl_map, map_with_s_domain, obj_bigger_ok=False,
obj_names=map_names, tgt_names=s_names)
aligned_map = _align_dim_type(
dim_type.in_,
isl_map, map_with_s_domain, obj_bigger_ok=False,
obj_names=map_names, tgt_names=s_names)
"""
# }}}

return aligned_map.intersect_domain(s).range()

# FIXME: Revive _project_out_only_if_all_instructions_in_within

new_domains = [process_set(dom) for dom in kernel.domains]

# {{{ update within_inames

new_insns = []
for insn in kernel.instructions:
overlap = old_inames & insn.within_inames
if overlap and within(kernel, insn):
if len(overlap) != len(old_inames):
raise LoopyError("instruction '%s' is within only a part "
"of the map domain inames. Instructions must "
"either be within all or none of the map domain "
"inames." % insn.id)

insn = insn.copy(
within_inames=(insn.within_inames - old_inames) | new_inames)
else:
# leave insn unmodified
pass

new_insns.append(insn)

# }}}

kernel = kernel.copy(
domains=new_domains,
instructions=new_insns,
applied_iname_rewrites=applied_iname_rewrites)

rule_mapping_context = SubstitutionRuleMappingContext(
kernel.substitutions, kernel.get_var_name_generator())
ins = _MapDomainMapper(rule_mapping_context, within,
new_inames, substitutions)

kernel = ins.map_kernel(kernel)
kernel = rule_mapping_context.finish_kernel(kernel)

return kernel

# }}}


def add_inames_for_unused_hw_axes(kernel, within=None):
"""
Returns a kernel with inames added to each instruction
Expand Down
54 changes: 53 additions & 1 deletion test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ def test_split_iname_only_if_in_within():

def test_nested_substs_in_insns(ctx_factory):
ctx = ctx_factory()
import loopy as lp

ref_knl = lp.make_kernel(
"{[i]: 0<=i<10}",
Expand All @@ -568,6 +567,59 @@ def test_nested_substs_in_insns(ctx_factory):
lp.auto_test_vs_ref(ref_knl, ctx, knl)


def test_diamond_tiling(ctx_factory, interactive=False):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)

ref_knl = lp.make_kernel(
"[nx,nt] -> {[ix, it]: 1<=ix<nx-1 and 0<=it<nt}",
"""
u[ix, it+2] = (
2*u[ix, it+1]
+ dt**2/dx**2 * (u[ix+1, it+1] - 2*u[ix, it+1] + u[ix-1, it+1])
- u[ix, it])
""")

# FIXME: Handle priorities in map_domain
knl_for_transform = ref_knl

ref_knl = lp.prioritize_loops(ref_knl, "it, ix")

import islpy as isl
m = isl.BasicMap(
"[nx,nt] -> {[ix, it] -> [tx, tt, tparity, itt, itx]: "
"16*(tx - tt) + itx - itt = ix - it and "
"16*(tx + tt + tparity) + itt + itx = ix + it and "
"0<=tparity<2 and 0 <= itx - itt < 16 and 0 <= itt+itx < 16}")
knl = lp.map_domain(knl_for_transform, m)
knl = lp.prioritize_loops(knl, "tt,tparity,tx,itt,itx")

if interactive:
nx = 43
u = np.zeros((nx, 200))
x = np.linspace(-1, 1, nx)
dx = x[1] - x[0]
u[:, 0] = u[:, 1] = np.exp(-100*x**2)

u_dev = cl.array.to_device(queue, u)
knl(queue, u=u_dev, dx=dx, dt=dx)

u = u_dev.get()
import matplotlib.pyplot as plt
plt.imshow(u.T)
plt.show()
else:
types = {"dt,dx,u": np.float64}
knl = lp.add_and_infer_dtypes(knl, types)
ref_knl = lp.add_and_infer_dtypes(ref_knl, types)

lp.auto_test_vs_ref(ref_knl, ctx, knl,
parameters={
"nx": 200, "nt": 300,
"dx": 1, "dt": 1
})


def test_extract_subst_with_iname_deps_in_templ(ctx_factory):
knl = lp.make_kernel(
"{[i, j, k]: 0<=i<100 and 0<=j,k<5}",
Expand Down