Skip to content

Commit f0dd15a

Browse files
committed
Fix ispc streaming store generation
1 parent 79ca6f8 commit f0dd15a

File tree

1 file changed

+80
-71
lines changed

1 file changed

+80
-71
lines changed

loopy/target/ispc.py

Lines changed: 80 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -29,45 +29,72 @@
2929
from typing import TYPE_CHECKING, Iterable, Sequence, cast
3030

3131
import numpy as np
32+
from typing_extensions import Never
3233

3334
import pymbolic.primitives as p
3435
from cgen import Collection, Const, Declarator, Generable
3536
from pymbolic import var
3637
from pymbolic.mapper.stringifier import PREC_NONE
38+
from pymbolic.mapper.substitutor import make_subst_func
3739
from pytools import memoize_method
3840

3941
from loopy.diagnostic import LoopyError
40-
from loopy.kernel.data import AddressSpace, ArrayArg, TemporaryVariable
41-
from loopy.symbolic import CombineMapper, Literal
42+
from loopy.kernel.data import AddressSpace, ArrayArg, LocalInameTag, TemporaryVariable
43+
from loopy.symbolic import (
44+
CoefficientCollector,
45+
CombineMapper,
46+
GroupHardwareAxisIndex,
47+
Literal,
48+
LocalHardwareAxisIndex,
49+
SubstitutionMapper,
50+
flatten,
51+
)
4252
from loopy.target.c import CFamilyASTBuilder, CFamilyTarget
4353
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
4454

4555

4656
if TYPE_CHECKING:
4757
from loopy.codegen import CodeGenerationState
4858
from loopy.codegen.result import CodeGenerationResult
59+
from loopy.kernel import LoopKernel
60+
from loopy.kernel.instruction import Assignment
4961
from loopy.schedule import CallKernel
5062
from loopy.types import LoopyType
5163
from loopy.typing import Expression
5264

5365

5466
class IsVaryingMapper(CombineMapper[bool, []]):
67+
# FIXME: Update this if/when ispc reduction support is added.
68+
69+
def __init__(self, kernel: LoopKernel) -> None:
70+
self.kernel = kernel
71+
super().__init__()
72+
5573
def combine(self, values: Iterable[bool]) -> bool:
5674
return reduce(operator.or_, values, False)
5775

5876
def map_constant(self, expr):
5977
return False
6078

61-
def map_group_hw_index(self, expr):
62-
return False
63-
64-
def map_local_hw_index(self, expr):
65-
if expr.axis == 0:
66-
return True
67-
else:
68-
raise LoopyError("ISPC only supports one local axis")
79+
def map_group_hw_index(self, expr: GroupHardwareAxisIndex) -> Never:
80+
# These only exist for a brief blip in time inside the expr-to-cexpr
81+
# mapper. We should never see them.
82+
raise AssertionError()
83+
84+
def map_local_hw_index(self, expr: LocalHardwareAxisIndex) -> Never:
85+
# These only exist for a brief blip in time inside the expr-to-cexpr
86+
# mapper. We should never see them.
87+
raise AssertionError()
88+
89+
def map_variable(self, expr: p.Variable) -> bool:
90+
iname = self.kernel.inames.get(expr.name)
91+
if iname is not None:
92+
ltags = iname.tags_of_type(LocalInameTag)
93+
if ltags:
94+
ltag, = ltags
95+
assert ltag.axis == 0
96+
return True
6997

70-
def map_variable(self, expr):
7198
return False
7299

73100

@@ -127,8 +154,7 @@ def map_variable(self, expr, type_context):
127154
return expr
128155

129156
else:
130-
return super().map_variable(
131-
expr, type_context)
157+
return super().map_variable(expr, type_context)
132158

133159
def map_subscript(self, expr, type_context):
134160
from loopy.kernel.data import TemporaryVariable
@@ -175,8 +201,8 @@ def rec(self, expr, type_context=None, needed_type: LoopyType | None = None): #
175201
else:
176202
actual_type = self.infer_type(expr)
177203
if actual_type != needed_type:
178-
# FIXME: problematic: quadratic complexity
179-
is_varying = IsVaryingMapper()(expr)
204+
# FIXME: problematic: potential quadratic complexity
205+
is_varying = IsVaryingMapper(self.kernel)(expr)
180206
registry = self.codegen_state.ast_builder.target.get_dtype_registry()
181207
cast = var("("
182208
f"{'varying' if is_varying else 'uniform'} "
@@ -409,7 +435,12 @@ def get_temporary_var_declarator(self,
409435
# }}}
410436

411437
# {{{ emit_...
412-
def emit_assignment(self, codegen_state, insn):
438+
439+
def emit_assignment(
440+
self,
441+
codegen_state: CodeGenerationState,
442+
insn: Assignment
443+
):
413444
kernel = codegen_state.kernel
414445
ecm = codegen_state.expression_to_code_mapper
415446

@@ -442,83 +473,61 @@ def emit_assignment(self, codegen_state, insn):
442473

443474
from loopy.kernel.array import get_access_info
444475
from loopy.symbolic import simplify_using_aff
445-
index_tuple = tuple(
446-
simplify_using_aff(kernel, idx) for idx in lhs.index_tuple)
447476

448-
access_info = get_access_info(kernel, ary, index_tuple,
449-
lambda expr: evaluate(expr, codegen_state.var_subst_map),
450-
codegen_state.vectorization_info)
477+
if not isinstance(lhs, p.Subscript):
478+
raise LoopyError("streaming store must have a subscript as argument")
451479

452480
from loopy.kernel.data import ArrayArg, TemporaryVariable
453-
454481
if not isinstance(ary, (ArrayArg, TemporaryVariable)):
455482
raise LoopyError("array type not supported in ISPC: %s"
456483
% type(ary).__name)
457484

485+
index_tuple = tuple(
486+
simplify_using_aff(kernel, idx) for idx in lhs.index_tuple)
487+
488+
access_info = get_access_info(kernel, ary, index_tuple,
489+
lambda expr: cast("int",
490+
evaluate(expr, codegen_state.var_subst_map)),
491+
codegen_state.vectorization_info)
492+
493+
l0_inames = {
494+
iname for iname in insn.within_inames
495+
if kernel.inames[iname].tags_of_type(LocalInameTag)}
496+
458497
if len(access_info.subscripts) != 1:
459498
raise LoopyError("streaming stores must have a subscript")
460499
subscript, = access_info.subscripts
461500

462-
from pymbolic.primitives import Sum, Variable, flattened_sum
463-
if isinstance(subscript, Sum):
464-
terms = subscript.children
465-
else:
466-
terms = (subscript.children,)
467-
468-
new_terms = []
469-
470-
from loopy.kernel.data import LocalInameTag, filter_iname_tags_by_type
471-
from loopy.symbolic import get_dependencies
472-
473-
saw_l0 = False
474-
for term in terms:
475-
if (isinstance(term, Variable)
476-
and kernel.iname_tags_of_type(term.name, LocalInameTag)):
477-
tag, = kernel.iname_tags_of_type(
478-
term.name, LocalInameTag, min_num=1, max_num=1)
479-
if tag.axis == 0:
480-
if saw_l0:
481-
raise LoopyError(
482-
"streaming store must have stride 1 in "
483-
"local index, got: %s" % subscript)
484-
saw_l0 = True
485-
continue
486-
else:
487-
for dep in get_dependencies(term):
488-
if dep in kernel.all_inames() and (
489-
filter_iname_tags_by_type(kernel.inames[dep].tags,
490-
LocalInameTag)):
491-
tag, = filter_iname_tags_by_type(
492-
kernel.inames[dep].tags, LocalInameTag, 1)
493-
if tag.axis == 0:
494-
raise LoopyError(
495-
"streaming store must have stride 1 in "
496-
"local index, got: %s" % subscript)
497-
498-
new_terms.append(term)
499-
500-
if not saw_l0:
501-
raise LoopyError("streaming store must have stride 1 in "
502-
"local index, got: %s" % subscript)
501+
if l0_inames:
502+
l0_iname, = l0_inames
503+
coeffs = CoefficientCollector([l0_iname])(subscript)
504+
if coeffs[p.Variable(l0_iname)] != 1:
505+
raise ValueError("coefficient of streaming store index "
506+
"in l.0 variable must be 1")
507+
508+
subscript = flatten(
509+
SubstitutionMapper(make_subst_func({l0_iname: 0}))(subscript))
510+
del l0_iname
503511

504512
if access_info.vector_index is not None:
505513
raise LoopyError("streaming store may not use a short-vector "
506514
"data type")
507515

508-
rhs_has_programindex = any(
509-
isinstance(tag, LocalInameTag) and tag.axis == 0
510-
for tag in kernel.iname_tags(dep)
511-
for dep in get_dependencies(insn.expression))
512-
513-
if not rhs_has_programindex:
514-
rhs_code = "broadcast(%s, 0)" % rhs_code
516+
if (l0_inames
517+
and not IsVaryingMapper(codegen_state.kernel)(insn.expression)):
518+
# rhs is uniform, must be cast to varying in order for streaming_store
519+
# to perform a vector store.
520+
registry = codegen_state.ast_builder.target.get_dtype_registry()
521+
rhs_code = var("(varying "
522+
f"{registry.dtype_to_ctype(lhs_dtype)}"
523+
f") ({rhs_code})")
515524

516525
from cgen import Statement
517526
return Statement(
518527
"streaming_store(%s + %s, %s)"
519528
% (
520529
access_info.array_name,
521-
ecm(flattened_sum(new_terms), PREC_NONE, "i"),
530+
ecm(subscript, PREC_NONE, "i"),
522531
rhs_code))
523532

524533
# }}}

0 commit comments

Comments
 (0)