|
29 | 29 | from typing import TYPE_CHECKING, Iterable, Sequence, cast |
30 | 30 |
|
31 | 31 | import numpy as np |
| 32 | +from typing_extensions import Never |
32 | 33 |
|
33 | 34 | import pymbolic.primitives as p |
34 | 35 | from cgen import Collection, Const, Declarator, Generable |
35 | 36 | from pymbolic import var |
36 | 37 | from pymbolic.mapper.stringifier import PREC_NONE |
| 38 | +from pymbolic.mapper.substitutor import make_subst_func |
37 | 39 | from pytools import memoize_method |
38 | 40 |
|
39 | 41 | from loopy.diagnostic import LoopyError |
40 | | -from loopy.kernel.data import AddressSpace, ArrayArg, TemporaryVariable |
41 | | -from loopy.symbolic import CombineMapper, Literal |
| 42 | +from loopy.kernel.data import AddressSpace, ArrayArg, LocalInameTag, TemporaryVariable |
| 43 | +from loopy.symbolic import ( |
| 44 | + CoefficientCollector, |
| 45 | + CombineMapper, |
| 46 | + GroupHardwareAxisIndex, |
| 47 | + Literal, |
| 48 | + LocalHardwareAxisIndex, |
| 49 | + SubstitutionMapper, |
| 50 | + flatten, |
| 51 | +) |
42 | 52 | from loopy.target.c import CFamilyASTBuilder, CFamilyTarget |
43 | 53 | from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper |
44 | 54 |
|
45 | 55 |
|
46 | 56 | if TYPE_CHECKING: |
47 | 57 | from loopy.codegen import CodeGenerationState |
48 | 58 | from loopy.codegen.result import CodeGenerationResult |
| 59 | + from loopy.kernel import LoopKernel |
| 60 | + from loopy.kernel.instruction import Assignment |
49 | 61 | from loopy.schedule import CallKernel |
50 | 62 | from loopy.types import LoopyType |
51 | 63 | from loopy.typing import Expression |
52 | 64 |
|
53 | 65 |
|
54 | 66 | class IsVaryingMapper(CombineMapper[bool, []]): |
| 67 | + # FIXME: Update this if/when ispc reduction support is added. |
| 68 | + |
| 69 | + def __init__(self, kernel: LoopKernel) -> None: |
| 70 | + self.kernel = kernel |
| 71 | + super().__init__() |
| 72 | + |
55 | 73 | def combine(self, values: Iterable[bool]) -> bool: |
56 | 74 | return reduce(operator.or_, values, False) |
57 | 75 |
|
58 | 76 | def map_constant(self, expr): |
59 | 77 | return False |
60 | 78 |
|
61 | | - def map_group_hw_index(self, expr): |
62 | | - return False |
63 | | - |
64 | | - def map_local_hw_index(self, expr): |
65 | | - if expr.axis == 0: |
66 | | - return True |
67 | | - else: |
68 | | - raise LoopyError("ISPC only supports one local axis") |
| 79 | + def map_group_hw_index(self, expr: GroupHardwareAxisIndex) -> Never: |
| 80 | + # These only exist for a brief blip in time inside the expr-to-cexpr |
| 81 | + # mapper. We should never see them. |
| 82 | + raise AssertionError() |
| 83 | + |
| 84 | + def map_local_hw_index(self, expr: LocalHardwareAxisIndex) -> Never: |
| 85 | + # These only exist for a brief blip in time inside the expr-to-cexpr |
| 86 | + # mapper. We should never see them. |
| 87 | + raise AssertionError() |
| 88 | + |
| 89 | + def map_variable(self, expr: p.Variable) -> bool: |
| 90 | + iname = self.kernel.inames.get(expr.name) |
| 91 | + if iname is not None: |
| 92 | + ltags = iname.tags_of_type(LocalInameTag) |
| 93 | + if ltags: |
| 94 | + ltag, = ltags |
| 95 | + assert ltag.axis == 0 |
| 96 | + return True |
69 | 97 |
|
70 | | - def map_variable(self, expr): |
71 | 98 | return False |
72 | 99 |
|
73 | 100 |
|
@@ -127,8 +154,7 @@ def map_variable(self, expr, type_context): |
127 | 154 | return expr |
128 | 155 |
|
129 | 156 | else: |
130 | | - return super().map_variable( |
131 | | - expr, type_context) |
| 157 | + return super().map_variable(expr, type_context) |
132 | 158 |
|
133 | 159 | def map_subscript(self, expr, type_context): |
134 | 160 | from loopy.kernel.data import TemporaryVariable |
@@ -175,8 +201,8 @@ def rec(self, expr, type_context=None, needed_type: LoopyType | None = None): # |
175 | 201 | else: |
176 | 202 | actual_type = self.infer_type(expr) |
177 | 203 | if actual_type != needed_type: |
178 | | - # FIXME: problematic: quadratic complexity |
179 | | - is_varying = IsVaryingMapper()(expr) |
| 204 | + # FIXME: problematic: potential quadratic complexity |
| 205 | + is_varying = IsVaryingMapper(self.kernel)(expr) |
180 | 206 | registry = self.codegen_state.ast_builder.target.get_dtype_registry() |
181 | 207 | cast = var("(" |
182 | 208 | f"{'varying' if is_varying else 'uniform'} " |
@@ -409,7 +435,12 @@ def get_temporary_var_declarator(self, |
409 | 435 | # }}} |
410 | 436 |
|
411 | 437 | # {{{ emit_... |
412 | | - def emit_assignment(self, codegen_state, insn): |
| 438 | + |
| 439 | + def emit_assignment( |
| 440 | + self, |
| 441 | + codegen_state: CodeGenerationState, |
| 442 | + insn: Assignment |
| 443 | + ): |
413 | 444 | kernel = codegen_state.kernel |
414 | 445 | ecm = codegen_state.expression_to_code_mapper |
415 | 446 |
|
@@ -442,83 +473,61 @@ def emit_assignment(self, codegen_state, insn): |
442 | 473 |
|
443 | 474 | from loopy.kernel.array import get_access_info |
444 | 475 | from loopy.symbolic import simplify_using_aff |
445 | | - index_tuple = tuple( |
446 | | - simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) |
447 | 476 |
|
448 | | - access_info = get_access_info(kernel, ary, index_tuple, |
449 | | - lambda expr: evaluate(expr, codegen_state.var_subst_map), |
450 | | - codegen_state.vectorization_info) |
| 477 | + if not isinstance(lhs, p.Subscript): |
| 478 | + raise LoopyError("streaming store must have a subscript as argument") |
451 | 479 |
|
452 | 480 | from loopy.kernel.data import ArrayArg, TemporaryVariable |
453 | | - |
454 | 481 | if not isinstance(ary, (ArrayArg, TemporaryVariable)): |
455 | 482 | raise LoopyError("array type not supported in ISPC: %s" |
456 | 483 | % type(ary).__name) |
457 | 484 |
|
| 485 | + index_tuple = tuple( |
| 486 | + simplify_using_aff(kernel, idx) for idx in lhs.index_tuple) |
| 487 | + |
| 488 | + access_info = get_access_info(kernel, ary, index_tuple, |
| 489 | + lambda expr: cast("int", |
| 490 | + evaluate(expr, codegen_state.var_subst_map)), |
| 491 | + codegen_state.vectorization_info) |
| 492 | + |
| 493 | + l0_inames = { |
| 494 | + iname for iname in insn.within_inames |
| 495 | + if kernel.inames[iname].tags_of_type(LocalInameTag)} |
| 496 | + |
458 | 497 | if len(access_info.subscripts) != 1: |
459 | 498 | raise LoopyError("streaming stores must have a subscript") |
460 | 499 | subscript, = access_info.subscripts |
461 | 500 |
|
462 | | - from pymbolic.primitives import Sum, Variable, flattened_sum |
463 | | - if isinstance(subscript, Sum): |
464 | | - terms = subscript.children |
465 | | - else: |
466 | | - terms = (subscript.children,) |
467 | | - |
468 | | - new_terms = [] |
469 | | - |
470 | | - from loopy.kernel.data import LocalInameTag, filter_iname_tags_by_type |
471 | | - from loopy.symbolic import get_dependencies |
472 | | - |
473 | | - saw_l0 = False |
474 | | - for term in terms: |
475 | | - if (isinstance(term, Variable) |
476 | | - and kernel.iname_tags_of_type(term.name, LocalInameTag)): |
477 | | - tag, = kernel.iname_tags_of_type( |
478 | | - term.name, LocalInameTag, min_num=1, max_num=1) |
479 | | - if tag.axis == 0: |
480 | | - if saw_l0: |
481 | | - raise LoopyError( |
482 | | - "streaming store must have stride 1 in " |
483 | | - "local index, got: %s" % subscript) |
484 | | - saw_l0 = True |
485 | | - continue |
486 | | - else: |
487 | | - for dep in get_dependencies(term): |
488 | | - if dep in kernel.all_inames() and ( |
489 | | - filter_iname_tags_by_type(kernel.inames[dep].tags, |
490 | | - LocalInameTag)): |
491 | | - tag, = filter_iname_tags_by_type( |
492 | | - kernel.inames[dep].tags, LocalInameTag, 1) |
493 | | - if tag.axis == 0: |
494 | | - raise LoopyError( |
495 | | - "streaming store must have stride 1 in " |
496 | | - "local index, got: %s" % subscript) |
497 | | - |
498 | | - new_terms.append(term) |
499 | | - |
500 | | - if not saw_l0: |
501 | | - raise LoopyError("streaming store must have stride 1 in " |
502 | | - "local index, got: %s" % subscript) |
| 501 | + if l0_inames: |
| 502 | + l0_iname, = l0_inames |
| 503 | + coeffs = CoefficientCollector([l0_iname])(subscript) |
| 504 | + if coeffs[p.Variable(l0_iname)] != 1: |
| 505 | + raise ValueError("coefficient of streaming store index " |
| 506 | + "in l.0 variable must be 1") |
| 507 | + |
| 508 | + subscript = flatten( |
| 509 | + SubstitutionMapper(make_subst_func({l0_iname: 0}))(subscript)) |
| 510 | + del l0_iname |
503 | 511 |
|
504 | 512 | if access_info.vector_index is not None: |
505 | 513 | raise LoopyError("streaming store may not use a short-vector " |
506 | 514 | "data type") |
507 | 515 |
|
508 | | - rhs_has_programindex = any( |
509 | | - isinstance(tag, LocalInameTag) and tag.axis == 0 |
510 | | - for tag in kernel.iname_tags(dep) |
511 | | - for dep in get_dependencies(insn.expression)) |
512 | | - |
513 | | - if not rhs_has_programindex: |
514 | | - rhs_code = "broadcast(%s, 0)" % rhs_code |
| 516 | + if (l0_inames |
| 517 | + and not IsVaryingMapper(codegen_state.kernel)(insn.expression)): |
| 518 | + # rhs is uniform, must be cast to varying in order for streaming_store |
| 519 | + # to perform a vector store. |
| 520 | + registry = codegen_state.ast_builder.target.get_dtype_registry() |
| 521 | + rhs_code = var("(varying " |
| 522 | + f"{registry.dtype_to_ctype(lhs_dtype)}" |
| 523 | + f") ({rhs_code})") |
515 | 524 |
|
516 | 525 | from cgen import Statement |
517 | 526 | return Statement( |
518 | 527 | "streaming_store(%s + %s, %s)" |
519 | 528 | % ( |
520 | 529 | access_info.array_name, |
521 | | - ecm(flattened_sum(new_terms), PREC_NONE, "i"), |
| 530 | + ecm(subscript, PREC_NONE, "i"), |
522 | 531 | rhs_code)) |
523 | 532 |
|
524 | 533 | # }}} |
|
0 commit comments