From d6e873245206f24427dada14324c92692db43470 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Oct 2024 16:47:20 +0200 Subject: [PATCH 001/124] Add concat_where frontend and domain inference --- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 2 + .../ffront/foast_passes/type_deduction.py | 34 +++++- src/gt4py/next/ffront/foast_to_gtir.py | 14 ++- src/gt4py/next/iterator/ir.py | 3 + .../next/iterator/ir_utils/domain_utils.py | 46 +++++++ src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 + .../iterator/transforms/constant_folding.py | 31 ++++- .../next/iterator/transforms/infer_domain.py | 31 +++++ .../iterator/transforms/infer_domain_ops.py | 81 +++++++++++++ .../next/iterator/transforms/pass_manager.py | 2 + .../type_system/type_specifications.py | 4 - .../iterator/type_system/type_synthesizer.py | 23 +++- .../next/type_system/type_specifications.py | 6 +- .../ffront_tests/test_concat_where.py | 2 +- .../iterator_tests/test_type_inference.py | 1 + .../transforms_tests/test_constant_folding.py | 7 ++ .../transforms_tests/test_domain_inference.py | 114 ++++++++++++++++++ 18 files changed, 390 insertions(+), 18 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/infer_domain_ops.py diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index bd22aebe57..c9bea908a8 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi @WhereBuiltinFunction def concat_where( - mask: common.Field, + mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, false_field: common.Field | core_defs.ScalarT | Tuple, /, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 1210e96efc..d55af4fa29 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -58,6 +58,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType + elif t is common.Domain: + return ts.DomainType elif t is type: return ( ts.FunctionType diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6b40cbb77f..ae3d3c6437 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -20,6 +20,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices +from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -570,6 +571,19 @@ def _deduce_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: # check both types compatible + if ( + isinstance(left.type, ts.DimensionType) + and isinstance(right.type, ts.ScalarType) + and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ): + return ts.DomainType(dims=[left.type.dim]) + if ( + isinstance(right.type, ts.DimensionType) + and isinstance(left.type, ts.ScalarType) + and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + ): + return ts.DomainType(dims=[right.type.dim]) + # TODO for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( @@ -908,6 +922,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) try: + # TODO(tehrengruber): the construct_tuple_type function doesn't look correct if isinstance(true_branch_type, ts.TupleType) and isinstance( false_branch_type, ts.TupleType ): @@ -943,7 +958,24 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, ) - _visit_concat_where = _visit_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: + true_branch_type = node.args[1].type + false_branch_type = node.args[2].type + if true_branch_type != false_branch_type: + raise errors.DSLError( + node.location, + f"Incompatible argument in call to '{node.func!s}': expected " + f"'{true_branch_type}' and '{false_branch_type}' to be equal.", + ) + return_type = true_branch_type + + return foast.Call( + func=node.func, + args=node.args, + kwargs=node.kwargs, + type=return_type, + location=node.location, + ) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4519b4e571..82d16202c4 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -225,6 +225,8 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: return im.sym(node.id) def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + if isinstance(node.type, ts.DimensionType): + return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: @@ -261,6 +263,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: + # TODO: double-check if we need the changes in the original PR return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -394,7 +397,13 @@ def create_if( return im.let(cond_symref_name, cond_)(result) - _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + if not isinstance(node.type, ts.TupleType): # to keep the IR simpler + return im.call("concat_where")(*self.visit(node.args)) + else: + raise NotImplementedError() + + # TODO: tuple case def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) @@ -476,8 +485,9 @@ def _map( """ Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ + # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, ts.ScalarType) + isinstance(t, ts.ScalarType, ts.DimensionType) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e875709631..61ac0aee74 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -169,6 +169,9 @@ class FunctionDefinition(Node, SymbolTableTrait): "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) + "concat_where", + "inf", # TODO: discuss + "neg_inf", # TODO: discuss *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 4a023f7535..8e549828eb 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -174,3 +174,49 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) intersection of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: + """Return the (set) complement of a domain.""" + dims_dict = {} + for dim in domain.ranges.keys(): + lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + if lb == im.ref("neg_inf"): + dims_dict[dim] = SymbolicRange(int(ub.value), "inf") + elif ub == im.ref("inf"): + dims_dict[dim] = SymbolicRange("neg_inf", int(lb.value)) + else: + raise ValueError("Invalid domain ranges") + return SymbolicDomain(domain.grid_type, dims_dict) + + +def promote_to_same_dimensions( + domain_small: SymbolicDomain, domain_large: SymbolicDomain +) -> SymbolicDomain: + """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + dims_dict = {} + for dim in domain_large.ranges.keys(): + if dim in domain_small.ranges.keys(): + lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop + dims_dict[dim] = SymbolicRange(lb, ub) + else: + dims_dict[dim] = SymbolicRange("neg_inf", "inf") + return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..91bfd8b50d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -246,6 +246,11 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) +def concat_where(cond, true_field, false_field): + """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + return call("concat_where")(cond, true_field, false_field) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..f8e86670ed 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -8,7 +8,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @@ -21,12 +21,35 @@ def visit_FunCall(self, node: ir.FunCall): new_node = self.generic_visit(node) if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] + cpm.is_call_to(new_node, ("minimum", "maximum")) and new_node.args[0] == new_node.args[1] ): # `minimum(a, a)` -> `a` return new_node.args[0] + if cpm.is_call_to(new_node, "minimum"): + # `minimum(neg_inf, neg_inf)` -> `neg_inf` + if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to( + new_node.args[1], "neg_inf" + ): + return im.ref("neg_inf") + # `minimum(inf, a)` -> `a` + elif cpm.is_ref_to(new_node.args[0], "inf"): + return new_node.args[1] + # `minimum(a, inf)` -> `a` + elif cpm.is_ref_to(new_node.args[1], "inf"): + return new_node.args[0] + + if cpm.is_call_to(new_node, "maximum"): + # `minimum(inf, inf)` -> `inf` + if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"): + return im.ref("inf") + # `minimum(neg_inf, a)` -> `a` + elif cpm.is_ref_to(new_node.args[0], "neg_inf"): + return new_node.args[1] + # `minimum(a, neg_inf)` -> `a` + elif cpm.is_ref_to(new_node.args[1], "neg_inf"): + return new_node.args[0] + if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" @@ -52,6 +75,6 @@ def visit_FunCall(self, node: ir.FunCall): ] new_node = im.literal_from_value(fun(*arg_values)) except ValueError: - pass # happens for inf and neginf + pass # happens for SymRefs which are not inf or neg_inf return new_node diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f26d3f9ec2..f2044e4b6f 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -363,6 +363,35 @@ def _infer_if( return result_expr, actual_domains +def _infer_concat_where( + expr: itir.Expr, + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: + assert cpm.is_call_to(expr, "concat_where") + infered_args_expr = [] + actual_domains: AccessedDomains = {} + cond, true_field, false_field = expr.args + symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + for arg in [true_field, false_field]: + if arg == true_field: + extended_cond = domain_utils.promote_to_same_dimensions(symbolic_cond, domain) + domain_ = domain_utils.domain_intersection(domain, extended_cond) + elif arg == false_field: + cond_complement = domain_utils.domain_complement(symbolic_cond) + extended_cond_complement = domain_utils.promote_to_same_dimensions( + cond_complement, domain + ) + domain_ = domain_utils.domain_intersection(domain, extended_cond_complement) + + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains + + def _infer_expr( expr: itir.Expr, domain: DomainAccess, @@ -382,6 +411,8 @@ def _infer_expr( return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) + elif cpm.is_call_to(expr, "concat_where"): + return _infer_concat_where(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py new file mode 100644 index 0000000000..f4422d506a --- /dev/null +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -0,0 +1,81 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + + +class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if isinstance(node, ir.FunCall) and cpm.is_call_to( + node, ir.BINARY_MATH_COMPARISON_BUILTINS + ): + if isinstance(node.args[0], ir.AxisLiteral) and isinstance(node.args[1], ir.Literal): + dim = common.Dimension(value=node.args[0].value, kind=common.DimensionKind.VERTICAL) + value = int(node.args[1].value) + reverse = False + elif isinstance(node.args[0], ir.Literal) and isinstance(node.args[1], ir.AxisLiteral): + dim = common.Dimension(value=node.args[1].value, kind=common.DimensionKind.VERTICAL) + value = int(node.args[0].value) + reverse = True + else: + raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") + + match node.fun.id: + case ir.SymbolRef("less"): + if reverse: + min = value + 1 + max = "inf" + else: + min = "neg_inf" + max = value - 1 + case ir.SymbolRef("less_equal"): + if reverse: + min = value + max = "inf" + else: + min = "neg_inf" + max = value + case ir.SymbolRef("greater"): + if reverse: + min = "neg_inf" + max = value - 1 + else: + min = value + 1 + max = "inf" + case ir.SymbolRef("greater_equal"): + if reverse: + min = "neg_inf" + max = value + else: + min = value + max = "inf" + case ir.SymbolRef("eq"): + min = max = value + case ir.SymbolRef("not_eq"): + min1 = "neg_inf" + max1 = value - 1 + min2 = value + 1 + max2 = "inf" + return im.call("and_")( + im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), + im.domain(common.GridType.CARTESIAN, {dim: (min2, max2)}), + ) + case _: + raise NotImplementedError + + return im.domain(common.GridType.CARTESIAN, {dim: (min, max)}) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6906f81e3f..e92b4b1fd0 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -15,6 +15,7 @@ fuse_as_fieldop, global_tmps, infer_domain, + infer_domain_ops, inline_dynamic_shifts, inline_fundefs, inline_lifts, @@ -83,6 +84,7 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = infer_domain_ops.InferDomainOps.apply(ir) ir = infer_domain.infer_program( ir, offset_provider=offset_provider, diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 7825bf1c98..30c79c7c94 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -16,10 +16,6 @@ class NamedRangeType(ts.TypeSpec): dim: common.Dimension -class DomainType(ts.DataType): - dims: list[common.Dimension] | Literal["unknown"] - - class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6e9936c4af..7e4e36da98 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -112,7 +112,12 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer( fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS ) -def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): + return ts.DomainType(dims=[rhs.dim]) + if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): + return ts.DomainType(dims=[lhs.dim]) + assert isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.ScalarType) return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -183,9 +188,9 @@ def named_range( @_register_builtin_type_synthesizer(fun_names=["cartesian_domain", "unstructured_domain"]) -def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: +def _(*args: it_ts.NamedRangeType) -> ts.DomainType: assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) - return it_ts.DomainType(dims=[arg.dim for arg in args]) + return ts.DomainType(dims=[arg.dim for arg in args]) @_register_builtin_type_synthesizer @@ -202,7 +207,17 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: +def concat_where( + domain: ts.DomainType, + true_field: ts.FieldType | ts.TupleType, + false_field: ts.FieldType | ts.TupleType, +) -> ts.FieldType: + assert true_field == false_field + return true_field + + +@_register_builtin_type_synthesizer +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index c1c0f0b5e1..fd946075d1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Optional, Sequence, Union +from typing import Iterator, Literal, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -127,3 +127,7 @@ def __str__(self) -> str: kwarg_strs = [f"{key}: {value}" for key, value in self.pos_or_kw_args.items()] args_str = ", ".join((*arg_strs, *kwarg_strs)) return f"({args_str}) -> {self.returns}" + + +class DomainType(DataType): + dims: list[common.Dimension] | Literal["unknown"] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 364434029f..27e6988744 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -23,7 +23,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim <= 2, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d4d7c60d69..8af157ebcc 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -72,6 +72,7 @@ def expression_test_cases(): return ( # itir expr, type + # TODO: write test for IDim < 10, concat_where (im.call("abs")(1), int_type), (im.call("power")(2.0, 2), float64_type), (im.plus(1, 2), int_type), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..e0a46b48f6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -60,3 +60,10 @@ def test_constant_folding_literal_maximum(): expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_constant_folding_inf_maximum(): + testee = im.call("maximum")(im.literal_from_value(1), im.ref("inf")) + expected = im.ref("inf") + actual = ConstantFolding.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 779ab738cb..034e4993d8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1093,3 +1093,117 @@ def test_never_accessed_domain_tuple(offset_provider): "in_field2": infer_domain.DomainAccessDescriptor.NEVER, } run_test_expr(testee, testee, domain, expected_domains, offset_provider) + + +def test_concat_where(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) +# Todo: nested concat wheres + + +def test_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 10)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions_J(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, "inf")}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, "inf")}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 20)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + testee = im.concat_where( + domain_cond1, + im.concat_where( + domain_cond2, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ), + im.as_fieldop("deref")("in_field3"), + ) + + expected = im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ), + im.as_fieldop("deref", domain3)("in_field3"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2, "in_field3": domain3} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) From 69f6b118549586f51c1e15d9b0ba1dec48c71330 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Oct 2024 19:18:45 +0200 Subject: [PATCH 002/124] Finish domain inference for (nested) concat_where and transform to as_fieldop --- src/gt4py/next/iterator/ir.py | 1 + .../next/iterator/ir_utils/domain_utils.py | 8 +-- .../iterator/transforms/constant_folding.py | 13 +++- .../transforms/expand_library_functions.py | 39 ++++++++++++ .../next/iterator/transforms/infer_domain.py | 1 + .../iterator/transforms/infer_domain_ops.py | 60 ++++++++++--------- .../next/iterator/transforms/pass_manager.py | 5 ++ .../transforms/transform_concat_where.py | 34 +++++++++++ .../next/iterator/type_system/inference.py | 2 +- .../iterator/type_system/type_synthesizer.py | 8 +-- .../ffront_tests/test_concat_where.py | 21 ++++++- 11 files changed, 150 insertions(+), 42 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/expand_library_functions.py create mode 100644 src/gt4py/next/iterator/transforms/transform_concat_where.py diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 61ac0aee74..0521d027c4 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -170,6 +170,7 @@ class FunctionDefinition(Node, SymbolTableTrait): "index", # `index(dim)` creates a dim-field that has the current index at each point "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) "concat_where", + "in", "inf", # TODO: discuss "neg_inf", # TODO: discuss *ARITHMETIC_BUILTINS, diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8e549828eb..b66f21cf60 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -200,12 +200,12 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: for dim in domain.ranges.keys(): lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop if lb == im.ref("neg_inf"): - dims_dict[dim] = SymbolicRange(int(ub.value), "inf") + dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf")) elif ub == im.ref("inf"): - dims_dict[dim] = SymbolicRange("neg_inf", int(lb.value)) + dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb) else: raise ValueError("Invalid domain ranges") - return SymbolicDomain(domain.grid_type, dims_dict) + return SymbolicDomain(domain.grid_type, dims_dict) def promote_to_same_dimensions( @@ -218,5 +218,5 @@ def promote_to_same_dimensions( lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop dims_dict[dim] = SymbolicRange(lb, ub) else: - dims_dict[dim] = SymbolicRange("neg_inf", "inf") + dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf")) return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index f8e86670ed..6fe26f886b 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -26,7 +26,7 @@ def visit_FunCall(self, node: ir.FunCall): ): # `minimum(a, a)` -> `a` return new_node.args[0] - if cpm.is_call_to(new_node, "minimum"): + if cpm.is_call_to(new_node, "minimum"): # TODO: add tests # `minimum(neg_inf, neg_inf)` -> `neg_inf` if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to( new_node.args[1], "neg_inf" @@ -39,7 +39,7 @@ def visit_FunCall(self, node: ir.FunCall): elif cpm.is_ref_to(new_node.args[1], "inf"): return new_node.args[0] - if cpm.is_call_to(new_node, "maximum"): + if cpm.is_call_to(new_node, "maximum"): # TODO: add tests # `minimum(inf, inf)` -> `inf` if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"): return im.ref("inf") @@ -49,7 +49,14 @@ def visit_FunCall(self, node: ir.FunCall): # `minimum(a, neg_inf)` -> `a` elif cpm.is_ref_to(new_node.args[1], "neg_inf"): return new_node.args[0] - + if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to( + new_node.args[0], "neg_inf" + ): + return im.literal_from_value(True) # TODO: add tests + if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to( + new_node.args[0], "inf" + ): + return im.literal_from_value(True) # TODO: add tests if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py new file mode 100644 index 0000000000..0da3ff925c --- /dev/null +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import reduce + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "in"): + ret = [] + pos, domain = node.args + for i, (k, v) in enumerate( + domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items() + ): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) # TODO: avoid pos duplication + return reduce(im.and_, ret) + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f2044e4b6f..cab17e0202 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -369,6 +369,7 @@ def _infer_concat_where( **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "concat_where") + assert isinstance(domain, domain_utils.SymbolicDomain) infered_args_expr = [] actual_domains: AccessedDomains = {} cond, true_field, false_field = expr.args diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index f4422d506a..f86070c1a1 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -19,51 +19,53 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - if isinstance(node, ir.FunCall) and cpm.is_call_to( - node, ir.BINARY_MATH_COMPARISON_BUILTINS - ): - if isinstance(node.args[0], ir.AxisLiteral) and isinstance(node.args[1], ir.Literal): - dim = common.Dimension(value=node.args[0].value, kind=common.DimensionKind.VERTICAL) - value = int(node.args[1].value) + if cpm.is_call_to(node, ir.BINARY_MATH_COMPARISON_BUILTINS): # TODO: add tests + arg1, arg2 = node.args + fun = node.fun + if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): + dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + value = int(arg2.value) reverse = False - elif isinstance(node.args[0], ir.Literal) and isinstance(node.args[1], ir.AxisLiteral): - dim = common.Dimension(value=node.args[1].value, kind=common.DimensionKind.VERTICAL) - value = int(node.args[0].value) + elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): + dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + value = int(arg1.value) reverse = True else: raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") - - match node.fun.id: + assert isinstance(fun, ir.SymRef) + min_: int | str + max_: int | str + match fun.id: case ir.SymbolRef("less"): if reverse: - min = value + 1 - max = "inf" + min_ = value + 1 + max_ = "inf" else: - min = "neg_inf" - max = value - 1 + min_ = "neg_inf" + max_ = value - 1 case ir.SymbolRef("less_equal"): if reverse: - min = value - max = "inf" + min_ = value + max_ = "inf" else: - min = "neg_inf" - max = value + min_ = "neg_inf" + max_ = value case ir.SymbolRef("greater"): if reverse: - min = "neg_inf" - max = value - 1 + min_ = "neg_inf" + max_ = value - 1 else: - min = value + 1 - max = "inf" + min_ = value + 1 + max_ = "inf" case ir.SymbolRef("greater_equal"): if reverse: - min = "neg_inf" - max = value + min_ = "neg_inf" + max_ = value else: - min = value - max = "inf" + min_ = value + max_ = "inf" case ir.SymbolRef("eq"): - min = max = value + min_ = max_ = value case ir.SymbolRef("not_eq"): min1 = "neg_inf" max1 = value - 1 @@ -76,6 +78,6 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: case _: raise NotImplementedError - return im.domain(common.GridType.CARTESIAN, {dim: (min, max)}) + return im.domain(common.GridType.CARTESIAN, {dim: (min_, max_)}) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e92b4b1fd0..edc0fd0fe3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,6 +12,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + expand_library_functions, fuse_as_fieldop, global_tmps, infer_domain, @@ -19,6 +20,7 @@ inline_dynamic_shifts, inline_fundefs, inline_lifts, + transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -90,6 +92,9 @@ def apply_common_transforms( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) + # ir = ConstantFolding.apply(ir) # todo: remove for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py new file mode 100644 index 0000000000..92fff34592 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -0,0 +1,34 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() + dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] + return im.as_fieldop( + im.lambda_("pos", "a", "b")( + im.if_(im.call("in")(im.deref("pos"), cond_expr), im.deref("a"), im.deref("b")) + ) + )(im.make_tuple(*dims), field_a, field_b) + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index d0d39cbd34..4c96377895 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,7 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) - assert isinstance(domain, it_ts.DomainType) + assert isinstance(domain, ts.DomainType) assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 7e4e36da98..79831d2064 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -213,7 +213,7 @@ def concat_where( false_field: ts.FieldType | ts.TupleType, ) -> ts.FieldType: assert true_field == false_field - return true_field + return true_field # TODO: tuples? @_register_builtin_type_synthesizer @@ -259,7 +259,7 @@ def apply_lift( def _convert_as_fieldop_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec + domain: ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: # get the dimensions of all non-zero-dimensional field inputs and check they agree all_input_dims = ( @@ -299,7 +299,7 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( stencil: TypeSynthesizer, - domain: Optional[it_ts.DomainType] = None, + domain: Optional[ts.DomainType] = None, *, offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: @@ -314,7 +314,7 @@ def as_fieldop( # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` # it is unclear if the result has dimension I, J or J, I. if domain is None: - domain = it_ts.DomainType(dims="unknown") + domain = ts.DomainType(dims="unknown") @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 27e6988744..cf96bbe885 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -23,7 +23,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(KDim <= 2, boundary, interior) + return concat_where(k == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -37,6 +37,25 @@ def testee( cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(KDim <= 2, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] <= 0, boundary.asnumpy(), interior.asnumpy() + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) # TODO + + def test_boundary_horizontal_slice(cartesian_case): @gtx.field_operator def testee( From 05e74c29838defa729f66e6cc89fb01b512be8ae Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 20 Jan 2025 14:16:18 +0100 Subject: [PATCH 003/124] fix merge conflicts --- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/ffront/foast_to_gtir.py | 4 ++-- src/gt4py/next/iterator/type_system/type_synthesizer.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d55af4fa29..028761e9fa 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -129,7 +129,7 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=common.Field) +MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 82d16202c4..636222aa95 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -224,7 +224,7 @@ def visit_Assign( def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: return im.sym(node.id) - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: + def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral: if isinstance(node.type, ts.DimensionType): return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) @@ -487,7 +487,7 @@ def _map( """ # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, ts.ScalarType, ts.DimensionType) + isinstance(t, (ts.ScalarType, ts.DimensionType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 79831d2064..d740fc58a4 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -211,13 +211,13 @@ def concat_where( domain: ts.DomainType, true_field: ts.FieldType | ts.TupleType, false_field: ts.FieldType | ts.TupleType, -) -> ts.FieldType: +) -> ts.FieldType | ts.TupleType: assert true_field == false_field - return true_field # TODO: tuples? + return true_field @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) From ba8343bd9b44f3b852aa167cb190db09c6dd1fca Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 30 Jan 2025 19:32:41 +0100 Subject: [PATCH 004/124] Extend concat_where, now also working for nested concat_wheres and expressions connected by logical operators (they are resolved by NestConcatWheres): running fine in roundtrip_default, gtfn_run_gtfn and gtfn_run_gtfn_imperative --- .../ffront/foast_passes/type_deduction.py | 60 +++++++------ src/gt4py/next/ffront/foast_to_gtir.py | 23 ++++- src/gt4py/next/iterator/builtins.py | 2 - src/gt4py/next/iterator/ir.py | 10 +++ .../next/iterator/ir_utils/domain_utils.py | 10 +-- src/gt4py/next/iterator/pretty_printer.py | 6 ++ .../iterator/transforms/constant_folding.py | 50 ++++++----- src/gt4py/next/iterator/transforms/cse.py | 16 +++- .../transforms/expand_library_functions.py | 2 +- .../iterator/transforms/infer_domain_ops.py | 62 ++++++++++---- .../iterator/transforms/nest_concat_wheres.py | 36 ++++++++ .../next/iterator/transforms/pass_manager.py | 4 +- .../transforms/transform_concat_where.py | 6 +- .../next/iterator/type_system/inference.py | 6 ++ .../iterator/type_system/type_synthesizer.py | 20 ++++- .../codegens/gtfn/gtfn_ir.py | 6 +- src/gt4py/next/type_system/type_info.py | 3 +- .../ffront_tests/test_concat_where.py | 56 ++++++++++++- .../transforms_tests/test_constant_folding.py | 84 ++++++++++++++++++- .../transforms_tests/test_domain_inference.py | 8 +- 20 files changed, 381 insertions(+), 89 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/nest_concat_wheres.py diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b2437fae6..cc4b2863d7 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -11,7 +11,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits from gt4py.next import errors -from gt4py.next.common import DimensionKind +from gt4py.next.common import DimensionKind, promote_dims from gt4py.next.ffront import ( # noqa dialect_ast_enums, experimental, @@ -20,7 +20,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -574,13 +574,13 @@ def _deduce_compare_type( if ( isinstance(left.type, ts.DimensionType) and isinstance(right.type, ts.ScalarType) - and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + and right.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ): return ts.DomainType(dims=[left.type.dim]) if ( isinstance(right.type, ts.DimensionType) and isinstance(left.type, ts.ScalarType) - and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) + and left.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ): return ts.DomainType(dims=[right.type.dim]) # TODO @@ -626,7 +626,11 @@ def _deduce_binop_type( dialect_ast_enums.BinaryOperator.BIT_OR, dialect_ast_enums.BinaryOperator.BIT_XOR, } - is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic + + def tmp(arg): + return type_info.is_logical(arg) or isinstance(arg, ts.DomainType) + + is_compatible = tmp if node.op in logical_ops else type_info.is_arithmetic # check both types compatible for arg in (left, right): @@ -634,29 +638,35 @@ def _deduce_binop_type( raise errors.DSLError( arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) + if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( + right.type, (ts.ScalarType, ts.FieldType) + ): + left_type = cast(ts.FieldType | ts.ScalarType, left.type) + right_type = cast(ts.FieldType | ts.ScalarType, right.type) - left_type = cast(ts.FieldType | ts.ScalarType, left.type) - right_type = cast(ts.FieldType | ts.ScalarType, right.type) - - if node.op == dialect_ast_enums.BinaryOperator.POW: - return left_type + if node.op == dialect_ast_enums.BinaryOperator.POW: + return left_type - if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( - right_type - ): - raise errors.DSLError( - arg.location, - f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", - ) + if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( + right_type + ): + raise errors.DSLError( + arg.location, + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", + ) - try: - return type_info.promote(left_type, right_type) - except ValueError as ex: - raise errors.DSLError( - node.location, - f"Could not promote '{left_type}' and '{right_type}' to common type" - f" in call to '{node.op}'.", - ) from ex + try: + return type_info.promote(left_type, right_type) + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", + ) from ex + elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): + return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) + else: + raise ValueError("TODO") def _check_operand_dtypes_match( self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index c04afd844c..dd936d7995 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -251,7 +251,28 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: raise NotImplementedError(f"Unary operator '{node.op}' is not supported.") def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) + if ( + node.op == dialect_ast_enums.BinaryOperator.BIT_AND + and isinstance(node.left.type, ts.DomainType) + and isinstance(node.right.type, ts.DomainType) + ): + return im.and_(self.visit(node.left), self.visit(node.right)) + if ( + node.op == dialect_ast_enums.BinaryOperator.BIT_OR + and isinstance(node.left.type, ts.DomainType) + and isinstance(node.right.type, ts.DomainType) + ): + return im.or_(self.visit(node.left), self.visit(node.right)) + if ( + node.op == dialect_ast_enums.BinaryOperator.BIT_XOR + and isinstance(node.left.type, ts.DomainType) + and isinstance(node.right.type, ts.DomainType) + ): + raise NotImplementedError( + f"Binary operator '{node.op}' is not supported for '{node.right.type}' and '{node.right.type}'." + ) + else: + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: assert ( diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index a2dcb62231..4ebc9a388c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -490,8 +490,6 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "unstructured_domain", "concat_where", "in", - "inf", # TODO: discuss - "neg_inf", # TODO: discuss *ARITHMETIC_BUILTINS, *TYPE_BUILTINS, } diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index ea5cf84d86..7ccd86faab 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -63,6 +63,14 @@ class NoneLiteral(Expr): _none_literal: int = 0 +class InfinityLiteral(Expr): + pass + + +class NegInfinityLiteral(Expr): + pass + + class OffsetLiteral(Expr): value: Union[int, str] @@ -142,3 +150,5 @@ class Program(Node, ValidatedSymbolTableTrait): Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] +InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] +NegInfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index e4604bde03..e3ab788033 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -199,10 +199,10 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: dims_dict = {} for dim in domain.ranges.keys(): lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop - if lb == im.ref("neg_inf"): - dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf")) - elif ub == im.ref("inf"): - dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb) + if isinstance(lb, itir.NegInfinityLiteral): + dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral()) + elif isinstance(ub, itir.InfinityLiteral): + dims_dict[dim] = SymbolicRange(start=itir.NegInfinityLiteral(), stop=lb) else: raise ValueError("Invalid domain ranges") return SymbolicDomain(domain.grid_type, dims_dict) @@ -218,5 +218,5 @@ def promote_to_same_dimensions( lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop dims_dict[dim] = SymbolicRange(lb, ub) else: - dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf")) + dims_dict[dim] = SymbolicRange(itir.NegInfinityLiteral(), itir.InfinityLiteral()) return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 7acbf5d23d..8f29b3ce9c 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -133,6 +133,12 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]: def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]: return [str(node.value)] + def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: + return ["INF"] + + def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: + return ["NEG"] + def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 3282e98919..8444f3276b 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -26,37 +26,49 @@ def visit_FunCall(self, node: ir.FunCall): ): # `minimum(a, a)` -> `a` return new_node.args[0] - if cpm.is_call_to(new_node, "minimum"): # TODO: add tests + if cpm.is_call_to(new_node, "minimum"): # `minimum(neg_inf, neg_inf)` -> `neg_inf` - if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to( - new_node.args[1], "neg_inf" + if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( + new_node.args[1], ir.NegInfinityLiteral ): - return im.ref("neg_inf") + return ir.NegInfinityLiteral() # `minimum(inf, a)` -> `a` - elif cpm.is_ref_to(new_node.args[0], "inf"): + elif isinstance(new_node.args[0], ir.InfinityLiteral): return new_node.args[1] # `minimum(a, inf)` -> `a` - elif cpm.is_ref_to(new_node.args[1], "inf"): + elif isinstance(new_node.args[1], ir.InfinityLiteral): return new_node.args[0] - if cpm.is_call_to(new_node, "maximum"): # TODO: add tests + if cpm.is_call_to(new_node, "maximum"): # `minimum(inf, inf)` -> `inf` - if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"): - return im.ref("inf") + if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( + new_node.args[1], ir.InfinityLiteral + ): + return ir.InfinityLiteral() # `minimum(neg_inf, a)` -> `a` - elif cpm.is_ref_to(new_node.args[0], "neg_inf"): + elif isinstance(new_node.args[0], ir.NegInfinityLiteral): return new_node.args[1] # `minimum(a, neg_inf)` -> `a` - elif cpm.is_ref_to(new_node.args[1], "neg_inf"): + elif isinstance(new_node.args[1], ir.NegInfinityLiteral): return new_node.args[0] - if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to( - new_node.args[0], "neg_inf" - ): - return im.literal_from_value(True) # TODO: add tests - if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to( - new_node.args[0], "inf" - ): - return im.literal_from_value(True) # TODO: add tests + if cpm.is_call_to(new_node, ("less", "less_equal")): + if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( + new_node.args[1], ir.InfinityLiteral + ): + return im.literal_from_value(True) + if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( + new_node.args[1], ir.NegInfinityLiteral + ): + return im.literal_from_value(False) + if cpm.is_call_to(new_node, ("greater", "greater_equal")): + if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( + new_node.args[1], ir.InfinityLiteral + ): + return im.literal_from_value(False) + if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( + new_node.args[1], ir.NegInfinityLiteral + ): + return im.literal_from_value(True) if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccaaf563f5..955f428fc4 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -86,7 +86,21 @@ def _is_collectable_expr(node: itir.Node) -> bool: # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach) - if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce", "map_"]: + # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if isinstance(node.fun, itir.SymRef) and node.fun.id in [ + "lift", + "shift", + "reduce", + "map_", + "index", + ]: + return False + # do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if cpm.is_call_to(node, "make_tuple") and all( + cpm.is_call_to(arg, "index") for arg in node.args + ): return False return True elif isinstance(node, itir.Lambda): diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 0da3ff925c..9fab9e053f 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -26,7 +26,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: if cpm.is_call_to(node, "in"): ret = [] pos, domain = node.args - for i, (k, v) in enumerate( + for i, (_, v) in enumerate( domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items() ): ret.append( diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index f86070c1a1..addb4047ef 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -9,8 +9,13 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator import builtins, ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding class InferDomainOps(PreserveLocationVisitor, NodeTranslator): @@ -19,11 +24,16 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - if cpm.is_call_to(node, ir.BINARY_MATH_COMPARISON_BUILTINS): # TODO: add tests + node = self.generic_visit(node) + if ( + cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) + and any(isinstance(arg, ir.AxisLiteral) for arg in node.args) + and any(isinstance(arg, ir.Literal) for arg in node.args) + ): # TODO: add tests arg1, arg2 = node.args fun = node.fun if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): - dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + dim = common.Dimension(value=arg1.value, kind=common.DimensionKind.VERTICAL) value = int(arg2.value) reverse = False elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): @@ -33,44 +43,44 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: else: raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") assert isinstance(fun, ir.SymRef) - min_: int | str - max_: int | str + min_: int | ir.NegInfinityLiteral + max_: int | ir.InfinityLiteral match fun.id: case ir.SymbolRef("less"): if reverse: min_ = value + 1 - max_ = "inf" + max_ = ir.InfinityLiteral() else: - min_ = "neg_inf" + min_ = ir.NegInfinityLiteral() max_ = value - 1 case ir.SymbolRef("less_equal"): if reverse: min_ = value - max_ = "inf" + max_ = ir.InfinityLiteral() else: - min_ = "neg_inf" + min_ = ir.NegInfinityLiteral() max_ = value case ir.SymbolRef("greater"): if reverse: - min_ = "neg_inf" + min_ = ir.NegInfinityLiteral() max_ = value - 1 else: min_ = value + 1 - max_ = "inf" + max_ = ir.InfinityLiteral() case ir.SymbolRef("greater_equal"): if reverse: - min_ = "neg_inf" + min_ = ir.NegInfinityLiteral() max_ = value else: min_ = value - max_ = "inf" + max_ = ir.InfinityLiteral() case ir.SymbolRef("eq"): min_ = max_ = value case ir.SymbolRef("not_eq"): - min1 = "neg_inf" + min1 = ir.NegInfinityLiteral() max1 = value - 1 min2 = value + 1 - max2 = "inf" + max2 = ir.InfinityLiteral() return im.call("and_")( im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), im.domain(common.GridType.CARTESIAN, {dim: (min2, max2)}), @@ -78,6 +88,24 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: case _: raise NotImplementedError - return im.domain(common.GridType.CARTESIAN, {dim: (min_, max_)}) + return im.domain( + common.GridType.CARTESIAN, + {dim: (min_, max_ + 1)} + if not isinstance(max_, ir.InfinityLiteral) + else {dim: (min_, max_)}, + ) + if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( + isinstance(arg, (ir.Literal, ir.FunCall)) for arg in node.args + ): + if cpm.is_call_to(node, "and_"): + # TODO: domain promotion + return ConstantFolding.apply( + domain_utils.domain_intersection( + *[domain_utils.SymbolicDomain.from_expr(arg) for arg in node.args] + ).as_expr() + ) + + else: + raise NotImplementedError return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py new file mode 100644 index 0000000000..ee8197579d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -0,0 +1,36 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) + + +class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): + + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + if cpm.is_call_to(cond_expr, ("and_")): + conds = cond_expr.args + return im.concat_where(conds[0], im.concat_where(conds[1],field_a, field_b), field_b) + if cpm.is_call_to(cond_expr, ("or_")): + conds = cond_expr.args + return im.concat_where(conds[0], field_a, im.concat_where(conds[1],field_a, field_b)) + + + return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index c90a719125..43a3e98f47 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -20,6 +20,7 @@ inline_dynamic_shifts, inline_fundefs, inline_lifts, + nest_concat_wheres, transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet @@ -86,7 +87,9 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -94,7 +97,6 @@ def apply_common_transforms( ) ir = transform_concat_where.TransformConcatWhere.apply(ir) ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) - # ir = ConstantFolding.apply(ir) # todo: remove for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 92fff34592..a33cfcab5a 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -16,11 +16,14 @@ class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + @classmethod def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) if cpm.is_call_to(node, "concat_where"): cond_expr, field_a, field_b = node.args cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() @@ -28,7 +31,8 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: return im.as_fieldop( im.lambda_("pos", "a", "b")( im.if_(im.call("in")(im.deref("pos"), cond_expr), im.deref("a"), im.deref("b")) - ) + ), + node.annex.domain.as_expr(), )(im.make_tuple(*dims), field_a, field_b) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 8760c7226c..c2f25e2e89 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -579,6 +579,12 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type + def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + + def visit_NegInfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + def visit_SymRef( self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] ) -> ts.TypeSpec | type_synthesizer.TypeSynthesizer: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 403eec0955..f733a229be 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -109,10 +109,9 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_builtin_type_synthesizer( - fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS -) -def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: +def synthesize_binary_math_comparison_builtins( + lhs, rhs +) -> ts.ScalarType | ts.TupleType | ts.DomainType: if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): return ts.DomainType(dims=[rhs.dim]) if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): @@ -121,6 +120,19 @@ def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_LOGICAL_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.DomainType) and isinstance(rhs, ts.DomainType): + return ts.DomainType(dims=common.promote_dims(lhs.dims, rhs.dims)) + else: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + @_register_builtin_type_synthesizer def deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.DataType | ts.DeferredType: if isinstance(it, ts.DeferredType): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 831694791a..6ca9bde77f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -124,7 +124,11 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) + or _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, Literal)) + or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")), + el, + ) for el in value ): raise ValueError( diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 26373c647f..fa73748df6 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -304,7 +304,8 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: return ( - isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) + and isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind is ts.ScalarKind.BOOL ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index cf96bbe885..7d03eb4fc1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -9,7 +9,7 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import KDim, cartesian_case +from next_tests.integration_tests.cases import KDim, IDim, cartesian_case from gt4py import next as gtx from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases @@ -42,7 +42,7 @@ def test_dimension(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(KDim <= 2, boundary, interior) + return concat_where(KDim >= 2, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -50,10 +50,58 @@ def testee( out = cases.allocate(cartesian_case, testee, cases.RETURN)() ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] <= 0, boundary.asnumpy(), interior.asnumpy() + k.asnumpy()[np.newaxis, np.newaxis, :] >= 2, boundary.asnumpy(), interior.asnumpy() ) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_nested_conditions(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where((KDim < 2), boundary, concat_where((KDim >= 5), boundary, interior)) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + (k.asnumpy()[np.newaxis, np.newaxis, :] < 2) + | (k.asnumpy()[np.newaxis, np.newaxis, :] >= 5), + boundary.asnumpy(), + interior.asnumpy(), + ) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) # TODO +def test_dimension_two_conditions_and(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where(((KDim > 2) & (KDim <= 5)), interior, boundary) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where((k.asnumpy() > 2) & (k.asnumpy() <= 5), interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_conditions_or(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where(((KDim < 2) | (KDim >= 5)), boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where((k.asnumpy() < 2) | (k.asnumpy() >= 5), boundary.asnumpy(), interior.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) def test_boundary_horizontal_slice(cartesian_case): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index fbe9d40154..794a93090b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -8,6 +8,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.iterator import ir def test_constant_folding_boolean(): @@ -63,7 +64,86 @@ def test_constant_folding_literal_maximum(): def test_constant_folding_inf_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), im.ref("inf")) - expected = im.ref("inf") + testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = ir.InfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = ir.InfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_folding_inf_minimum(): + testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = ir.NegInfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = ir.NegInfinityLiteral() + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_greater_less(): + testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral()) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(im.literal_from_value(1), ir.NegInfinityLiteral()) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(True) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("greater")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(ir.InfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(False) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("less")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 8279d8223c..2e014ffdb8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1097,7 +1097,7 @@ def test_never_accessed_domain_tuple(offset_provider): def test_concat_where(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 4)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) testee = im.concat_where( @@ -1120,13 +1120,13 @@ def test_concat_where(offset_provider): assert expected_domains == constant_fold_accessed_domains(actual_domains) -# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 4)}) +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 4)}) # Todo: nested concat wheres def test_concat_where_two_dimensions(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) - domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 10)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 10)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) testee = im.concat_where( @@ -1177,7 +1177,7 @@ def test_concat_where_two_dimensions_J(offset_provider): def test_nested_concat_where_two_dimensions(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, "inf")}) - domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: ("neg_inf", 20)}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 20)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) From f90329eb5b1f737e13c03fafb3d669a7250a5e7d Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 31 Jan 2025 15:10:58 +0100 Subject: [PATCH 005/124] Some fixes, tuples still not supported --- .../ffront/foast_passes/type_deduction.py | 67 +++++++++++-------- src/gt4py/next/iterator/pretty_printer.py | 2 +- .../iterator/transforms/constant_folding.py | 2 +- .../iterator/transforms/infer_domain_ops.py | 6 +- .../iterator/transforms/nest_concat_wheres.py | 20 +++--- .../iterator/type_system/type_synthesizer.py | 3 +- tests/next_tests/integration_tests/cases.py | 1 + .../ffront_tests/test_concat_where.py | 63 +++++++++++++---- 8 files changed, 106 insertions(+), 58 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index cc4b2863d7..dc8d36af5e 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -571,18 +571,35 @@ def _deduce_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: # check both types compatible + left_t, right_t = left.type, right.type + integer_kind = getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) if ( - isinstance(left.type, ts.DimensionType) - and isinstance(right.type, ts.ScalarType) - and right.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + isinstance(left_t, ts.DimensionType) + and isinstance(right_t, ts.ScalarType) + and right_t.kind == integer_kind + ): + return ts.DomainType(dims=[left_t.dim]) + if ( + isinstance(right_t, ts.DimensionType) + and isinstance(left_t, ts.ScalarType) + and left_t.kind == integer_kind ): - return ts.DomainType(dims=[left.type.dim]) + return ts.DomainType(dims=[right_t.dim]) if ( - isinstance(right.type, ts.DimensionType) - and isinstance(left.type, ts.ScalarType) - and left.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + isinstance(left_t, ts.OffsetType) + and left.op == dialect_ast_enums.BinaryOperator.MOD + and isinstance(right_t, ts.ScalarType) + and right_t.kind == integer_kind + ) or ( + isinstance(right_t, ts.OffsetType) + and right.op == dialect_ast_enums.BinaryOperator.MOD + and isinstance(left_t, ts.ScalarType) + and left_t.kind == integer_kind ): - return ts.DomainType(dims=[right.type.dim]) + raise errors.DSLError( + left.location, "Type 'ts.OffsetType' can not be used in operator 'mod'." + ) + # TODO for arg in (left, right): if not type_info.is_arithmetic(arg.type): @@ -596,13 +613,13 @@ def _deduce_compare_type( # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion return type_info.promote( - with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL), - with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL), + with_altered_scalar_kind(left_t, ts.ScalarKind.BOOL), + with_altered_scalar_kind(right_t, ts.ScalarKind.BOOL), ) except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote '{left.type}' and '{right.type}' to common type" + f"Could not promote '{left_t}' and '{right_t}' to common type" f" in call to '{node.op}'.", ) from ex @@ -627,10 +644,10 @@ def _deduce_binop_type( dialect_ast_enums.BinaryOperator.BIT_XOR, } - def tmp(arg): + def is_logical_or_domain(arg: ts.TypeSpec) -> bool: return type_info.is_logical(arg) or isinstance(arg, ts.DomainType) - is_compatible = tmp if node.op in logical_ops else type_info.is_arithmetic + is_compatible = is_logical_or_domain if node.op in logical_ops else type_info.is_arithmetic # check both types compatible for arg in (left, right): @@ -641,26 +658,23 @@ def tmp(arg): if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( right.type, (ts.ScalarType, ts.FieldType) ): - left_type = cast(ts.FieldType | ts.ScalarType, left.type) - right_type = cast(ts.FieldType | ts.ScalarType, right.type) - if node.op == dialect_ast_enums.BinaryOperator.POW: - return left_type + return left.type if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( - right_type + right.type ): raise errors.DSLError( arg.location, - f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", + f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: - return type_info.promote(left_type, right_type) + return type_info.promote(left.type, right.type) except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote '{left_type}' and '{right_type}' to common type" + f"Could not promote '{left.type}' and '{right.type}' to common type" f" in call to '{node.op}'.", ) from ex elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): @@ -971,13 +985,10 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: true_branch_type = node.args[1].type false_branch_type = node.args[2].type - if true_branch_type != false_branch_type: - raise errors.DSLError( - node.location, - f"Incompatible argument in call to '{node.func!s}': expected " - f"'{true_branch_type}' and '{false_branch_type}' to be equal.", - ) - return_type = true_branch_type + true_branch_fieldtype = cast(ts.FieldType, true_branch_type) + false_branch_fieldtype = cast(ts.FieldType, false_branch_type) + promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype) + return_type = promoted_type return foast.Call( func=node.func, diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 8f29b3ce9c..1d97878257 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -137,7 +137,7 @@ def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: return ["INF"] def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: - return ["NEG"] + return ["-INF"] def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 8444f3276b..b3980e70ed 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -16,7 +16,7 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) - def visit_FunCall(self, node: ir.FunCall): + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded new_node = self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index addb4047ef..a5da214ae3 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -23,7 +23,7 @@ class InferDomainOps(PreserveLocationVisitor, NodeTranslator): def apply(cls, node: ir.Node): return cls().visit(node) - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: node = self.generic_visit(node) if ( cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) @@ -33,11 +33,11 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: arg1, arg2 = node.args fun = node.fun if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): - dim = common.Dimension(value=arg1.value, kind=common.DimensionKind.VERTICAL) + dim = common.Dimension(value=arg1.value, kind=arg1.kind) value = int(arg2.value) reverse = False elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): - dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + dim = common.Dimension(value=arg2.value, kind=arg2.kind) value = int(arg1.value) reverse = True else: diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py index ee8197579d..258494e0c4 100644 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -8,15 +8,10 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): - @classmethod def apply(cls, node: ir.Node): return cls().visit(node) @@ -27,10 +22,17 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond_expr, field_a, field_b = node.args if cpm.is_call_to(cond_expr, ("and_")): conds = cond_expr.args - return im.concat_where(conds[0], im.concat_where(conds[1],field_a, field_b), field_b) + return im.concat_where( + conds[0], im.concat_where(conds[1], field_a, field_b), field_b + ) if cpm.is_call_to(cond_expr, ("or_")): conds = cond_expr.args - return im.concat_where(conds[0], field_a, im.concat_where(conds[1],field_a, field_b)) - + return im.concat_where( + conds[0], field_a, im.concat_where(conds[1], field_a, field_b) + ) + if cpm.is_call_to(cond_expr, ("eq")): + cond1 = im.less(cond_expr.args[0], cond_expr.args[1]) + cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) + return im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f733a229be..c6f31d0a51 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -224,8 +224,7 @@ def concat_where( true_field: ts.FieldType | ts.TupleType, false_field: ts.FieldType | ts.TupleType, ) -> ts.FieldType | ts.TupleType: - assert true_field == false_field - return true_field + return type_info.promote(true_field, false_field) @_register_builtin_type_synthesizer diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 89ad556476..66330016ef 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -62,6 +62,7 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] IKField: TypeAlias = gtx.Field[[IDim, KDim], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 7d03eb4fc1..7db29bc088 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -9,8 +9,9 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import KDim, IDim, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx +from gt4py.next import errors from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -23,7 +24,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -55,6 +56,22 @@ def testee( cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension_different_dims(cartesian_case): + @gtx.field_operator + def testee(j: cases.JField, interior: cases.IJField, boundary: cases.JField) -> cases.IJField: + return concat_where(IDim >= 2, boundary, interior) + + j = cases.allocate(cartesian_case, testee, "j", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + j.asnumpy()[:, np.newaxis] >= 2, boundary.asnumpy()[np.newaxis, :], interior.asnumpy() + ) + cases.verify(cartesian_case, testee, j, interior, boundary, out=out, ref=ref) + + def test_dimension_two_nested_conditions(cartesian_case): @gtx.field_operator def testee( @@ -90,6 +107,20 @@ def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> c cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension_two_conditions_eq(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where((KDim == 2), interior, boundary) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where(k.asnumpy() == 2, interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + def test_dimension_two_conditions_or(cartesian_case): @gtx.field_operator def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: @@ -109,7 +140,7 @@ def test_boundary_horizontal_slice(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -130,7 +161,7 @@ def test_boundary_single_layer(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -147,18 +178,22 @@ def testee( def test_alternating_mask(cartesian_case): - @gtx.field_operator - def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: - return concat_where(k % 2 == 0, f1, f0) + with pytest.raises( + errors.DSLError, match=("Type 'ts.OffsetType' can not be used in operator 'mod'.") + ): - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - f0 = cases.allocate(cartesian_case, testee, "f0")() - f1 = cases.allocate(cartesian_case, testee, "f1")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() + @gtx.field_operator + def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: + return concat_where(KDim % 2 == 0, f1, f0) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + f0 = cases.allocate(cartesian_case, testee, "f0")() + f1 = cases.allocate(cartesian_case, testee, "f1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) - cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) @pytest.mark.uses_tuple_returns @@ -171,7 +206,7 @@ def testee( interior1: cases.IJKField, boundary1: cases.IJField, ) -> Tuple[cases.IJKField, cases.IJKField]: - return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior0 = cases.allocate(cartesian_case, testee, "interior0")() From b49a82d68875acb3fd54f1a19772eead93ef8cc4 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 5 Feb 2025 15:42:48 +0100 Subject: [PATCH 006/124] Some updates for concat where, which were necessary when using it in PMAP-LES, DerefCond is still a draft, but works for the tested cases --- .../iterator/transforms/constant_folding.py | 1 + .../iterator/transforms/infer_domain_ops.py | 31 +++++++++++-------- .../next/iterator/transforms/pass_manager.py | 2 ++ .../transforms/transform_concat_where.py | 23 ++++++++++++-- .../iterator/type_system/type_synthesizer.py | 8 +++-- 5 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index b3980e70ed..a8cee011d7 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,6 +12,7 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index a5da214ae3..e6ae6557dc 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -28,18 +28,24 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: if ( cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) and any(isinstance(arg, ir.AxisLiteral) for arg in node.args) - and any(isinstance(arg, ir.Literal) for arg in node.args) + and any(isinstance(arg, (ir.Literal, ir.SymRef)) for arg in node.args) ): # TODO: add tests arg1, arg2 = node.args fun = node.fun - if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): + if isinstance(arg1, ir.AxisLiteral): dim = common.Dimension(value=arg1.value, kind=arg1.kind) - value = int(arg2.value) reverse = False - elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): + if isinstance(arg2, ir.Literal): + value = int(arg2.value) + elif isinstance(arg2, ir.SymRef): + value = arg2 + elif isinstance(arg2, ir.AxisLiteral): dim = common.Dimension(value=arg2.value, kind=arg2.kind) - value = int(arg1.value) reverse = True + if isinstance(arg1, ir.Literal): + value = int(arg1.value) + elif isinstance(arg1, ir.SymRef): + value = arg1 else: raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") assert isinstance(fun, ir.SymRef) @@ -48,11 +54,11 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: match fun.id: case ir.SymbolRef("less"): if reverse: - min_ = value + 1 + min_ = im.plus(value, 1) max_ = ir.InfinityLiteral() else: min_ = ir.NegInfinityLiteral() - max_ = value - 1 + max_ = im.minus(value, 1) case ir.SymbolRef("less_equal"): if reverse: min_ = value @@ -63,9 +69,9 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: case ir.SymbolRef("greater"): if reverse: min_ = ir.NegInfinityLiteral() - max_ = value - 1 + max_ = im.minus(value, 1) else: - min_ = value + 1 + min_ = im.plus(value, 1) max_ = ir.InfinityLiteral() case ir.SymbolRef("greater_equal"): if reverse: @@ -78,8 +84,8 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: min_ = max_ = value case ir.SymbolRef("not_eq"): min1 = ir.NegInfinityLiteral() - max1 = value - 1 - min2 = value + 1 + max1 = im.minus(value, 1) + min2 = im.plus(value, 1) max2 = ir.InfinityLiteral() return im.call("and_")( im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), @@ -87,10 +93,9 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: ) case _: raise NotImplementedError - return im.domain( common.GridType.CARTESIAN, - {dim: (min_, max_ + 1)} + {dim: (min_, im.plus(max_, 1))} if not isinstance(max_, ir.InfinityLiteral) else {dim: (min_, max_)}, ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 43a3e98f47..104251aba8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -95,7 +95,9 @@ def apply_common_transforms( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = ConstantFolding.apply(ir) # TODO: remove ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = ConstantFolding.apply(ir) # TODO: remove ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index a33cfcab5a..518742ba6f 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -8,12 +8,27 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, ir_makers as im, ) +class DerefCond(PreserveLocationVisitor, NodeTranslator): + + @classmethod + def apply(cls, node: ir.Node, symbol_refs: list): + return cls().visit(node, symbol_refs=symbol_refs) + + def visit_SymRef(self, node: ir.FunCall, symbol_refs: list) -> ir.FunCall: + node = self.generic_visit(node, symbol_refs=symbol_refs) + if node.id in symbol_refs and isinstance(node.type, ts.ScalarType): + node.type = None + return im.deref(node) + return node + class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) @@ -28,11 +43,13 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond_expr, field_a, field_b = node.args cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] + refs = symbol_ref_utils.collect_symbol_refs(cond_expr) + cond_expr = DerefCond.apply(cond_expr, refs) return im.as_fieldop( - im.lambda_("pos", "a", "b")( - im.if_(im.call("in")(im.deref("pos"), cond_expr), im.deref("a"), im.deref("b")) + im.lambda_("_tcw_pos", "_tcw_arg0", "_tcw_arg1", *refs)( + im.if_(im.call("in")(im.deref("_tcw_pos"), cond_expr), im.deref("_tcw_arg0"), im.deref("_tcw_arg1")) ), node.annex.domain.as_expr(), - )(im.make_tuple(*dims), field_a, field_b) + )(im.make_tuple(*dims), field_a, field_b, *refs) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index c6f31d0a51..26cbb5cb60 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -221,9 +221,11 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer def concat_where( domain: ts.DomainType, - true_field: ts.FieldType | ts.TupleType, - false_field: ts.FieldType | ts.TupleType, -) -> ts.FieldType | ts.TupleType: + true_field: ts.FieldType | ts.TupleType | ts.DeferredType, + false_field: ts.FieldType | ts.TupleType | ts.DeferredType, +) -> ts.FieldType | ts.TupleType | ts.DeferredType: + if isinstance(true_field, ts.DeferredType) or isinstance(false_field, ts.DeferredType): + return ts.DeferredType(constraint=None) return type_info.promote(true_field, false_field) From d16bbd5d69aea6972d4e8d2e62828fe06265602d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Sat, 15 Feb 2025 14:25:19 +0100 Subject: [PATCH 007/124] ITIR type inference: store param type in Lambda --- src/gt4py/next/iterator/type_system/inference.py | 4 ++++ .../unit_tests/iterator_tests/test_type_inference.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fe450625db..cc7a7123b9 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -37,6 +37,10 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: assert type_info.is_compatible_type( node.type, type_ ), "Node already has a type which differs." + if isinstance(node, itir.Lambda): + assert isinstance(type_, ts.FunctionType) + for param, param_type in zip(node.params, type_.pos_only_args): + _set_node_type(param, param_type) node.type = type_ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a39fe3c6d8..577c7bce1c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -245,6 +245,7 @@ def test_aliased_function(): assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type ) + assert result.args[0].params[0].type == int_type assert result.type == int_type From 813f3285bbf9b3d2876544e082baaf3394569bcf Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 18 Feb 2025 15:32:25 +0100 Subject: [PATCH 008/124] Flatten as_fieldop tuple arguments --- src/gt4py/next/iterator/ir_utils/misc.py | 22 ++- .../iterator/transforms/collapse_tuple.py | 151 +++++++++++++----- .../iterator/transforms/fuse_as_fieldop.py | 25 +-- .../transforms_tests/test_collapse_tuple.py | 29 +++- 4 files changed, 165 insertions(+), 62 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 03652cdf16..e04ccd7dd3 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -12,7 +12,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @dataclasses.dataclass(frozen=True) @@ -71,3 +71,23 @@ def is_equal(a: itir.Expr, b: itir.Expr): return a == b or ( CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) ) + + +def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + + return new_expr + + return expr diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 462f87b600..d22a92faf8 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -17,39 +17,48 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, ir_makers as im, misc as ir_misc, ) -from gt4py.next.iterator.transforms import fixed_point_transformation +from gt4py.next.iterator.transforms import fixed_point_transformation, inline_lifts from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda -from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.iterator.type_system import ( + inference as itir_type_inference, + type_specifications as it_ts, +) from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): - """Given a itir.FunCall return a new call with one of its argument replaced.""" - return ir.FunCall( +def _with_altered_arg(node: itir.FunCall, arg_idx: int, new_arg: itir.Expr | str): + """Given a ititir.FunCall return a new call with one of its argument replaced.""" + return itir.FunCall( fun=node.fun, args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) -def _is_trivial_make_tuple_call(node: ir.Expr): +def _with_altered_iterator_element_type(type_: it_ts.IteratorType, new_el_type: ts.DataType): + return it_ts.IteratorType( + position_dims=type_.position_dims, defined_dims=type_.defined_dims, element_type=new_el_type + ) + + +def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False if not all( - isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) + isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_make_tuple_call(arg) for arg in node.args ): return False return True -def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: +def _is_trivial_or_tuple_thereof_expr(node: itir.Node) -> bool: """ Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. @@ -65,7 +74,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: ... ) True """ - if isinstance(node, (ir.SymRef, ir.Literal)): + if isinstance(node, (itir.SymRef, itir.Literal)): return True if cpm.is_call_to(node, "make_tuple"): return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) @@ -138,6 +147,8 @@ class Transformation(enum.Flag): PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() + #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) + FLATTEN_AS_FIELDOP_ARGS = enum.auto() @classmethod def all(self) -> CollapseTuple.Transformation: @@ -152,7 +163,7 @@ def all(self) -> CollapseTuple.Transformation: @classmethod def apply( cls, - node: ir.Node, + node: itir.Node, *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, @@ -163,7 +174,7 @@ def apply( # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, uids: Optional[eve_utils.UIDGenerator] = None, - ) -> ir.Node: + ) -> itir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -181,7 +192,7 @@ def apply( offset_provider_type = offset_provider_type or {} uids = uids or eve_utils.UIDGenerator() - if isinstance(node, ir.Program): + if isinstance(node, itir.Program): within_stencil = False assert within_stencil in [ True, @@ -220,18 +231,18 @@ def visit(self, node, **kwargs): return super().visit(node, **kwargs) def transform_collapse_make_tuple_tuple_get( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if cpm.is_call_to(node, "make_tuple") and all( cpm.is_call_to(arg, "tuple_get") for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` - assert isinstance(node.args[0], ir.FunCall) + assert isinstance(node.args[0], itir.FunCall) first_expr = node.args[0].args[1] for i, v in enumerate(node.args): - assert isinstance(v, ir.FunCall) - assert isinstance(v.args[0], ir.Literal) + assert isinstance(v, itir.FunCall) + assert isinstance(v.args[0], itir.Literal) if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree return None @@ -248,11 +259,11 @@ def transform_collapse_make_tuple_tuple_get( return None def transform_collapse_tuple_get_make_tuple( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if ( cpm.is_call_to(node, "tuple_get") - and isinstance(node.args[0], ir.Literal) + and isinstance(node.args[0], itir.Literal) and cpm.is_call_to(node.args[1], "make_tuple") ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` @@ -265,8 +276,8 @@ def transform_collapse_tuple_get_make_tuple( return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal): + def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: + if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], itir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` @@ -289,12 +300,14 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ ) return None - def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + def transform_letify_make_tuple_elements( + self, node: itir.Node, **kwargs + ) -> Optional[itir.Node]: if cpm.is_call_to(node, "make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[ir.Sym, ir.Expr] = {} - new_args: list[ir.Expr] = [] + bound_vars: dict[itir.Sym, itir.Expr] = {} + new_args: list[itir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self.uids.sequential_id(prefix="__ct_el") @@ -309,7 +322,7 @@ def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optio ) return None - def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: itir.Node, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` @@ -318,13 +331,15 @@ def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Option return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_propagate_to_if_on_tuples( + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if kwargs["within_stencil"]: # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation # in local-view for now. Revisit. return None - if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"): + if isinstance(node, itir.FunCall) and not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -343,8 +358,8 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt return None def transform_propagate_to_if_on_tuples_cps( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: # The basic idea of this transformation is to remove tuples across if-stmts by rewriting # the expression in continuation passing style, e.g. something like a tuple reordering # ``` @@ -366,7 +381,7 @@ def transform_propagate_to_if_on_tuples_cps( # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this # tuple reordering example here. - if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"): + if not isinstance(node, itir.FunCall) or cpm.is_call_to(node, "if_"): return None # The first argument that is eligible also transforms all remaining args (They will be @@ -438,7 +453,7 @@ def transform_propagate_to_if_on_tuples_cps( return None - def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_propagate_nested_let(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -464,14 +479,76 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional ) return None - def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_let(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): - if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, itir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, itir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg - if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + if any( + trivial_args := [isinstance(arg, (itir.SymRef, itir.Literal)) for arg in node.args] + ): return inline_lambda(node, eligible_params=trivial_args) return None + + # TODO(tehrengruber): This is a transformation that should be executed before visiting the children. Then + # revisiting the body would not be needed. + def transform_flatten_as_fieldop_args( + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: + if not cpm.is_applied_as_fieldop(node): + return None + + for arg in node.args: + itir_type_inference.reinfer(arg) + + if not any(isinstance(arg.type, ts.TupleType) for arg in node.args): + return None + + node = ir_misc.canonicalize_as_fieldop(node) + stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + new_body = stencil.expr + domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + orig_args_map: dict[itir.Sym, itir.Expr] = {} + new_params: list[itir.Sym] = [] + new_args: list[itir.Expr] = [] + for param, arg in zip(stencil.params, node.args, strict=True): + if isinstance(arg.type, ts.TupleType): + ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) + orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg + new_params_inner, new_args_inner = [], [] + for i, type_ in enumerate(param.type.element_type.types): + new_params_inner.append( + im.sym( + f"__ct_flat_el{i}_{param.id}", + _with_altered_iterator_element_type(param.type, type_), + ) + ) + new_args_inner.append(im.tuple_get(i, ref_to_orig_arg)) + + param_substitute = im.lift( + im.lambda_(*new_params_inner)( + im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in new_params_inner]) + ) + )(*[im.ref(p.id, p.type) for p in new_params_inner]) + + new_body = im.let(param.id, param_substitute)(new_body) + # note: the lift is trivial so inlining it is not an issue with respect to tree size + new_body = inline_lambda(new_body, force_inline_lift_args=True) + new_params.extend(new_params_inner) + new_args.extend(new_args_inner) + else: + new_params.append(param) + new_args.append(arg) + + # remove lifts again + new_body = inline_lifts.InlineLifts( + flags=inline_lifts.InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + ).visit(new_body) + new_body = self.visit(new_body, **kwargs) + + return im.let(*orig_args_map.items())( + im.as_fieldop(im.lambda_(*new_params)(new_body), domain)(*new_args) + ) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 81633dfb87..ff1a1b36e8 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -21,6 +21,7 @@ common_pattern_matcher as cpm, domain_utils, ir_makers as im, + misc as ir_misc, ) from gt4py.next.iterator.transforms import ( fixed_point_transformation, @@ -46,26 +47,6 @@ def _merge_arguments( return new_args -def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: - """ - Canonicalize applied `as_fieldop`s. - - In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified - format to work with (e.g. each parameter has a name without the need to special case). - """ - assert cpm.is_applied_as_fieldop(expr) - - stencil = expr.fun.args[0] # type: ignore[attr-defined] - domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] - if cpm.is_ref_to(stencil, "deref"): - stencil = im.lambda_("arg")(im.deref("arg")) - new_expr = im.as_fieldop(stencil, domain)(*expr.args) - - return new_expr - - return expr - - def _is_tuple_expr_of_literals(expr: itir.Expr): if cpm.is_call_to(expr, "make_tuple"): return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) @@ -78,7 +59,7 @@ def _inline_as_fieldop_arg( arg: itir.Expr, *, uids: eve_utils.UIDGenerator ) -> tuple[itir.Expr, dict[str, itir.Expr]]: assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) + arg = ir_misc.canonicalize_as_fieldop(arg) stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` inner_args: list[itir.Expr] = arg.args @@ -411,7 +392,7 @@ def transform_fuse_make_tuple(self, node: itir.Node, **kwargs): def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): if cpm.is_applied_as_fieldop(node): - node = _canonicalize_as_fieldop(node) + node = ir_misc.canonicalize_as_fieldop(node) stencil = node.fun.args[0] # type: ignore[attr-defined] # ensure cpm.is_applied_as_fieldop assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") args: list[itir.Expr] = node.args diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 916ae4e578..7813224f4d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -5,11 +5,15 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +from gt4py.next import common from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts -from next_tests.unit_tests.iterator_tests.test_type_inference import int_type +from gt4py.next.iterator.type_system import type_specifications as it_ts + +bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) def test_simple_make_tuple_tuple_get(): @@ -311,3 +315,24 @@ def test_if_make_tuple_reorder_cps_external(): within_stencil=False, ) assert actual == expected + + +def test_flatten_as_fieldop_args(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType(types=[bool_type, int_type]), + ) + testee = im.as_fieldop(im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))))( + im.make_tuple(1, 2) + ) + expected = im.as_fieldop( + im.lambda_("__ct_flat_el0_it", "__ct_flat_el1_it")(im.deref("__ct_flat_el1_it")) + )(1, 2) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected From 374546133e96dd18ee01d4ee3e3ba54cdc841493 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 02:10:34 +0100 Subject: [PATCH 009/124] Add support for scan and nested tuples --- src/gt4py/next/iterator/ir_utils/misc.py | 44 +++++++++++++ .../iterator/transforms/collapse_tuple.py | 26 ++++++-- .../iterator/transforms/fuse_as_fieldop.py | 46 +------------- .../transforms_tests/test_collapse_tuple.py | 62 ++++++++++++++++++- 4 files changed, 125 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index e04ccd7dd3..bcffd5fe51 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -91,3 +91,47 @@ def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return new_expr return expr + + +def unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = _unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index d22a92faf8..f0afd5bafc 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -12,6 +12,7 @@ import enum import functools import operator +import re from typing import Optional from gt4py import eve @@ -91,6 +92,19 @@ def _is_trivial_or_tuple_thereof_expr(node: itir.Node) -> bool: return False +def _flattened_as_fieldop_param_el_name(param: str, idx: int) -> str: + prefix = "__ct_flat_el_" + + # keep the original param name, but skip prefix from previous flattenings + if param.startswith(prefix): + parent_idx, suffix = re.split(r"_(?!\d)", param[len(prefix) :], maxsplit=1) + prefix = f"{prefix}{parent_idx}_" + else: + suffix = param + + return f"{prefix}{idx}_{suffix}" + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still @@ -508,7 +522,10 @@ def transform_flatten_as_fieldop_args( return None node = ir_misc.canonicalize_as_fieldop(node) - stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + stencil, restore_scan = ir_misc.unwrap_scan( + node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) + new_body = stencil.expr domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop orig_args_map: dict[itir.Sym, itir.Expr] = {} @@ -522,7 +539,7 @@ def transform_flatten_as_fieldop_args( for i, type_ in enumerate(param.type.element_type.types): new_params_inner.append( im.sym( - f"__ct_flat_el{i}_{param.id}", + _flattened_as_fieldop_param_el_name(param.id, i), _with_altered_iterator_element_type(param.type, type_), ) ) @@ -548,7 +565,6 @@ def transform_flatten_as_fieldop_args( flags=inline_lifts.InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT ).visit(new_body) new_body = self.visit(new_body, **kwargs) + new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) - return im.let(*orig_args_map.items())( - im.as_fieldop(im.lambda_(*new_params)(new_body), domain)(*new_args) - ) + return im.let(*orig_args_map.items())(im.as_fieldop(new_stencil, domain)(*new_args)) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index ff1a1b36e8..b746369152 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -95,50 +95,6 @@ def _inline_as_fieldop_arg( ), extracted_args -def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): - """ - If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. - - If a regular stencil is given the stencil is left as-is and the back-transformation is the - identity function. This function allows treating a scan stencil like a regular stencil during - a transformation avoiding the complexity introduced by the different IR format. - - >>> scan = im.call("scan")( - ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 - ... ) - >>> stencil, back_trafo = _unwrap_scan(scan) - >>> str(stencil) - 'λ(arg) → state + ·arg' - >>> str(back_trafo(stencil)) - 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' - - In case a regular stencil is given it is returned as-is: - - >>> deref_stencil = im.lambda_("it")(im.deref("it")) - >>> stencil, back_trafo = _unwrap_scan(deref_stencil) - >>> assert stencil == deref_stencil - """ - if cpm.is_call_to(stencil, "scan"): - scan_pass, direction, init = stencil.args - assert isinstance(scan_pass, itir.Lambda) - # remove scan pass state to be used by caller - state_param = scan_pass.params[0] - stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) - - def restore_scan(transformed_stencil_like: itir.Lambda): - new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( - im.call(transformed_stencil_like)( - *(param.id for param in transformed_stencil_like.params) - ) - ) - return im.call("scan")(new_scan_pass, direction, init) - - return stencil_like, restore_scan - - assert isinstance(stencil, itir.Lambda) - return stencil, lambda s: s - - def fuse_as_fieldop( expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator ) -> itir.Expr: @@ -146,7 +102,7 @@ def fuse_as_fieldop( stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop - stencil, restore_scan = _unwrap_scan(stencil) + stencil, restore_scan = ir_misc.unwrap_scan(stencil) domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 7813224f4d..08f1926277 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -11,7 +11,6 @@ from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts -bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) @@ -321,13 +320,70 @@ def test_flatten_as_fieldop_args(): it_type = it_ts.IteratorType( position_dims=[Vertex], defined_dims=[Vertex], - element_type=ts.TupleType(types=[bool_type, int_type]), + element_type=ts.TupleType(types=[int_type, int_type]), ) testee = im.as_fieldop(im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))))( im.make_tuple(1, 2) ) expected = im.as_fieldop( - im.lambda_("__ct_flat_el0_it", "__ct_flat_el1_it")(im.deref("__ct_flat_el1_it")) + im.lambda_("__ct_flat_el_0_it", "__ct_flat_el_1_it")(im.deref("__ct_flat_el_1_it")) + )(1, 2) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_flatten_as_fieldop_args_nested(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType( + types=[ + int_type, + ts.TupleType(types=[int_type, int_type]), + ] + ), + ) + testee = im.as_fieldop( + im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.tuple_get(1, im.deref("it")))) + )(im.make_tuple(1, im.make_tuple(2, 3))) + expected = im.as_fieldop( + im.lambda_("__ct_flat_el_0_it", "__ct_flat_el_1_0_it", "__ct_flat_el_1_1_it")( + im.deref("__ct_flat_el_1_1_it") + ) + )(1, 2, 3) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_flatten_as_fieldop_args_scan(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType(types=[int_type, int_type]), + ) + testee = im.as_fieldop( + im.scan( + im.lambda_("state", im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))), True, 0 + ) + )(im.make_tuple(1, 2)) + expected = im.as_fieldop( + im.scan( + im.lambda_("state", "__ct_flat_el_0_it", "__ct_flat_el_1_it")( + im.deref("__ct_flat_el_1_it") + ), + True, + 0, + ) )(1, 2) actual = CollapseTuple.apply( testee, From 06806fbc468b6a1344b05e05ff0af7f09d7a8a0a Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 02:39:41 +0100 Subject: [PATCH 010/124] Preserve annex on new nodes --- src/gt4py/eve/visitors.py | 28 +++++++++++++++------ tests/eve_tests/unit_tests/test_visitors.py | 19 ++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 59b4ef0881..e86174427e 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -121,6 +121,18 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: return None +def _preserve_annex( + node: concepts.Node, new_node: concepts.Node, preserved_annex_attrs: tuple[str, ...] +) -> None: + if preserved_annex_attrs and (old_annex := getattr(node, "__node_annex__", None)): + # access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter + new_annex_dict = new_node.annex.__dict__ + for key in preserved_annex_attrs: + if (value := getattr(old_annex, key, NOTHING)) is not NOTHING: + assert key not in new_annex_dict or new_annex_dict[key] == value + new_annex_dict[key] = value + + class NodeTranslator(NodeVisitor): """Special `NodeVisitor` to translate nodes and trees. @@ -158,13 +170,7 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: if (new_child := self.visit(child, **kwargs)) is not NOTHING } ) - if self.PRESERVED_ANNEX_ATTRS and (old_annex := getattr(node, "__node_annex__", None)): - # access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter - new_annex_dict = new_node.annex.__dict__ - for key in self.PRESERVED_ANNEX_ATTRS: - if (value := getattr(old_annex, key, NOTHING)) is not NOTHING: - assert key not in new_annex_dict - new_annex_dict[key] = value + _preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS) return new_node @@ -189,3 +195,11 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return copy.deepcopy(node, memo=memo) + + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + new_node = super().visit(node, **kwargs) + + if isinstance(node, concepts.Node) and isinstance(new_node, concepts.Node): + _preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS) + + return new_node diff --git a/tests/eve_tests/unit_tests/test_visitors.py b/tests/eve_tests/unit_tests/test_visitors.py index 2eff41d749..0101aa38d8 100644 --- a/tests/eve_tests/unit_tests/test_visitors.py +++ b/tests/eve_tests/unit_tests/test_visitors.py @@ -8,6 +8,8 @@ from __future__ import annotations +import copy + from gt4py import eve @@ -24,3 +26,20 @@ class SampleTranslator(eve.NodeTranslator): assert translated_node.annex.foo == 1 assert translated_node.annex.bar is None assert not hasattr(translated_node.annex, "baz") + + +def test_annex_preservation_translated_node(compound_node: eve.Node): + compound_node.annex.foo = 1 + compound_node.annex.baz = 2 + + class SampleTranslator(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("foo",) + + def visit_Node(self, node: eve.Node): + # just return an empty node, we care about the annex only anyway + return eve.Node() + + translated_node = SampleTranslator().visit(compound_node) + + assert translated_node.annex.foo == 1 + assert not hasattr(translated_node.annex, "baz") From bab4fe10f9f35151b0f47b7e0e60a7861f7f2fac Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 02:49:15 +0100 Subject: [PATCH 011/124] Fix unnecessary import --- tests/eve_tests/unit_tests/test_visitors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/eve_tests/unit_tests/test_visitors.py b/tests/eve_tests/unit_tests/test_visitors.py index 0101aa38d8..28e83a885d 100644 --- a/tests/eve_tests/unit_tests/test_visitors.py +++ b/tests/eve_tests/unit_tests/test_visitors.py @@ -8,8 +8,6 @@ from __future__ import annotations -import copy - from gt4py import eve From 14b4bf3de45150a752202ae58c229e492640e182 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 02:51:29 +0100 Subject: [PATCH 012/124] Cleanup --- src/gt4py/next/ffront/experimental.py | 10 +++++++--- src/gt4py/next/ffront/fbuiltins.py | 10 ++++++---- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 3 --- src/gt4py/next/iterator/transforms/pass_manager.py | 10 ++++++---- src/gt4py/next/iterator/transforms/trace_shifts.py | 6 ++++++ .../next/iterator/transforms/transform_concat_where.py | 4 +--- .../unit_tests/iterator_tests/test_type_inference.py | 8 ++++---- 7 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index c9bea908a8..6d477ad015 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -6,11 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple +from typing import Tuple, TypeVar from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction +from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereLikeBuiltinFunction, FieldT @BuiltInFunction @@ -18,7 +18,11 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi raise NotImplementedError() -@WhereBuiltinFunction +_R = TypeVar("_R") +DomainT = TypeVar("DomainT", bound=common.Field) +ConcatWhereBuiltinFunction = WhereLikeBuiltinFunction[_R, DomainT, FieldT] + +@ConcatWhereBuiltinFunction def concat_where( mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b50aaadc1e..927758b5d6 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -137,14 +137,14 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) +MaskLikeT = TypeVar("MaskLikeT", bound=common.Field) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) -class WhereBuiltinFunction( - BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] +class WhereLikeBuiltinFunction( + BuiltInFunction[_R, [MaskLikeT, FieldT, FieldT]], Generic[_R, MaskLikeT, FieldT], ): - def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: + def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( @@ -158,6 +158,8 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: return tuple(self(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) +MaskT = TypeVar("MaskT", bound=common.Field) +WhereBuiltinFunction = WhereLikeBuiltinFunction[_R, MaskT, FieldT] @BuiltInFunction def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field: diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index b746369152..88d2587e84 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -419,7 +419,4 @@ def visit(self, node, **kwargs): node = super().visit(node, **kwargs) - if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): - node.annex.domain = node.annex.domain - return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 38f44e6ab7..fdb144cffe 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -94,15 +94,17 @@ def apply_common_transforms( ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) + + ir = ConstantFolding.apply(ir) # TODO: remove + ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = ConstantFolding.apply(ir) # TODO: remove + ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) + ir = infer_domain.infer_program( ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) - ir = ConstantFolding.apply(ir) # TODO: remove - ir = transform_concat_where.TransformConcatWhere.apply(ir) - ir = ConstantFolding.apply(ir) # TODO: remove - ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 0648df8363..8dc3d46b24 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -274,6 +274,12 @@ class TraceShifts(PreserveLocationVisitor, NodeTranslator): def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: return Sentinel.VALUE + def visit_InfinityLiteral(self, node: ir.SymRef, *, ctx: dict[str, Any]): + return Sentinel.VALUE + + def visit_NegInfinityLiteral(self, node: ir.SymRef, *, ctx: dict[str, Any]): + return Sentinel.VALUE + def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 518742ba6f..ffa20908ab 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -31,8 +31,6 @@ def visit_SymRef(self, node: ir.FunCall, symbol_refs: list) -> ir.FunCall: class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ("domain",) - @classmethod def apply(cls, node: ir.Node): return cls().visit(node) @@ -44,12 +42,12 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] refs = symbol_ref_utils.collect_symbol_refs(cond_expr) + # TODO: this deref pass is not correct cond_expr = DerefCond.apply(cond_expr, refs) return im.as_fieldop( im.lambda_("_tcw_pos", "_tcw_arg0", "_tcw_arg1", *refs)( im.if_(im.call("in")(im.deref("_tcw_pos"), cond_expr), im.deref("_tcw_arg0"), im.deref("_tcw_arg1")) ), - node.annex.domain.as_expr(), )(im.make_tuple(*dims), field_a, field_b, *refs) return self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 336551d4a9..6a0fe82f8f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -92,7 +92,7 @@ def expression_test_cases(): im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ), - it_ts.DomainType(dims=[IDim]), + ts.DomainType(dims=[IDim]), ), ( im.call("unstructured_domain")( @@ -100,7 +100,7 @@ def expression_test_cases(): itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ) ), - it_ts.DomainType(dims=[Vertex]), + ts.DomainType(dims=[Vertex]), ), # make_tuple ( @@ -298,7 +298,7 @@ def test_cartesian_fencil_definition(): program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) assert result.type == program_type - domain_type = it_ts.DomainType(dims=[IDim]) + domain_type = ts.DomainType(dims=[IDim]) assert result.body[0].domain.type == domain_type assert result.body[0].expr.type == float_i_field assert result.body[0].target.type == float_i_field @@ -337,7 +337,7 @@ def test_unstructured_fencil_definition(): params={"inp": float_edge_k_field, "out": float_vertex_k_field} ) assert result.type == program_type - domain_type = it_ts.DomainType(dims=[Vertex, KDim]) + domain_type = ts.DomainType(dims=[Vertex, KDim]) assert result.body[0].domain.type == domain_type assert result.body[0].expr.type == float_vertex_k_field assert result.body[0].target.type == float_vertex_k_field From fc20d7c5e2ff4d71abfd9c70f07434b8ab9a6b0f Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 12:11:33 +0100 Subject: [PATCH 013/124] Fix doctest --- src/gt4py/next/iterator/ir_utils/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index bcffd5fe51..00a1ab5609 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -104,7 +104,7 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): >>> scan = im.call("scan")( ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 ... ) - >>> stencil, back_trafo = _unwrap_scan(scan) + >>> stencil, back_trafo = unwrap_scan(scan) >>> str(stencil) 'λ(arg) → state + ·arg' >>> str(back_trafo(stencil)) @@ -113,7 +113,7 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): In case a regular stencil is given it is returned as-is: >>> deref_stencil = im.lambda_("it")(im.deref("it")) - >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> stencil, back_trafo = unwrap_scan(deref_stencil) >>> assert stencil == deref_stencil """ if cpm.is_call_to(stencil, "scan"): From c5fba83a526a8fe8c19b69744391785d9a022792 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 12:48:16 +0100 Subject: [PATCH 014/124] Fix failing tests --- src/gt4py/next/iterator/transforms/infer_domain.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f3c3185225..f662d055fc 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -21,7 +21,7 @@ domain_utils, ir_makers as im, ) -from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms import constant_folding, trace_shifts from gt4py.next.utils import flatten_nested_tuple, tree_map @@ -436,8 +436,12 @@ def _infer_stmt( **kwargs: Unpack[InferenceOptions], ): if isinstance(stmt, itir.SetAt): + # constant fold once otherwise constant folding after domain inference might create (syntactic) differences + # between the domain stored in IR and in the annex + domain = constant_folding.ConstantFolding.apply(stmt.domain) + transformed_call, _ = infer_expr( - stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs + stmt.expr, domain_utils.SymbolicDomain.from_expr(domain), **kwargs ) return itir.SetAt( From 04ae43074b89d5df36589d36e5e3ba0380637560 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 16:10:53 +0100 Subject: [PATCH 015/124] Fix tests --- src/gt4py/eve/visitors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index e86174427e..e329246ce7 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -129,7 +129,10 @@ def _preserve_annex( new_annex_dict = new_node.annex.__dict__ for key in preserved_annex_attrs: if (value := getattr(old_annex, key, NOTHING)) is not NOTHING: - assert key not in new_annex_dict or new_annex_dict[key] == value + # note: the annex value of the new node might not be equal (in + # the sense that the equality comparison is false), but in + # the context of the pass they are equalivalent. Therefore we don't + # assert equality here. new_annex_dict[key] = value From 5136adc6cf50c31d256458f71d168eb517b3edd4 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 16:13:58 +0100 Subject: [PATCH 016/124] Fix tests --- src/gt4py/eve/visitors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index e329246ce7..bfffe8a6fc 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -129,10 +129,10 @@ def _preserve_annex( new_annex_dict = new_node.annex.__dict__ for key in preserved_annex_attrs: if (value := getattr(old_annex, key, NOTHING)) is not NOTHING: - # note: the annex value of the new node might not be equal (in - # the sense that the equality comparison is false), but in - # the context of the pass they are equalivalent. Therefore we don't - # assert equality here. + # Note: The annex value of the new node might not be equal + # (in the sense that an equality comparison returns false), + # but in the context of the pass, they are equivalent. + # Therefore, we don't assert equality here. new_annex_dict[key] = value From 59396182d150d0d348cf22bfb4455bf13e909b9c Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 18:48:36 +0100 Subject: [PATCH 017/124] Cleanup frontend type deduction --- .../ffront/foast_passes/type_deduction.py | 127 ++++++++++-------- src/gt4py/next/type_system/type_info.py | 4 + .../next/type_system/type_specifications.py | 3 + .../ffront_tests/test_type_deduction.py | 38 +++++- 4 files changed, 111 insertions(+), 61 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index dc8d36af5e..bc62d34fe5 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -567,62 +567,67 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> foast.Compare: op=node.op, left=new_left, right=new_right, location=node.location, type=new_type ) - def _deduce_compare_type( + def _deduce_arithmetic_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: - # check both types compatible - left_t, right_t = left.type, right.type - integer_kind = getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) - if ( - isinstance(left_t, ts.DimensionType) - and isinstance(right_t, ts.ScalarType) - and right_t.kind == integer_kind - ): - return ts.DomainType(dims=[left_t.dim]) - if ( - isinstance(right_t, ts.DimensionType) - and isinstance(left_t, ts.ScalarType) - and left_t.kind == integer_kind - ): - return ts.DomainType(dims=[right_t.dim]) - if ( - isinstance(left_t, ts.OffsetType) - and left.op == dialect_ast_enums.BinaryOperator.MOD - and isinstance(right_t, ts.ScalarType) - and right_t.kind == integer_kind - ) or ( - isinstance(right_t, ts.OffsetType) - and right.op == dialect_ast_enums.BinaryOperator.MOD - and isinstance(left_t, ts.ScalarType) - and left_t.kind == integer_kind - ): - raise errors.DSLError( - left.location, "Type 'ts.OffsetType' can not be used in operator 'mod'." - ) - - # TODO - for arg in (left, right): - if not type_info.is_arithmetic(arg.type): - raise errors.DSLError( - arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." - ) - + # e.g. `1 < 2` self._check_operand_dtypes_match(node, left=left, right=right) try: # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion return type_info.promote( - with_altered_scalar_kind(left_t, ts.ScalarKind.BOOL), - with_altered_scalar_kind(right_t, ts.ScalarKind.BOOL), + with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL), + with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL), ) except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote '{left_t}' and '{right_t}' to common type" + f"Could not promote '{left.type}' and '{right.type}' to common type" f" in call to '{node.op}'.", ) from ex + def _deduce_dimension_compare_type( + self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any + ) -> Optional[ts.TypeSpec]: + # e.g. `IDim > 1` + index_type = ts.ScalarType( + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) + + if isinstance(left.type, ts.DimensionType): + if not right.type == index_type: + raise errors.DSLError( + right.location, + f"Expected an {index_type}, but got '{right.type}' instead.", + ) + return ts.DomainType(dims=[left.type.dim]) + elif isinstance(right.type, ts.DimensionType): + if not left.type == index_type: + raise errors.DSLError( + left.location, + f"Expected an {index_type}, but got '{right.type}' instead.", + ) + return ts.DomainType(dims=[right.type.dim]) + else: + raise AssertionError() + + def _deduce_compare_type( + self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any + ) -> Optional[ts.TypeSpec]: + # e.g. `1 < 1` + if all(type_info.is_arithmetic(arg) for arg in (left.type, right.type)): + return self._deduce_arithmetic_compare_type(node, left=left, right=right) + # e.g. `IDim > 1` + if any(isinstance(arg, ts.DimensionType) for arg in (left.type, right.type)): + return self._deduce_dimension_compare_type(node, left=left, right=right) + + raise errors.DSLError( + left.location, "Comparison operators can only be used between arithmetic types " + "(scalars, fields) or between a dimension and an index type " + "({builtins.INTEGER_INDEX_BUILTIN})." + ) + + def _deduce_binop_type( self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: @@ -644,20 +649,16 @@ def _deduce_binop_type( dialect_ast_enums.BinaryOperator.BIT_XOR, } - def is_logical_or_domain(arg: ts.TypeSpec) -> bool: - return type_info.is_logical(arg) or isinstance(arg, ts.DomainType) + err_msg = f"Unsupported operand type(s) for {node.op}: '{left.type}' and '{right.type}'." - is_compatible = is_logical_or_domain if node.op in logical_ops else type_info.is_arithmetic - - # check both types compatible - for arg in (left, right): - if not is_compatible(arg.type): - raise errors.DSLError( - arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." - ) if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( right.type, (ts.ScalarType, ts.FieldType) ): + is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic + for arg in (left, right): + if not is_compatible(arg.type): + raise errors.DSLError(arg.location, err_msg) + if node.op == dialect_ast_enums.BinaryOperator.POW: return left.type @@ -678,9 +679,12 @@ def is_logical_or_domain(arg: ts.TypeSpec) -> bool: f" in call to '{node.op}'.", ) from ex elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): + if node.op not in logical_ops: + raise errors.DSLError(node.location, f"{err_msg} Operator " + f"must be one of {', '.join((str(op) for op in logical_ops))}.") return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) else: - raise ValueError("TODO") + raise errors.DSLError(node.location, err_msg) def _check_operand_dtypes_match( self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr @@ -983,12 +987,19 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: - true_branch_type = node.args[1].type - false_branch_type = node.args[2].type - true_branch_fieldtype = cast(ts.FieldType, true_branch_type) - false_branch_fieldtype = cast(ts.FieldType, false_branch_type) - promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype) - return_type = promoted_type + mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) + + assert isinstance(mask_type, ts.DomainType) + assert all(isinstance(arg, (ts.FieldType, ts.ScalarType)) for arg in (true_branch_type, false_branch_type)) + + if (t_dtype := type_info.extract_dtype(true_branch_type)) != ( + f_dtype := type_info.extract_dtype(false_branch_type)): + raise errors.DSLError( + node.location, + f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'." + ) + + return_type = type_info.promote(mask_type, true_branch_type, false_branch_type) return foast.Call( func=node.func, diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 370ecfb998..0ba7427073 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -235,6 +235,8 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ + if not isinstance(symbol_type, (ts.ScalarType, ts.FieldType)): + return False return isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind in [ ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64, @@ -278,6 +280,8 @@ def is_integral(symbol_type: ts.TypeSpec) -> bool: >>> is_integral(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ + if not isinstance(symbol_type, (ts.ScalarType, ts.FieldType)): + return False return is_integer(extract_dtype(symbol_type)) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index abc62885ae..839e89bd34 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -49,6 +49,9 @@ class VoidType(TypeSpec): class DimensionType(TypeSpec): dim: common.Dimension + def __str__(self) -> str: + return str(self.dim) + class OffsetType(TypeSpec): # TODO(havogt): replace by ConnectivityType diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 254772fd8a..7118054c8b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -28,6 +28,7 @@ neighbor_sum, where, ) +from gt4py.next.ffront.experimental import concat_where from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -75,7 +76,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") + errors.DSLError, match=(re.escape("Unsupported operand type(s) for +: 'Field[[TDim], bool]' and 'Field[[TDim], bool]'.")) ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -97,13 +98,13 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): _ = FieldOperatorParser.apply_to_function(nonmatching) -def test_bitopping_float(): +def test_bitop_float(): def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( errors.DSLError, - match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), + match=re.escape("Unsupported operand type(s) for &: 'Field[[TDim], float64]' and 'Field[[TDim], float64]'."), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -130,6 +131,37 @@ def not_int(a: Field[[TDim], int64]): _ = FieldOperatorParser.apply_to_function(not_int) +def test_concat_where(): + def simple_concat_where(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > 0, a, b) + + parsed = FieldOperatorParser.apply_to_function(simple_concat_where) + compare_node = parsed.body.stmts[0].value.args[0] + assert compare_node.type == ts.DomainType(dims=[TDim]) + + +def test_domain_comparison_failure(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > 1., a, b) + + with pytest.raises( + errors.DSLError, + match=(r"Expected an int32, but got 'float64' instead."), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + +def test_concat_where_invalid_dtype(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > 0, 1., 2) + + with pytest.raises( + errors.DSLError, + match=re.escape("Field arguments must be of same dtype, got 'float64' != 'int32'."), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + @pytest.fixture def premap_setup(): X = Dimension("X") From 157b0e2205e51541e30c6b81e487c7a904a4bde8 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 18:52:18 +0100 Subject: [PATCH 018/124] Cleanup frontend type deduction --- src/gt4py/next/ffront/foast_passes/type_deduction.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index bc62d34fe5..fb2190172e 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -999,7 +999,14 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'." ) - return_type = type_info.promote(mask_type, true_branch_type, false_branch_type) + return_dims = promote_dims( + mask_type.dims, + type_info.promote(true_branch_type, false_branch_type).dims + ) + return_type = ts.FieldType( + dims=return_dims, + dtype=type_info.promote(t_dtype, f_dtype) + ) return foast.Call( func=node.func, From 435d057a9bf4f6c8d7f78c50d2c8ebc2d8716487 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:09:59 +0100 Subject: [PATCH 019/124] Cleanup concat where: - Cleanup infinity literal ir node and constant folding - Improve testing - Fix domain complement - Simplify & cleanup lowering, domain ops pass --- src/gt4py/next/ffront/foast_to_gtir.py | 27 +--- src/gt4py/next/iterator/ir.py | 17 +- .../next/iterator/ir_utils/domain_utils.py | 12 +- src/gt4py/next/iterator/pretty_printer.py | 11 +- .../iterator/transforms/constant_folding.py | 97 ++++++------ .../iterator/transforms/infer_domain_ops.py | 137 ++++++++--------- tests/next_tests/definitions.py | 9 +- tests/next_tests/integration_tests/cases.py | 62 ++++---- .../ffront_tests/test_concat_where.py | 145 ++++++++---------- .../ffront_tests/test_type_deduction.py | 13 +- .../transforms_tests/test_constant_folding.py | 40 ++--- .../transforms_tests/test_domain_inference.py | 12 +- 12 files changed, 290 insertions(+), 292 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index dd936d7995..514c7526c7 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -251,28 +251,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: raise NotImplementedError(f"Unary operator '{node.op}' is not supported.") def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - if ( - node.op == dialect_ast_enums.BinaryOperator.BIT_AND - and isinstance(node.left.type, ts.DomainType) - and isinstance(node.right.type, ts.DomainType) - ): - return im.and_(self.visit(node.left), self.visit(node.right)) - if ( - node.op == dialect_ast_enums.BinaryOperator.BIT_OR - and isinstance(node.left.type, ts.DomainType) - and isinstance(node.right.type, ts.DomainType) - ): - return im.or_(self.visit(node.left), self.visit(node.right)) - if ( - node.op == dialect_ast_enums.BinaryOperator.BIT_XOR - and isinstance(node.left.type, ts.DomainType) - and isinstance(node.right.type, ts.DomainType) - ): - raise NotImplementedError( - f"Binary operator '{node.op}' is not supported for '{node.right.type}' and '{node.right.type}'." - ) - else: - return self._lower_and_map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: assert ( @@ -284,7 +263,6 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - # TODO: double-check if we need the changes in the original PR return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -506,9 +484,8 @@ def _map( """ Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ - # TODO double-check that this code is consistent with the changes in the original PR if all( - isinstance(t, (ts.ScalarType, ts.DimensionType)) + isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 7ccd86faab..e40b8edbe1 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -5,8 +5,9 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations -from typing import ClassVar, List, Optional, Union +from typing import ClassVar, List, Optional, Union, TYPE_CHECKING import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef @@ -62,13 +63,18 @@ class Literal(Expr): class NoneLiteral(Expr): _none_literal: int = 0 - class InfinityLiteral(Expr): - pass + if TYPE_CHECKING: + POSITIVE: ClassVar[InfinityLiteral] # TODO(tehrengruber): should be `ClassVar[InfinityLiteral]`, but self-referential not supported in eve + NEGATIVE: ClassVar[InfinityLiteral] + + name: typing.Literal["POSITIVE", "NEGATIVE"] + def __str__(self): + return f"{type(self).__name__}.{self.name}" -class NegInfinityLiteral(Expr): - pass +InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE") +InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") class OffsetLiteral(Expr): @@ -151,4 +157,3 @@ class Program(Node, ValidatedSymbolTableTrait): SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] -NegInfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index e3ab788033..efd8ebb2c0 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -199,10 +199,12 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: dims_dict = {} for dim in domain.ranges.keys(): lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop - if isinstance(lb, itir.NegInfinityLiteral): - dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral()) - elif isinstance(ub, itir.InfinityLiteral): - dims_dict[dim] = SymbolicRange(start=itir.NegInfinityLiteral(), stop=lb) + # `]-inf, a[` -> `[a, inf[` + if lb == itir.InfinityLiteral.NEGATIVE: + dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral.POSITIVE) + # `[a, inf]` -> `]-inf, a]` + elif ub == itir.InfinityLiteral.POSITIVE: + dims_dict[dim] = SymbolicRange(start=itir.InfinityLiteral.NEGATIVE, stop=lb) else: raise ValueError("Invalid domain ranges") return SymbolicDomain(domain.grid_type, dims_dict) @@ -218,5 +220,5 @@ def promote_to_same_dimensions( lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop dims_dict[dim] = SymbolicRange(lb, ub) else: - dims_dict[dim] = SymbolicRange(itir.NegInfinityLiteral(), itir.InfinityLiteral()) + dims_dict[dim] = SymbolicRange(itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE) return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 1d97878257..5063e26392 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -133,11 +133,12 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]: def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]: return [str(node.value)] - def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: - return ["INF"] - - def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: - return ["-INF"] + def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[str]: + if node == ir.InfinityLiteral.POSITIVE: + return ["∞"] + elif node == ir.InfinityLiteral.NEGATIVE: + return ["-∞"] + raise AssertionError() def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 0286acac88..cec465ad68 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -31,49 +31,58 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: ): # `minimum(a, a)` -> `a` return new_node.args[0] + if cpm.is_call_to(new_node, "plus"): + a, b = new_node.args + for arg, other_arg in ((a, b), (b, a)): + # `a + inf` -> `inf` + if arg == ir.InfinityLiteral.POSITIVE: + return ir.InfinityLiteral.POSITIVE + # `a + (-inf)` -> `-inf` + if arg == ir.InfinityLiteral.NEGATIVE: + return ir.InfinityLiteral.NEGATIVE + if cpm.is_call_to(new_node, "minimum"): - # `minimum(neg_inf, neg_inf)` -> `neg_inf` - if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( - new_node.args[1], ir.NegInfinityLiteral - ): - return ir.NegInfinityLiteral() - # `minimum(inf, a)` -> `a` - elif isinstance(new_node.args[0], ir.InfinityLiteral): - return new_node.args[1] - # `minimum(a, inf)` -> `a` - elif isinstance(new_node.args[1], ir.InfinityLiteral): - return new_node.args[0] + a, b = new_node.args + for arg, other_arg in ((a, b), (b, a)): + # `minimum(inf, a)` -> `a` + if arg == ir.InfinityLiteral.POSITIVE: + return other_arg + # `minimum(-inf, a)` -> `-inf` + if arg == ir.InfinityLiteral.NEGATIVE: + return ir.InfinityLiteral.NEGATIVE if cpm.is_call_to(new_node, "maximum"): - # `minimum(inf, inf)` -> `inf` - if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( - new_node.args[1], ir.InfinityLiteral - ): - return ir.InfinityLiteral() - # `minimum(neg_inf, a)` -> `a` - elif isinstance(new_node.args[0], ir.NegInfinityLiteral): - return new_node.args[1] - # `minimum(a, neg_inf)` -> `a` - elif isinstance(new_node.args[1], ir.NegInfinityLiteral): - return new_node.args[0] + a, b = new_node.args + for arg, other_arg in ((a, b), (b, a)): + # `maximum(inf, a)` -> `inf` + if arg == ir.InfinityLiteral.POSITIVE: + return ir.InfinityLiteral.POSITIVE + # `maximum(-inf, a)` -> `a` + if arg == ir.InfinityLiteral.NEGATIVE: + return other_arg + if cpm.is_call_to(new_node, ("less", "less_equal")): - if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( - new_node.args[1], ir.InfinityLiteral - ): + a, b = new_node.args + # `-inf < v` -> `True` + # `v < inf` -> `True` + if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: return im.literal_from_value(True) - if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( - new_node.args[1], ir.NegInfinityLiteral - ): + # `inf < v` -> `False` + # `v < -inf ` -> `False` + if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: return im.literal_from_value(False) + if cpm.is_call_to(new_node, ("greater", "greater_equal")): - if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance( - new_node.args[1], ir.InfinityLiteral - ): - return im.literal_from_value(False) - if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance( - new_node.args[1], ir.NegInfinityLiteral - ): + a, b = new_node.args + # `inf > v` -> `True` + # `v > -inf ` -> `True` + if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: return im.literal_from_value(True) + # `-inf > v` -> `False` + # `v > inf` -> `False` + if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: + return im.literal_from_value(False) + if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" @@ -90,15 +99,13 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: and len(new_node.args) > 0 and all(isinstance(arg, ir.Literal) for arg in new_node.args) ): # `1 + 1` -> `2` - try: - if new_node.fun.id in builtins.ARITHMETIC_BUILTINS: - fun = getattr(embedded, str(new_node.fun.id)) - arg_values = [ - getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition - for arg in new_node.args - ] - new_node = im.literal_from_value(fun(*arg_values)) - except ValueError: - pass # happens for SymRefs which are not inf or neg_inf + if new_node.fun.id in builtins.ARITHMETIC_BUILTINS: + fun = getattr(embedded, str(new_node.fun.id)) + arg_values = [ + getattr(embedded, str(arg.type))(arg.value) + # type: ignore[attr-defined] # arg type already established in if condition + for arg in new_node.args + ] + new_node = im.literal_from_value(fun(*arg_values)) return new_node diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index e6ae6557dc..db51aa2888 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -15,90 +15,83 @@ domain_utils, ir_makers as im, ) +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.ir_utils.domain_utils import domain_complement from gt4py.next.iterator.transforms.constant_folding import ConstantFolding class InferDomainOps(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node): - return cls().visit(node) + return cls().visit(node, recurse=True) - def visit_FunCall(self, node: ir.FunCall) -> ir.Node: - node = self.generic_visit(node) + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + if kwargs["recurse"]: + node = self.generic_visit(node, **kwargs) + + # IDim < a if ( cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) and any(isinstance(arg, ir.AxisLiteral) for arg in node.args) - and any(isinstance(arg, (ir.Literal, ir.SymRef)) for arg in node.args) + and any(isinstance(arg, ir.Expr) for arg in node.args) ): # TODO: add tests arg1, arg2 = node.args - fun = node.fun - if isinstance(arg1, ir.AxisLiteral): - dim = common.Dimension(value=arg1.value, kind=arg1.kind) - reverse = False - if isinstance(arg2, ir.Literal): - value = int(arg2.value) - elif isinstance(arg2, ir.SymRef): - value = arg2 - elif isinstance(arg2, ir.AxisLiteral): - dim = common.Dimension(value=arg2.value, kind=arg2.kind) - reverse = True - if isinstance(arg1, ir.Literal): - value = int(arg1.value) - elif isinstance(arg1, ir.SymRef): - value = arg1 + if isinstance(arg2, ir.AxisLiteral): + # take complementary operation if we have e.g. `IDim > 1` use `1 <= IDim` + complementary_op = { + "less": "greater_equal", + "less_equal": "greater", + "greater": "less_equal", + "greater_equal": "less", + "eq": "eq", + "not_eq": "not_eq", + } + return self.visit(im.call(complementary_op[node.fun.id])(arg2, arg1), **{**kwargs, "recurse":False}) + + assert isinstance(arg1.type, ts.DimensionType) + dim: common.Dimension = arg1.type.dim + value: ir.Expr = arg2 + + if cpm.is_call_to(node, ("less", "less_equal", "greater", "greater_equal", "eq")): + min_: int | ir.InfinityLiteral + max_: int | ir.InfinityLiteral + + # IDim < 1 + if cpm.is_call_to(node, "less"): + min_ = ir.InfinityLiteral.NEGATIVE + max_ = value + # IDim <= 1 + elif cpm.is_call_to(node, "less_equal"): + min_ = ir.InfinityLiteral.NEGATIVE + max_ = im.plus(value, 1) + # IDim > 1 + elif cpm.is_call_to(node, "greater"): + min_ = im.plus(value, 1) + max_ = ir.InfinityLiteral.POSITIVE + # IDim >= 1 + elif cpm.is_call_to(node, "greater_equal"): + min_ = value + max_ = ir.InfinityLiteral.POSITIVE + # IDim == 1 + elif cpm.is_call_to(node, "eq"): + min_ = value + max_ = im.plus(value, 1) + + domain = domain_utils.SymbolicDomain( + common.GridType.CARTESIAN, # TODO + ranges={dim: domain_utils.SymbolicRange(start=min_, stop=max_)} + ) + + return domain.as_expr() + elif cpm.is_call_to(node, "not_eq"): + # `IDim != a -> `IDim < a & IDim > a` + return im.call("and_")( + self.visit(im.less(dim, value), **kwargs), + self.visit(im.greater(dim, value), **kwargs) + ) else: - raise ValueError(f"{node.args} need to be a 'ir.AxisLiteral' and an 'ir.Literal'.") - assert isinstance(fun, ir.SymRef) - min_: int | ir.NegInfinityLiteral - max_: int | ir.InfinityLiteral - match fun.id: - case ir.SymbolRef("less"): - if reverse: - min_ = im.plus(value, 1) - max_ = ir.InfinityLiteral() - else: - min_ = ir.NegInfinityLiteral() - max_ = im.minus(value, 1) - case ir.SymbolRef("less_equal"): - if reverse: - min_ = value - max_ = ir.InfinityLiteral() - else: - min_ = ir.NegInfinityLiteral() - max_ = value - case ir.SymbolRef("greater"): - if reverse: - min_ = ir.NegInfinityLiteral() - max_ = im.minus(value, 1) - else: - min_ = im.plus(value, 1) - max_ = ir.InfinityLiteral() - case ir.SymbolRef("greater_equal"): - if reverse: - min_ = ir.NegInfinityLiteral() - max_ = value - else: - min_ = value - max_ = ir.InfinityLiteral() - case ir.SymbolRef("eq"): - min_ = max_ = value - case ir.SymbolRef("not_eq"): - min1 = ir.NegInfinityLiteral() - max1 = im.minus(value, 1) - min2 = im.plus(value, 1) - max2 = ir.InfinityLiteral() - return im.call("and_")( - im.domain(common.GridType.CARTESIAN, {dim: (min1, max1)}), - im.domain(common.GridType.CARTESIAN, {dim: (min2, max2)}), - ) - case _: - raise NotImplementedError - return im.domain( - common.GridType.CARTESIAN, - {dim: (min_, im.plus(max_, 1))} - if not isinstance(max_, ir.InfinityLiteral) - else {dim: (min_, max_)}, - ) + raise ValueError(f"{fun} is not a valid comparison operator.") + if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( isinstance(arg, (ir.Literal, ir.FunCall)) for arg in node.args ): @@ -113,4 +106,4 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: else: raise NotImplementedError - return self.generic_visit(node) + return node diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index b412c0c273..e85283ee1d 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -106,6 +106,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_ORIGIN = "uses_origin" USES_REDUCE_WITH_LAMBDA = "uses_reduce_with_lambda" USES_SCAN = "uses_scan" +USES_FRONTEND_CONCAT_WHERE = "uses_frontend_concat_where" +USES_GTIR_CONCAT_WHERE = "uses_gtir_concat_where" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" @@ -155,6 +157,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_FRONTEND_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] ) EMBEDDED_SKIP_LIST = [ @@ -165,10 +168,14 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args + (USES_FRONTEND_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] +GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [ + (USES_FRONTEND_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), +] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST + DOMAIN_INFERENCE_SKIP_LIST @@ -211,5 +218,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], - ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.GTIR_EMBEDDED: GTIR_EMBEDDED_SKIP_LIST, } diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 66330016ef..8bbbed4ef9 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy import dataclasses import functools import inspect @@ -104,7 +105,7 @@ def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: ... @@ -137,11 +138,11 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: return constructors.full( - domain=common.domain(sizes), fill_value=self.value, dtype=dtype, allocator=allocator + domain=domain, fill_value=self.value, dtype=dtype, allocator=allocator ) @@ -163,16 +164,15 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: - if len(sizes) > 1: + if len(domain.dims) > 1: raise ValueError( - f"'IndexInitializer' only supports fields with a single 'Dimension', got {sizes}." + f"'IndexInitializer' only supports fields with a single 'Dimension', got {domain}." ) - n_data = list(sizes.values())[0] return constructors.as_field( - domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=allocator + domain=domain, data=np.arange(domain.ranges[0].start, domain.ranges[0].stop, dtype=dtype), allocator=allocator ) def from_case( @@ -204,16 +204,15 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: common.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: start = self.start - svals = tuple(sizes.values()) - n_data = int(np.prod(svals)) - self.start += n_data + assert isinstance(domain.size, int) + self.start += domain.size return constructors.as_field( - common.domain(sizes), - np.arange(start, start + n_data, dtype=dtype).reshape(svals), + common.domain(domain), + np.arange(start, self.start, dtype=dtype).reshape(domain.shape), allocator=allocator, ) @@ -326,6 +325,7 @@ def allocate( name: str, *, sizes: Optional[dict[gtx.Dimension, int]] = None, + domain: Optional[dict[gtx.Dimension, tuple[int, int]] | gtx.Domain] = None, strategy: Optional[DataInitializer] = None, dtype: Optional[np.typing.DTypeLike] = None, extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None, @@ -347,9 +347,18 @@ def allocate( Useful for shifted fields, which must start off bigger than the output field in the shifted dimension. """ - sizes = extend_sizes( - case.default_sizes | (sizes or {}), extend - ) # TODO: this should take into account the Domain of the allocated field + if sizes: + assert not domain and all(dim in case.default_sizes for dim in sizes) + domain = {dim: (0, sizes[dim] if dim in sizes else default_size) for dim, default_size in case.default_sizes.items()} + + if not domain: + domain = {dim: (0, size) for dim, size in case.default_sizes.items()} + + if not isinstance(domain, gtx.Domain): + domain = gtx.domain(domain) + + domain = extend_domain(domain, extend) # TODO: this should take into account the Domain of the allocated field + arg_type = get_param_types(fieldview_prog)[name] if strategy is None: if name in ["out", RETURN]: @@ -359,7 +368,7 @@ def allocate( return _allocate_from_type( case=case, arg_type=arg_type, - sizes=sizes, + domain=domain, dtype=dtype, strategy=strategy.from_case(case=case, fieldview_prog=fieldview_prog, arg_name=name), ) @@ -524,7 +533,7 @@ def unstructured_case_3d(unstructured_case): def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, strategy: DataInitializer, dtype: Optional[np.typing.DTypeLike] = None, tuple_start: Optional[int] = None, @@ -534,7 +543,7 @@ def _allocate_from_type( case ts.FieldType(dims=dims, dtype=arg_dtype): return strategy.field( allocator=case.allocator, - sizes={dim: sizes[dim] for dim in dims}, + domain=common.domain(tuple(domain[dim] for dim in dims)), dtype=dtype or arg_dtype.kind.name.lower(), ) case ts.ScalarType(kind=kind): @@ -543,7 +552,7 @@ def _allocate_from_type( return tuple( ( _allocate_from_type( - case=case, arg_type=t, sizes=sizes, dtype=dtype, strategy=strategy + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy ) for t in types ) @@ -579,15 +588,16 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> raise TypeError(f"Can not get size for parameter of type '{param_type}'.") -def extend_sizes( - sizes: dict[gtx.Dimension, int], extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None +def extend_domain( + domain: gtx.Domain, extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None ) -> dict[gtx.Dimension, int]: """Calculate the sizes per dimension given a set of extensions.""" - sizes = sizes.copy() if extend: + domain = copy.deepcopy(domain) for dim, (lower, upper) in extend.items(): - sizes[dim] += upper - lower - return sizes + domain[dim][0] += -lower + domain[dim][1] += upper + return domain def get_default_data( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 7db29bc088..2ed38f2993 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -19,185 +19,170 @@ ) -def test_boundary_same_size_fields(cartesian_case): +@pytest.mark.uses_frontend_concat_where +def test_concat_where(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: - return concat_where(KDim == 0, boundary, interior) + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground")() + air = cases.allocate(cartesian_case, testee, "air")() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy() ) - - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) -def test_dimension(cartesian_case): +@pytest.mark.uses_frontend_concat_where +def test_concat_where_non_overlapping(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: - return concat_where(KDim >= 2, boundary, interior) + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground", domain=out.domain.slice_at[:, :, 0:1])() + air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() - ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] >= 2, boundary.asnumpy(), interior.asnumpy() - ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + ref = np.concatenate((ground.asnumpy(), air.asnumpy()), axis=2) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) -def test_dimension_different_dims(cartesian_case): +@pytest.mark.uses_frontend_concat_where +def test_concat_where_non_overlapping_different_dims(cartesian_case): @gtx.field_operator - def testee(j: cases.JField, interior: cases.IJField, boundary: cases.JField) -> cases.IJField: - return concat_where(IDim >= 2, boundary, interior) + def testee( + ground: cases.KField, # note: boundary field is only defined in K + air: cases.IJKField + ) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) - j = cases.allocate(cartesian_case, testee, "j", strategy=cases.IndexInitializer())() - interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground", domain=gtx.domain({KDim: (0, 1)}))() + air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() - ref = np.where( - j.asnumpy()[:, np.newaxis] >= 2, boundary.asnumpy()[np.newaxis, :], interior.asnumpy() - ) - cases.verify(cartesian_case, testee, j, interior, boundary, out=out, ref=ref) + ref = np.concatenate((np.tile(ground.asnumpy(),(*air.domain.shape[0:2], len(ground.domain[KDim].unit_range))), air.asnumpy()), axis=2) + + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where def test_dimension_two_nested_conditions(cartesian_case): @gtx.field_operator def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: return concat_where((KDim < 2), boundary, concat_where((KDim >= 5), boundary, interior)) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - (k.asnumpy()[np.newaxis, np.newaxis, :] < 2) - | (k.asnumpy()[np.newaxis, np.newaxis, :] >= 5), + (k[np.newaxis, np.newaxis, :] < 2) + | (k[np.newaxis, np.newaxis, :] >= 5), boundary.asnumpy(), interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where def test_dimension_two_conditions_and(cartesian_case): @gtx.field_operator - def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: return concat_where(((KDim > 2) & (KDim <= 5)), interior, boundary) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where((k.asnumpy() > 2) & (k.asnumpy() <= 5), interior.asnumpy(), boundary.asnumpy()) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where((k > 2) & (k <= 5), interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where def test_dimension_two_conditions_eq(cartesian_case): @gtx.field_operator - def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: return concat_where((KDim == 2), interior, boundary) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where(k.asnumpy() == 2, interior.asnumpy(), boundary.asnumpy()) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k == 2, interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where def test_dimension_two_conditions_or(cartesian_case): @gtx.field_operator - def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: return concat_where(((KDim < 2) | (KDim >= 5)), boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where((k.asnumpy() < 2) | (k.asnumpy() >= 5), boundary.asnumpy(), interior.asnumpy()) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where((k < 2) | (k >= 5), boundary.asnumpy(), interior.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where def test_boundary_horizontal_slice(cartesian_case): @gtx.field_operator def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJField + interior: cases.IJKField, boundary: cases.IJField ) -> cases.IJKField: return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy()[:, :, np.newaxis], interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where def test_boundary_single_layer(cartesian_case): @gtx.field_operator def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, np.broadcast_to(boundary.asnumpy(), interior.shape), interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) - - -def test_alternating_mask(cartesian_case): - with pytest.raises( - errors.DSLError, match=("Type 'ts.OffsetType' can not be used in operator 'mod'.") - ): - - @gtx.field_operator - def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: - return concat_where(KDim % 2 == 0, f1, f0) - - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - f0 = cases.allocate(cartesian_case, testee, "f0")() - f1 = cases.allocate(cartesian_case, testee, "f1")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() - - ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) - - cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +@pytest.mark.uses_frontend_concat_where @pytest.mark.uses_tuple_returns def test_with_tuples(cartesian_case): + pytest.skip("Not implemented in the frontend.") @gtx.field_operator def testee( k: cases.KField, @@ -208,20 +193,20 @@ def testee( ) -> Tuple[cases.IJKField, cases.IJKField]: return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior0 = cases.allocate(cartesian_case, testee, "interior0")() boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() interior1 = cases.allocate(cartesian_case, testee, "interior1")() boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref0 = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary0.asnumpy()[:, :, np.newaxis], interior0.asnumpy(), ) ref1 = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary1.asnumpy()[:, :, np.newaxis], interior1.asnumpy(), ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 7118054c8b..dd2758e5da 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -146,7 +146,18 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=(r"Expected an int32, but got 'float64' instead."), + match=re.escape("Expected an int32, but got 'float64' instead."), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + +def test_domain_comparison_checkerboard_failure(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim % 2., a, b) + + with pytest.raises( + errors.DSLError, + match=re.escape("Unsupported operand type(s) for %."), ): _ = FieldOperatorParser.apply_to_function(domain_comparison) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 794a93090b..315bf61b31 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -64,86 +64,86 @@ def test_constant_folding_literal_maximum(): def test_constant_folding_inf_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral()) - expected = ir.InfinityLiteral() + testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) + expected = ir.InfinityLiteral.POSITIVE actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("maximum")(ir.InfinityLiteral(), im.literal_from_value(1)) - expected = ir.InfinityLiteral() + testee = im.call("maximum")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + expected = ir.InfinityLiteral.POSITIVE actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("maximum")(im.literal_from_value(1), ir.NegInfinityLiteral()) + testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("maximum")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + testee = im.call("maximum")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_folding_inf_minimum(): - testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral()) + testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("minimum")(ir.InfinityLiteral(), im.literal_from_value(1)) + testee = im.call("minimum")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("minimum")(im.literal_from_value(1), ir.NegInfinityLiteral()) - expected = ir.NegInfinityLiteral() + testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) + expected = ir.InfinityLiteral.NEGATIVE actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("minimum")(ir.NegInfinityLiteral(), im.literal_from_value(1)) - expected = ir.NegInfinityLiteral() + testee = im.call("minimum")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + expected = ir.InfinityLiteral.NEGATIVE actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_greater_less(): - testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral()) + testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("greater")(im.literal_from_value(1), ir.NegInfinityLiteral()) + testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral()) + testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(im.literal_from_value(1), ir.NegInfinityLiteral()) + testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("greater")(ir.InfinityLiteral(), im.literal_from_value(1)) + testee = im.call("greater")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("greater")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + testee = im.call("greater")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(ir.InfinityLiteral(), im.literal_from_value(1)) + testee = im.call("less")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(ir.NegInfinityLiteral(), im.literal_from_value(1)) + testee = im.call("less")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 2e014ffdb8..32c41d6cef 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -1097,7 +1097,7 @@ def test_never_accessed_domain_tuple(offset_provider): def test_concat_where(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 4)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 4)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) testee = im.concat_where( @@ -1120,13 +1120,13 @@ def test_concat_where(offset_provider): assert expected_domains == constant_fold_accessed_domains(actual_domains) -# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 4)}) +# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 4)}) # Todo: nested concat wheres def test_concat_where_two_dimensions(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) - domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 10)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 10)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) testee = im.concat_where( @@ -1151,7 +1151,7 @@ def test_concat_where_two_dimensions(offset_provider): def test_concat_where_two_dimensions_J(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) - domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, "inf")}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, itir.InfinityLiteral.POSITIVE)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) testee = im.concat_where( @@ -1176,8 +1176,8 @@ def test_concat_where_two_dimensions_J(offset_provider): def test_nested_concat_where_two_dimensions(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) - domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, "inf")}) - domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: (itir.NegInfinityLiteral(), 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, itir.InfinityLiteral.POSITIVE)}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 20)}) domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) From bd8dbaa6703a6b5e4e086770759c9f688db20fbe Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:17:28 +0100 Subject: [PATCH 020/124] Fix iterator tests --- src/gt4py/next/iterator/builtins.py | 11 ++++++++++- src/gt4py/next/iterator/embedded.py | 8 ++++++++ .../iterator/transforms/expand_library_functions.py | 2 +- .../iterator/transforms/transform_concat_where.py | 2 +- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 4ebc9a388c..f9b045747f 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -402,6 +402,15 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def concat_where(*args): + raise BackendNotSelectedError() + +@builtin_dispatch +def in_(*args): + raise BackendNotSelectedError() + + UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -489,7 +498,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "tuple_get", "unstructured_domain", "concat_where", - "in", + "in_", *ARITHMETIC_BUILTINS, *TYPE_BUILTINS, } diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index da0516d26b..0ef868cac7 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1798,6 +1798,14 @@ def impl(*args): def index(axis: common.Dimension) -> common.Field: return IndexField(axis) +@builtins.concat_where.register(EMBEDDED) +def concat_where(*args): + raise NotImplementedError("To be implemented in frontend embedded.") + +@builtins.in_.register(EMBEDDED) +def in_(*args): + raise NotImplementedError("To be implemented in frontend embedded.") + def closure( domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 9fab9e053f..0f3d005452 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -23,7 +23,7 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - if cpm.is_call_to(node, "in"): + if cpm.is_call_to(node, "in_"): ret = [] pos, domain = node.args for i, (_, v) in enumerate( diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index ffa20908ab..34c424cd85 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -46,7 +46,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond_expr = DerefCond.apply(cond_expr, refs) return im.as_fieldop( im.lambda_("_tcw_pos", "_tcw_arg0", "_tcw_arg1", *refs)( - im.if_(im.call("in")(im.deref("_tcw_pos"), cond_expr), im.deref("_tcw_arg0"), im.deref("_tcw_arg1")) + im.if_(im.call("in_")(im.deref("_tcw_pos"), cond_expr), im.deref("_tcw_arg0"), im.deref("_tcw_arg1")) ), )(im.make_tuple(*dims), field_a, field_b, *refs) From 2c1464870b8113e0327a1828e093a4c1cc4b49d3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:39:17 +0100 Subject: [PATCH 021/124] Fix infer domain ops --- src/gt4py/next/iterator/transforms/infer_domain_ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index db51aa2888..477fc95010 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -93,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: raise ValueError(f"{fun} is not a valid comparison operator.") if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( - isinstance(arg, (ir.Literal, ir.FunCall)) for arg in node.args + isinstance(arg.type, ts.DomainType) for arg in node.args ): if cpm.is_call_to(node, "and_"): # TODO: domain promotion @@ -102,8 +102,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: *[domain_utils.SymbolicDomain.from_expr(arg) for arg in node.args] ).as_expr() ) - else: - raise NotImplementedError + raise NotImplementedError() return node From 120080318ff210af53afdcabe4af560227d47e8a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:52:59 +0100 Subject: [PATCH 022/124] Cleanup --- .../transforms/transform_concat_where.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 34c424cd85..54be53d187 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -16,19 +16,6 @@ ir_makers as im, ) -class DerefCond(PreserveLocationVisitor, NodeTranslator): - - @classmethod - def apply(cls, node: ir.Node, symbol_refs: list): - return cls().visit(node, symbol_refs=symbol_refs) - - def visit_SymRef(self, node: ir.FunCall, symbol_refs: list) -> ir.FunCall: - node = self.generic_visit(node, symbol_refs=symbol_refs) - if node.id in symbol_refs and isinstance(node.type, ts.ScalarType): - node.type = None - return im.deref(node) - return node - class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): @classmethod @@ -42,11 +29,12 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] refs = symbol_ref_utils.collect_symbol_refs(cond_expr) - # TODO: this deref pass is not correct - cond_expr = DerefCond.apply(cond_expr, refs) + return im.as_fieldop( - im.lambda_("_tcw_pos", "_tcw_arg0", "_tcw_arg1", *refs)( - im.if_(im.call("in_")(im.deref("_tcw_pos"), cond_expr), im.deref("_tcw_arg0"), im.deref("_tcw_arg1")) + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( + im.let(*zip(refs, map(im.deref, refs), strict=True))( + im.if_(im.call("in_")(im.deref("__tcw_pos"), cond_expr), im.deref("__tcw_arg0"), im.deref("__tcw_arg1")) + ) ), )(im.make_tuple(*dims), field_a, field_b, *refs) From cf0ffb2dd10cffc8420608861a0cb22ff397ce72 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:55:13 +0100 Subject: [PATCH 023/124] Fix format --- src/gt4py/next/ffront/experimental.py | 8 +++- src/gt4py/next/ffront/fbuiltins.py | 5 ++- .../ffront/foast_passes/type_deduction.py | 42 +++++++++++-------- src/gt4py/next/iterator/builtins.py | 1 + src/gt4py/next/iterator/embedded.py | 2 + src/gt4py/next/iterator/ir.py | 8 +++- .../next/iterator/ir_utils/domain_utils.py | 4 +- src/gt4py/next/iterator/transforms/cse.py | 8 ++-- .../iterator/transforms/infer_domain_ops.py | 14 ++++--- .../next/iterator/transforms/pass_manager.py | 5 +-- .../transforms/transform_concat_where.py | 9 ++-- tests/next_tests/integration_tests/cases.py | 13 ++++-- .../ffront_tests/test_concat_where.py | 38 +++++++++-------- .../ffront_tests/test_type_deduction.py | 21 ++++++---- 14 files changed, 112 insertions(+), 66 deletions(-) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 6d477ad015..dfa89468a5 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -10,7 +10,12 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereLikeBuiltinFunction, FieldT +from gt4py.next.ffront.fbuiltins import ( + BuiltInFunction, + FieldOffset, + FieldT, + WhereLikeBuiltinFunction, +) @BuiltInFunction @@ -22,6 +27,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi DomainT = TypeVar("DomainT", bound=common.Field) ConcatWhereBuiltinFunction = WhereLikeBuiltinFunction[_R, DomainT, FieldT] + @ConcatWhereBuiltinFunction def concat_where( mask: common.Domain, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 927758b5d6..17e6bb2133 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -142,7 +142,8 @@ def __gt_type__(self) -> ts.FunctionType: class WhereLikeBuiltinFunction( - BuiltInFunction[_R, [MaskLikeT, FieldT, FieldT]], Generic[_R, MaskLikeT, FieldT], + BuiltInFunction[_R, [MaskLikeT, FieldT, FieldT]], + Generic[_R, MaskLikeT, FieldT], ): def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): @@ -158,9 +159,11 @@ def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> return tuple(self(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) + MaskT = TypeVar("MaskT", bound=common.Field) WhereBuiltinFunction = WhereLikeBuiltinFunction[_R, MaskT, FieldT] + @BuiltInFunction def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index fb2190172e..31f8f5b4eb 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -592,7 +592,8 @@ def _deduce_dimension_compare_type( ) -> Optional[ts.TypeSpec]: # e.g. `IDim > 1` index_type = ts.ScalarType( - kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + ) if isinstance(left.type, ts.DimensionType): if not right.type == index_type: @@ -622,12 +623,12 @@ def _deduce_compare_type( return self._deduce_dimension_compare_type(node, left=left, right=right) raise errors.DSLError( - left.location, "Comparison operators can only be used between arithmetic types " - "(scalars, fields) or between a dimension and an index type " - "({builtins.INTEGER_INDEX_BUILTIN})." + left.location, + "Comparison operators can only be used between arithmetic types " + "(scalars, fields) or between a dimension and an index type " + "({builtins.INTEGER_INDEX_BUILTIN}).", ) - def _deduce_binop_type( self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: @@ -654,7 +655,9 @@ def _deduce_binop_type( if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( right.type, (ts.ScalarType, ts.FieldType) ): - is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic + is_compatible = ( + type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic + ) for arg in (left, right): if not is_compatible(arg.type): raise errors.DSLError(arg.location, err_msg) @@ -680,8 +683,11 @@ def _deduce_binop_type( ) from ex elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): if node.op not in logical_ops: - raise errors.DSLError(node.location, f"{err_msg} Operator " - f"must be one of {', '.join((str(op) for op in logical_ops))}.") + raise errors.DSLError( + node.location, + f"{err_msg} Operator " + f"must be one of {', '.join((str(op) for op in logical_ops))}.", + ) return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) else: raise errors.DSLError(node.location, err_msg) @@ -987,26 +993,26 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: - mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) + mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) assert isinstance(mask_type, ts.DomainType) - assert all(isinstance(arg, (ts.FieldType, ts.ScalarType)) for arg in (true_branch_type, false_branch_type)) + assert all( + isinstance(arg, (ts.FieldType, ts.ScalarType)) + for arg in (true_branch_type, false_branch_type) + ) if (t_dtype := type_info.extract_dtype(true_branch_type)) != ( - f_dtype := type_info.extract_dtype(false_branch_type)): + f_dtype := type_info.extract_dtype(false_branch_type) + ): raise errors.DSLError( node.location, - f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'." + f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", ) return_dims = promote_dims( - mask_type.dims, - type_info.promote(true_branch_type, false_branch_type).dims - ) - return_type = ts.FieldType( - dims=return_dims, - dtype=type_info.promote(t_dtype, f_dtype) + mask_type.dims, type_info.promote(true_branch_type, false_branch_type).dims ) + return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) return foast.Call( func=node.func, diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index f9b045747f..fe3afc1960 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -406,6 +406,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] def concat_where(*args): raise BackendNotSelectedError() + @builtin_dispatch def in_(*args): raise BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 0ef868cac7..958a6909c2 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1798,10 +1798,12 @@ def impl(*args): def index(axis: common.Dimension) -> common.Field: return IndexField(axis) + @builtins.concat_where.register(EMBEDDED) def concat_where(*args): raise NotImplementedError("To be implemented in frontend embedded.") + @builtins.in_.register(EMBEDDED) def in_(*args): raise NotImplementedError("To be implemented in frontend embedded.") diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e40b8edbe1..d1e82b6edc 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations -from typing import ClassVar, List, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef @@ -63,9 +63,12 @@ class Literal(Expr): class NoneLiteral(Expr): _none_literal: int = 0 + class InfinityLiteral(Expr): if TYPE_CHECKING: - POSITIVE: ClassVar[InfinityLiteral] # TODO(tehrengruber): should be `ClassVar[InfinityLiteral]`, but self-referential not supported in eve + POSITIVE: ClassVar[ + InfinityLiteral + ] # TODO(tehrengruber): should be `ClassVar[InfinityLiteral]`, but self-referential not supported in eve NEGATIVE: ClassVar[InfinityLiteral] name: typing.Literal["POSITIVE", "NEGATIVE"] @@ -73,6 +76,7 @@ class InfinityLiteral(Expr): def __str__(self): return f"{type(self).__name__}.{self.name}" + InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE") InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index efd8ebb2c0..6622997d63 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -220,5 +220,7 @@ def promote_to_same_dimensions( lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop dims_dict[dim] = SymbolicRange(lb, ub) else: - dims_dict[dim] = SymbolicRange(itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE) + dims_dict[dim] = SymbolicRange( + itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE + ) return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 871d00b3f6..32fd126e63 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -90,15 +90,17 @@ def _is_collectable_expr(node: itir.Node) -> bool: # otherwise # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement # instead of an as_fieldop - if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_", "index")) or cpm.is_applied_lift(node): + if cpm.is_call_to( + node, ("lift", "shift", "reduce", "map_", "index") + ) or cpm.is_applied_lift(node): return False return True # do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement # instead of an as_fieldop if cpm.is_call_to(node, "make_tuple") and all( - cpm.is_call_to(arg, "index") for arg in node.args + cpm.is_call_to(arg, "index") for arg in node.args ): - return False + return False elif isinstance(node, itir.Lambda): return True diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 477fc95010..6d874f9efa 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -15,9 +15,8 @@ domain_utils, ir_makers as im, ) -from gt4py.next.type_system import type_specifications as ts -from gt4py.next.iterator.ir_utils.domain_utils import domain_complement from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.type_system import type_specifications as ts class InferDomainOps(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +45,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: "eq": "eq", "not_eq": "not_eq", } - return self.visit(im.call(complementary_op[node.fun.id])(arg2, arg1), **{**kwargs, "recurse":False}) + return self.visit( + im.call(complementary_op[node.fun.id])(arg2, arg1), + **{**kwargs, "recurse": False}, + ) assert isinstance(arg1.type, ts.DimensionType) dim: common.Dimension = arg1.type.dim @@ -78,8 +80,8 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: max_ = im.plus(value, 1) domain = domain_utils.SymbolicDomain( - common.GridType.CARTESIAN, # TODO - ranges={dim: domain_utils.SymbolicRange(start=min_, stop=max_)} + common.GridType.CARTESIAN, # TODO + ranges={dim: domain_utils.SymbolicRange(start=min_, stop=max_)}, ) return domain.as_expr() @@ -87,7 +89,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # `IDim != a -> `IDim < a & IDim > a` return im.call("and_")( self.visit(im.less(dim, value), **kwargs), - self.visit(im.greater(dim, value), **kwargs) + self.visit(im.greater(dim, value), **kwargs), ) else: raise ValueError(f"{fun} is not a valid comparison operator.") diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index fdb144cffe..f1540cabc8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -94,10 +94,9 @@ def apply_common_transforms( ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) - - ir = ConstantFolding.apply(ir) # TODO: remove + ir = ConstantFolding.apply(ir) # TODO: remove ir = transform_concat_where.TransformConcatWhere.apply(ir) - ir = ConstantFolding.apply(ir) # TODO: remove + ir = ConstantFolding.apply(ir) # TODO: remove ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 54be53d187..62c302335a 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -8,13 +8,12 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -from gt4py.next.type_system import type_specifications as ts -from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, ir_makers as im, ) +from gt4py.next.iterator.transforms import symbol_ref_utils class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): @@ -33,7 +32,11 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: return im.as_fieldop( im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( im.let(*zip(refs, map(im.deref, refs), strict=True))( - im.if_(im.call("in_")(im.deref("__tcw_pos"), cond_expr), im.deref("__tcw_arg0"), im.deref("__tcw_arg1")) + im.if_( + im.call("in_")(im.deref("__tcw_pos"), cond_expr), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) ) ), )(im.make_tuple(*dims), field_a, field_b, *refs) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 8bbbed4ef9..a4f2ebeea3 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -172,7 +172,9 @@ def field( f"'IndexInitializer' only supports fields with a single 'Dimension', got {domain}." ) return constructors.as_field( - domain=domain, data=np.arange(domain.ranges[0].start, domain.ranges[0].stop, dtype=dtype), allocator=allocator + domain=domain, + data=np.arange(domain.ranges[0].start, domain.ranges[0].stop, dtype=dtype), + allocator=allocator, ) def from_case( @@ -349,7 +351,10 @@ def allocate( """ if sizes: assert not domain and all(dim in case.default_sizes for dim in sizes) - domain = {dim: (0, sizes[dim] if dim in sizes else default_size) for dim, default_size in case.default_sizes.items()} + domain = { + dim: (0, sizes[dim] if dim in sizes else default_size) + for dim, default_size in case.default_sizes.items() + } if not domain: domain = {dim: (0, size) for dim, size in case.default_sizes.items()} @@ -357,7 +362,9 @@ def allocate( if not isinstance(domain, gtx.Domain): domain = gtx.domain(domain) - domain = extend_domain(domain, extend) # TODO: this should take into account the Domain of the allocated field + domain = extend_domain( + domain, extend + ) # TODO: this should take into account the Domain of the allocated field arg_type = get_param_types(fieldview_prog)[name] if strategy is None: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 2ed38f2993..a5e65d2c1f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -30,9 +30,7 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: air = cases.allocate(cartesian_case, testee, "air")() k = np.arange(0, cartesian_case.default_sizes[KDim]) - ref = np.where( - k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy() - ) + ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) @@ -43,7 +41,9 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: return concat_where(KDim == 0, ground, air) out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ground = cases.allocate(cartesian_case, testee, "ground", domain=out.domain.slice_at[:, :, 0:1])() + ground = cases.allocate( + cartesian_case, testee, "ground", domain=out.domain.slice_at[:, :, 0:1] + )() air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() ref = np.concatenate((ground.asnumpy(), air.asnumpy()), axis=2) @@ -54,8 +54,8 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: def test_concat_where_non_overlapping_different_dims(cartesian_case): @gtx.field_operator def testee( - ground: cases.KField, # note: boundary field is only defined in K - air: cases.IJKField + ground: cases.KField, # note: boundary field is only defined in K + air: cases.IJKField, ) -> cases.IJKField: return concat_where(KDim == 0, ground, air) @@ -63,7 +63,15 @@ def testee( ground = cases.allocate(cartesian_case, testee, "ground", domain=gtx.domain({KDim: (0, 1)}))() air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() - ref = np.concatenate((np.tile(ground.asnumpy(),(*air.domain.shape[0:2], len(ground.domain[KDim].unit_range))), air.asnumpy()), axis=2) + ref = np.concatenate( + ( + np.tile( + ground.asnumpy(), (*air.domain.shape[0:2], len(ground.domain[KDim].unit_range)) + ), + air.asnumpy(), + ), + axis=2, + ) cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) @@ -71,9 +79,7 @@ def testee( @pytest.mark.uses_frontend_concat_where def test_dimension_two_nested_conditions(cartesian_case): @gtx.field_operator - def testee( - interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: return concat_where((KDim < 2), boundary, concat_where((KDim >= 5), boundary, interior)) interior = cases.allocate(cartesian_case, testee, "interior")() @@ -82,8 +88,7 @@ def testee( k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - (k[np.newaxis, np.newaxis, :] < 2) - | (k[np.newaxis, np.newaxis, :] >= 5), + (k[np.newaxis, np.newaxis, :] < 2) | (k[np.newaxis, np.newaxis, :] >= 5), boundary.asnumpy(), interior.asnumpy(), ) @@ -138,9 +143,7 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: @pytest.mark.uses_frontend_concat_where def test_boundary_horizontal_slice(cartesian_case): @gtx.field_operator - def testee( - interior: cases.IJKField, boundary: cases.IJField - ) -> cases.IJKField: + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: return concat_where(KDim == 0, boundary, interior) interior = cases.allocate(cartesian_case, testee, "interior")() @@ -160,9 +163,7 @@ def testee( @pytest.mark.uses_frontend_concat_where def test_boundary_single_layer(cartesian_case): @gtx.field_operator - def testee( - interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: return concat_where(KDim == 0, boundary, interior) interior = cases.allocate(cartesian_case, testee, "interior")() @@ -183,6 +184,7 @@ def testee( @pytest.mark.uses_tuple_returns def test_with_tuples(cartesian_case): pytest.skip("Not implemented in the frontend.") + @gtx.field_operator def testee( k: cases.KField, diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index dd2758e5da..341608111b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -76,7 +76,12 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - errors.DSLError, match=(re.escape("Unsupported operand type(s) for +: 'Field[[TDim], bool]' and 'Field[[TDim], bool]'.")) + errors.DSLError, + match=( + re.escape( + "Unsupported operand type(s) for +: 'Field[[TDim], bool]' and 'Field[[TDim], bool]'." + ) + ), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -104,7 +109,9 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=re.escape("Unsupported operand type(s) for &: 'Field[[TDim], float64]' and 'Field[[TDim], float64]'."), + match=re.escape( + "Unsupported operand type(s) for &: 'Field[[TDim], float64]' and 'Field[[TDim], float64]'." + ), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -142,7 +149,7 @@ def simple_concat_where(a: Field[[TDim], float], b: Field[[TDim], float]): def test_domain_comparison_failure(): def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): - return concat_where(TDim > 1., a, b) + return concat_where(TDim > 1.0, a, b) with pytest.raises( errors.DSLError, @@ -153,7 +160,7 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): def test_domain_comparison_checkerboard_failure(): def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): - return concat_where(TDim % 2., a, b) + return concat_where(TDim % 2.0, a, b) with pytest.raises( errors.DSLError, @@ -164,11 +171,11 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): def test_concat_where_invalid_dtype(): def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): - return concat_where(TDim > 0, 1., 2) + return concat_where(TDim > 0, 1.0, 2) with pytest.raises( - errors.DSLError, - match=re.escape("Field arguments must be of same dtype, got 'float64' != 'int32'."), + errors.DSLError, + match=re.escape("Field arguments must be of same dtype, got 'float64' != 'int32'."), ): _ = FieldOperatorParser.apply_to_function(domain_comparison) From 335e932086b19481ef904a65417ac1bd611fcc59 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 13:13:15 +0100 Subject: [PATCH 024/124] Fix broken scan (e.g. test_tuple_scalar_scan) --- src/gt4py/next/iterator/ir_utils/misc.py | 18 +++++++++++++++--- .../next/iterator/transforms/collapse_tuple.py | 15 ++++++++++----- .../transforms_tests/test_collapse_tuple.py | 9 ++++++--- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 00a1ab5609..03a3dfb0e3 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -93,6 +93,16 @@ def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return expr +def _remove_let_alias(let_expr: itir.FunCall): + assert cpm.is_let(let_expr) + is_aliased_let = True + for param, arg in zip(let_expr.fun.params, let_expr.args, strict=True): # type: ignore[attr-defined] # ensured by cpm.is_let + is_aliased_let &= cpm.is_ref_to(arg, param.id) + if is_aliased_let: + return let_expr.fun.expr # type: ignore[attr-defined] # ensured by cpm.is_let + return let_expr + + def unwrap_scan(stencil: itir.Lambda | itir.FunCall): """ If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. @@ -108,7 +118,7 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): >>> str(stencil) 'λ(arg) → state + ·arg' >>> str(back_trafo(stencil)) - 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + 'scan(λ(state, arg) → state + ·arg, True, 0.0)' In case a regular stencil is given it is returned as-is: @@ -125,8 +135,10 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): def restore_scan(transformed_stencil_like: itir.Lambda): new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( - im.call(transformed_stencil_like)( - *(param.id for param in transformed_stencil_like.params) + _remove_let_alias( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) ) ) return im.call("scan")(new_scan_pass, direction, init) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f0afd5bafc..03364451b4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -18,7 +18,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import ir, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, ir_makers as im, @@ -51,10 +51,7 @@ def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False - if not all( - isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_make_tuple_call(arg) - for arg in node.args - ): + if not all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args): return False return True @@ -163,6 +160,8 @@ class Transformation(enum.Flag): INLINE_TRIVIAL_LET = enum.auto() #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) FLATTEN_AS_FIELDOP_ARGS = enum.auto() + #: `let(a, b[1])(a)` -> `b[1]` + INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto() @classmethod def all(self) -> CollapseTuple.Transformation: @@ -507,6 +506,12 @@ def transform_inline_trivial_let(self, node: itir.FunCall, **kwargs) -> Optional return None + def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if cpm.is_let(node): + if any(trivial_args := [_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None + # TODO(tehrengruber): This is a transformation that should be executed before visiting the children. Then # revisiting the body would not be needed. def transform_flatten_as_fieldop_args( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 08f1926277..e8d04096b4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -128,8 +128,11 @@ def test_propagate_tuple_get(): def test_letify_make_tuple_elements(): - # anything that is not trivial, i.e. a SymRef, works here - el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") + fun_type = ts.FunctionType( + pos_only_args=[], pos_or_kw_args={}, kw_only_args={}, returns=int_type + ) + # anything that is not trivial, works here + el1, el2 = im.call(im.ref("foo", fun_type))(), im.call(im.ref("bar", fun_type))() testee = im.make_tuple(el1, el2) expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( im.make_tuple("__ct_el_1", "__ct_el_2") @@ -391,4 +394,4 @@ def test_flatten_as_fieldop_args_scan(): allow_undeclared_symbols=True, within_stencil=False, ) - assert actual == expected + assert actual == expected \ No newline at end of file From c18b7ad67d365c0e9b369f42559379abdcda38fe Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 13:34:36 +0100 Subject: [PATCH 025/124] Fix failing tests --- tests/next_tests/integration_tests/cases.py | 16 +++++++++++++--- .../ffront_tests/test_gt4py_builtins.py | 12 ++++++------ .../transforms_tests/test_collapse_tuple.py | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index a4f2ebeea3..5aac6ff595 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -533,7 +533,7 @@ def unstructured_case_3d(unstructured_case): return dataclasses.replace( unstructured_case, default_sizes={**unstructured_case.default_sizes, KDim: 10}, - offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + offset_provider={**unstructured_case.offset_provider, "Koff": KDim}, ) @@ -602,8 +602,18 @@ def extend_domain( if extend: domain = copy.deepcopy(domain) for dim, (lower, upper) in extend.items(): - domain[dim][0] += -lower - domain[dim][1] += upper + domain = domain.replace( + dim, + common.named_range( + ( + dim, + ( + domain[dim].unit_range.start - lower, + domain[dim].unit_range.stop + upper, + ), + ) + ), + ) return domain diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index ab1c625fef..6ddd62cbd6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -141,7 +141,7 @@ def test_neighbor_sum(unstructured_case_3d, fop): @pytest.mark.uses_unstructured_shift -def test_reduction_execution_with_offset(unstructured_case): +def test_reduction_execution_with_offset(unstructured_case_3d): EKField: TypeAlias = gtx.Field[[Edge, KDim], np.int32] VKField: TypeAlias = gtx.Field[[Vertex, KDim], np.int32] @@ -158,12 +158,12 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].ndarray - field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() - out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() + v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray + field = cases.allocate(unstructured_case_3d, fencil, "edge_f", sizes={KDim: 2})() + out = cases.allocate(unstructured_case_3d, fencil_op, cases.RETURN, sizes={KDim: 1})() cases.verify( - unstructured_case, + unstructured_case_3d, fencil, field, out, @@ -174,7 +174,7 @@ def fencil(edge_f: EKField, out: VKField): initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE, ).reshape(out.shape), - offset_provider=unstructured_case.offset_provider | {"Koff": KDim}, + offset_provider=unstructured_case_3d.offset_provider, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index e8d04096b4..636e66940c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -394,4 +394,4 @@ def test_flatten_as_fieldop_args_scan(): allow_undeclared_symbols=True, within_stencil=False, ) - assert actual == expected \ No newline at end of file + assert actual == expected From d399c65848f2916c0c760ed217ec80d62d5a887e Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 13:34:51 +0100 Subject: [PATCH 026/124] Fix format --- .../iterator_tests/transforms_tests/test_collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index e8d04096b4..636e66940c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -394,4 +394,4 @@ def test_flatten_as_fieldop_args_scan(): allow_undeclared_symbols=True, within_stencil=False, ) - assert actual == expected \ No newline at end of file + assert actual == expected From 5ad77013267c520d44ae59dcaeefba7dadb3ee20 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 15:29:50 +0100 Subject: [PATCH 027/124] Fix failing tests --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 ++- .../iterator/transforms/collapse_tuple.py | 39 +++++++++++++------ 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 42b82ffdd0..bdb9bd8249 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -445,7 +445,7 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - from gt4py.next.iterator.ir_utils import domain_utils + from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils result = call( call("as_fieldop")( @@ -462,7 +462,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal def _populate_domain_annex_wrapper(*args, **kwargs): node = result(*args, **kwargs) - if domain: + # note: if the domain is not a direct construction, e.g. because it is only a reference + # to a domain defined in a let, don't populate the annex + if domain and cpm.is_call_to(domain, ("cartesian_domain", "unstructured_domain")): node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 03364451b4..6b04790644 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -13,7 +13,7 @@ import functools import operator import re -from typing import Optional +from typing import Literal, Optional from gt4py import eve from gt4py.eve import utils as eve_utils @@ -41,12 +41,24 @@ def _with_altered_arg(node: itir.FunCall, arg_idx: int, new_arg: itir.Expr | str ) -def _with_altered_iterator_element_type(type_: it_ts.IteratorType, new_el_type: ts.DataType): +def _with_altered_iterator_element_type( + type_: it_ts.IteratorType, new_el_type: ts.DataType +) -> it_ts.IteratorType: return it_ts.IteratorType( position_dims=type_.position_dims, defined_dims=type_.defined_dims, element_type=new_el_type ) +def _with_altered_iterator_position_dims( + type_: it_ts.IteratorType, new_position_dims: list[common.Dimension] | Literal["unknown"] +) -> it_ts.IteratorType: + return it_ts.IteratorType( + position_dims=new_position_dims, + defined_dims=type_.defined_dims, + element_type=type_.element_type, + ) + + def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): @@ -540,19 +552,24 @@ def transform_flatten_as_fieldop_args( if isinstance(arg.type, ts.TupleType): ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg - new_params_inner, new_args_inner = [], [] + new_params_inner, lift_params = [], [] for i, type_ in enumerate(param.type.element_type.types): - new_params_inner.append( + new_param = im.sym( + _flattened_as_fieldop_param_el_name(param.id, i), + _with_altered_iterator_element_type(param.type, type_), + ) + lift_params.append( im.sym( - _flattened_as_fieldop_param_el_name(param.id, i), - _with_altered_iterator_element_type(param.type, type_), + new_param.id, + _with_altered_iterator_position_dims(new_param.type, "unknown"), ) ) - new_args_inner.append(im.tuple_get(i, ref_to_orig_arg)) + new_params_inner.append(new_param) + new_args.append(im.tuple_get(i, ref_to_orig_arg)) param_substitute = im.lift( - im.lambda_(*new_params_inner)( - im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in new_params_inner]) + im.lambda_(*lift_params)( + im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params]) ) )(*[im.ref(p.id, p.type) for p in new_params_inner]) @@ -560,14 +577,14 @@ def transform_flatten_as_fieldop_args( # note: the lift is trivial so inlining it is not an issue with respect to tree size new_body = inline_lambda(new_body, force_inline_lift_args=True) new_params.extend(new_params_inner) - new_args.extend(new_args_inner) else: new_params.append(param) new_args.append(arg) # remove lifts again new_body = inline_lifts.InlineLifts( - flags=inline_lifts.InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + flags=inline_lifts.InlineLifts.Flag.INLINE_DEREF_LIFT + | inline_lifts.InlineLifts.Flag.PROPAGATE_SHIFT ).visit(new_body) new_body = self.visit(new_body, **kwargs) new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) From d3957bd805f96c1d16e8891b9976a08ecd872366 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 15:30:53 +0100 Subject: [PATCH 028/124] Fix format --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 6b04790644..97c8f1ca02 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -561,7 +561,7 @@ def transform_flatten_as_fieldop_args( lift_params.append( im.sym( new_param.id, - _with_altered_iterator_position_dims(new_param.type, "unknown"), + _with_altered_iterator_position_dims(new_param.type, "unknown"), # type: ignore[arg-type] # always in IteratorType ) ) new_params_inner.append(new_param) From b52a07c4288fe9b1dd6ecca23b050e8f7eeceec1 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 16:14:01 +0100 Subject: [PATCH 029/124] Cleanup --- .../next/iterator/transforms/collapse_tuple.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 97c8f1ca02..923f6d1302 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -170,7 +170,7 @@ class Transformation(enum.Flag): PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() - #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) + #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` FLATTEN_AS_FIELDOP_ARGS = enum.auto() #: `let(a, b[1])(a)` -> `b[1]` INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto() @@ -529,6 +529,7 @@ def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Opt def transform_flatten_as_fieldop_args( self, node: itir.FunCall, **kwargs ) -> Optional[itir.Node]: + # `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` if not cpm.is_applied_as_fieldop(node): return None @@ -545,13 +546,15 @@ def transform_flatten_as_fieldop_args( new_body = stencil.expr domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop - orig_args_map: dict[itir.Sym, itir.Expr] = {} + remapped_args: dict[ + itir.Sym, itir.Expr + ] = {} # contains the arguments that are remapped, e.g. `{a, b}` new_params: list[itir.Sym] = [] new_args: list[itir.Expr] = [] for param, arg in zip(stencil.params, node.args, strict=True): if isinstance(arg.type, ts.TupleType): - ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) - orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg + ref_to_remapped_arg = im.ref(f"__ct_flat_remapped_{len(remapped_args)}", arg.type) + remapped_args[im.sym(ref_to_remapped_arg.id, arg.type)] = arg new_params_inner, lift_params = [], [] for i, type_ in enumerate(param.type.element_type.types): new_param = im.sym( @@ -565,8 +568,10 @@ def transform_flatten_as_fieldop_args( ) ) new_params_inner.append(new_param) - new_args.append(im.tuple_get(i, ref_to_orig_arg)) + new_args.append(im.tuple_get(i, ref_to_remapped_arg)) + # an iterator that substitutes the original (tuple) iterator, e.g. `t`. Built + # from the new parameters which are the elements of `t`. param_substitute = im.lift( im.lambda_(*lift_params)( im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params]) @@ -576,6 +581,7 @@ def transform_flatten_as_fieldop_args( new_body = im.let(param.id, param_substitute)(new_body) # note: the lift is trivial so inlining it is not an issue with respect to tree size new_body = inline_lambda(new_body, force_inline_lift_args=True) + new_params.extend(new_params_inner) else: new_params.append(param) @@ -589,4 +595,4 @@ def transform_flatten_as_fieldop_args( new_body = self.visit(new_body, **kwargs) new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) - return im.let(*orig_args_map.items())(im.as_fieldop(new_stencil, domain)(*new_args)) + return im.let(*remapped_args.items())(im.as_fieldop(new_stencil, domain)(*new_args)) From c5c3e5f14d2c1f8e4f1f959f09af31f208a382ec Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 16:51:25 +0100 Subject: [PATCH 030/124] Fix pyproject.toml test marker --- pyproject.toml | 4 +++- tests/next_tests/definitions.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1efce6bd29..f7b2087a55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -326,7 +326,9 @@ markers = [ 'uses_unstructured_shift: tests that use a unstructured connectivity', 'uses_max_over: tests that use the max_over builtin', 'uses_mesh_with_skip_values: tests that use a mesh with skip values', - 'checks_specific_error: tests that rely on the backend to produce a specific error message' + 'checks_specific_error: tests that rely on the backend to produce a specific error message', + 'uses_frontend_concat_where: tests that use the frontend concat_where builtin', + 'uses_gtir_concat_where: tests that use the GTIR concat_where builtin', ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index e85283ee1d..81bf8bea70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -106,8 +106,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_ORIGIN = "uses_origin" USES_REDUCE_WITH_LAMBDA = "uses_reduce_with_lambda" USES_SCAN = "uses_scan" -USES_FRONTEND_CONCAT_WHERE = "uses_frontend_concat_where" -USES_GTIR_CONCAT_WHERE = "uses_gtir_concat_where" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" @@ -126,6 +124,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MAX_OVER = "uses_max_over" USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" +USES_FRONTEND_CONCAT_WHERE = "uses_frontend_concat_where" +USES_GTIR_CONCAT_WHERE = "uses_gtir_concat_where" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') From f59fabf237bef95b0afca6ee3f986fb869f5cdd4 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 16:58:17 +0100 Subject: [PATCH 031/124] Remove unnecessary visits --- pyproject.toml | 2 +- .../next/iterator/transforms/expand_library_functions.py | 5 ++++- src/gt4py/next/iterator/transforms/infer_domain_ops.py | 2 +- src/gt4py/next/iterator/transforms/nest_concat_wheres.py | 3 ++- src/gt4py/next/iterator/transforms/transform_concat_where.py | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7b2087a55..8786ad8381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -328,7 +328,7 @@ markers = [ 'uses_mesh_with_skip_values: tests that use a mesh with skip values', 'checks_specific_error: tests that rely on the backend to produce a specific error message', 'uses_frontend_concat_where: tests that use the frontend concat_where builtin', - 'uses_gtir_concat_where: tests that use the GTIR concat_where builtin', + 'uses_gtir_concat_where: tests that use the GTIR concat_where builtin' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 0f3d005452..4f2527ec5e 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -23,6 +23,8 @@ def apply(cls, node: ir.Node): return cls().visit(node) def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "in_"): ret = [] pos, domain = node.args @@ -36,4 +38,5 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: ) ) # TODO: avoid pos duplication return reduce(im.and_, ret) - return self.generic_visit(node) + + return node diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 6d874f9efa..22f57f9711 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -92,7 +92,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: self.visit(im.greater(dim, value), **kwargs), ) else: - raise ValueError(f"{fun} is not a valid comparison operator.") + raise AssertionError() if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( isinstance(arg.type, ts.DomainType) for arg in node.args diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py index 258494e0c4..df305f3765 100644 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -18,6 +18,7 @@ def apply(cls, node: ir.Node): def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): cond_expr, field_a, field_b = node.args if cpm.is_call_to(cond_expr, ("and_")): @@ -35,4 +36,4 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) return im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) - return self.generic_visit(node) + return node diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 62c302335a..6a94406163 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -41,4 +41,4 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: ), )(im.make_tuple(*dims), field_a, field_b, *refs) - return self.generic_visit(node) + return node From c8e06bd0aac5d68cf0a03ec86fd3c9fe7df99bd2 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 17:02:22 +0100 Subject: [PATCH 032/124] Cleanup trace shifts --- src/gt4py/next/iterator/transforms/trace_shifts.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 8dc3d46b24..dea5d807be 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -274,10 +274,7 @@ class TraceShifts(PreserveLocationVisitor, NodeTranslator): def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: return Sentinel.VALUE - def visit_InfinityLiteral(self, node: ir.SymRef, *, ctx: dict[str, Any]): - return Sentinel.VALUE - - def visit_NegInfinityLiteral(self, node: ir.SymRef, *, ctx: dict[str, Any]): + def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, ctx: dict[str, Any]): return Sentinel.VALUE def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: From f748da7828ee5d725c1d512e0a3099e06f380377 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 17:19:23 +0100 Subject: [PATCH 033/124] Fix type inference --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index ea4c6861d7..f80c17e2fc 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -116,7 +116,7 @@ def synthesize_binary_math_comparison_builtins( return ts.DomainType(dims=[rhs.dim]) if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): return ts.DomainType(dims=[lhs.dim]) - assert isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.ScalarType) + assert all(isinstance(lhs, (ts.ScalarType, ts.DeferredType)) for arg in (lhs, rhs)) return ts.ScalarType(kind=ts.ScalarKind.BOOL) From 45f8b09606478976e418e82228242a19802640f0 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 17:26:23 +0100 Subject: [PATCH 034/124] Add concat_where transforms to field view transforms --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 2 ++ src/gt4py/next/iterator/transforms/pass_manager.py | 6 ++++++ .../ffront_tests/test_concat_where.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 6622997d63..9d9018714d 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -189,6 +189,8 @@ def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f1540cabc8..477f62e6e4 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -198,5 +198,11 @@ def apply_fieldview_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + + # TODO: deduplicate with regular pass manager + ir = nest_concat_wheres.NestConcatWheres.apply(ir) + ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = ConstantFolding.apply(ir) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index a5e65d2c1f..e9eddf9701 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -19,6 +19,20 @@ ) +@pytest.mark.uses_frontend_concat_where +def test_concat_where_simple(cartesian_case): + @gtx.field_operator + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim > 0, air, ground) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground")() + air = cases.allocate(cartesian_case, testee, "air")() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + @pytest.mark.uses_frontend_concat_where def test_concat_where(cartesian_case): @gtx.field_operator From b3647bf71b39267bc5f09846afd12a66185783c2 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 17:55:32 +0100 Subject: [PATCH 035/124] Fix typo --- tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 341608111b..1393ad1687 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -164,7 +164,7 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=re.escape("Unsupported operand type(s) for %."), + match=re.escape("Unsupported operand type(s) for %"), ): _ = FieldOperatorParser.apply_to_function(domain_comparison) From 6ea11e50ed856510d1fb8e1630e2dd4fd0744147 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 18:40:27 +0100 Subject: [PATCH 036/124] Add support for tuples --- .../ffront/foast_passes/type_deduction.py | 30 +++++++++++-------- src/gt4py/next/ffront/foast_to_gtir.py | 10 ++++--- .../ffront_tests/test_concat_where.py | 5 +--- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 31f8f5b4eb..370d0be85c 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -10,7 +10,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next import errors +from gt4py.next import errors, utils from gt4py.next.common import DimensionKind, promote_dims from gt4py.next.ffront import ( # noqa dialect_ast_enums, @@ -997,22 +997,26 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: assert isinstance(mask_type, ts.DomainType) assert all( - isinstance(arg, (ts.FieldType, ts.ScalarType)) + isinstance(el, (ts.FieldType, ts.ScalarType)) for arg in (true_branch_type, false_branch_type) + for el in type_info.primitive_constituents(arg) ) - if (t_dtype := type_info.extract_dtype(true_branch_type)) != ( - f_dtype := type_info.extract_dtype(false_branch_type) - ): - raise errors.DSLError( - node.location, - f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", - ) - - return_dims = promote_dims( - mask_type.dims, type_info.promote(true_branch_type, false_branch_type).dims + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda el: ts.TupleType(types=list(el)), ) - return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)): + raise errors.DSLError( + node.location, + f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", + ) + return_dims = promote_dims(mask_type.dims, type_info.promote(tb, fb).dims) + return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + return return_type + + return_type = deduce_return_type(true_branch_type, false_branch_type) return foast.Call( func=node.func, diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 514c7526c7..3e7acf5082 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -397,10 +397,12 @@ def create_if( return im.let(cond_symref_name, cond_)(result) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return im.call("concat_where")(*self.visit(node.args)) - else: - raise NotImplementedError() + domain, true_branch, false_branch = self.visit(node.args) + return lowering_utils.process_elements( + lambda tb, fb: im.call("concat_where")(domain, tb, fb), + (true_branch, false_branch), + node.type, + ) # TODO: tuple case diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index e9eddf9701..34b3bbe79f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -33,6 +33,7 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + @pytest.mark.uses_frontend_concat_where def test_concat_where(cartesian_case): @gtx.field_operator @@ -197,11 +198,8 @@ def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField @pytest.mark.uses_frontend_concat_where @pytest.mark.uses_tuple_returns def test_with_tuples(cartesian_case): - pytest.skip("Not implemented in the frontend.") - @gtx.field_operator def testee( - k: cases.KField, interior0: cases.IJKField, boundary0: cases.IJField, interior1: cases.IJKField, @@ -230,7 +228,6 @@ def testee( cases.verify( cartesian_case, testee, - k, interior0, boundary0, interior1, From 60d0d9afd869a32240aaca9f2305f7549ec8131a Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 23:09:21 +0100 Subject: [PATCH 037/124] Fixes --- src/gt4py/next/ffront/foast_to_gtir.py | 5 +++++ .../transforms/expand_library_functions.py | 5 +++++ .../next/iterator/transforms/global_tmps.py | 3 ++- .../next/iterator/transforms/infer_domain.py | 20 ++++++++++++++++++- .../iterator/transforms/infer_domain_ops.py | 13 ++++++++---- .../iterator/transforms/nest_concat_wheres.py | 1 + .../next/iterator/transforms/pass_manager.py | 12 ++++++----- .../transforms/transform_concat_where.py | 6 ++++++ .../next/program_processors/runners/gtfn.py | 2 +- 9 files changed, 55 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3e7acf5082..62f1c15050 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -229,6 +229,11 @@ def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.Axis return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) return im.ref(node.id) + def visit_Attribute(self, node: foast.Attribute, **kwargs): + if isinstance(node.type, ts.DimensionType): + return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind) + raise AssertionError() + def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: return im.tuple_get(node.index, self.visit(node.value, **kwargs)) diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 4f2527ec5e..6711c7d7e9 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -18,6 +18,11 @@ class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + @classmethod def apply(cls, node: ir.Node): return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index ac7fcb8f1c..a54497f1fc 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -186,7 +186,8 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - program = infer_domain.infer_program(program, offset_provider=offset_provider) + # TODO: document why to keep existing domains, add test + program = infer_domain.infer_program(program, offset_provider=offset_provider, keep_existing_domains=True) program = type_inference.infer( program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 1fc0c7067a..dddf1fc0a4 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -21,6 +21,7 @@ domain_utils, ir_makers as im, ) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain from gt4py.next.iterator.transforms import constant_folding, trace_shifts from gt4py.next.utils import flatten_nested_tuple, tree_map @@ -54,6 +55,7 @@ class InferenceOptions(typing.TypedDict): offset_provider: common.OffsetProvider symbolic_domain_sizes: Optional[dict[str, str]] allow_uninferred: bool + keep_existing_domains: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -197,11 +199,16 @@ def _infer_as_fieldop( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, + keep_existing_domains: bool ) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") + + if len(applied_fieldop.fun.args) == 2 and keep_existing_domains: + target_domain = SymbolicDomain.from_expr(applied_fieldop.fun.args[1]) + # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough @@ -241,6 +248,7 @@ def _infer_as_fieldop( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains ) transformed_inputs.append(transformed_input) @@ -431,6 +439,7 @@ def infer_expr( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, + keep_existing_domains: bool = False ) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -445,6 +454,10 @@ def infer_expr( name that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. + # TODO: describe why this is needed with concat_where (if inside as_fieldop might shrinken the + actually access domain) + - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and + use them to propagate the domain further. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) @@ -457,8 +470,10 @@ def infer_expr( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) - expr.annex.domain = domain + if not keep_existing_domains or not hasattr(expr.annex, "domain"): + expr.annex.domain = domain return expr, accessed_domains @@ -496,6 +511,8 @@ def infer_program( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, + # TODO: add test + keep_existing_domains: bool = False ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -517,6 +534,7 @@ def infer_program( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains ) for stmt in program.body ], diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 22f57f9711..508634e255 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +import dataclasses from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common @@ -16,13 +16,18 @@ ir_makers as im, ) from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import _get_gridtype from gt4py.next.type_system import type_specifications as ts +@dataclasses.dataclass class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + grid_type: common.GridType + @classmethod - def apply(cls, node: ir.Node): - return cls().visit(node, recurse=True) + def apply(cls, program: ir.Program): + # TODO: move _get_gridtype + return cls(grid_type=_get_gridtype(program.body)).visit(program, recurse=True) def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if kwargs["recurse"]: @@ -80,7 +85,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: max_ = im.plus(value, 1) domain = domain_utils.SymbolicDomain( - common.GridType.CARTESIAN, # TODO + self.grid_type, ranges={dim: domain_utils.SymbolicRange(start=min_, stop=max_)}, ) diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py index df305f3765..829bdf2808 100644 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -21,6 +21,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: if cpm.is_call_to(node, "concat_where"): cond_expr, field_a, field_b = node.args + # TODO: don't duplicate exprs here if cpm.is_call_to(cond_expr, ("and_")): conds = cond_expr.args return im.concat_where( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 477f62e6e4..ac23a80073 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -94,17 +94,19 @@ def apply_common_transforms( ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) - ir = ConstantFolding.apply(ir) # TODO: remove - ir = transform_concat_where.TransformConcatWhere.apply(ir) - ir = ConstantFolding.apply(ir) # TODO: remove - ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) - ir = infer_domain.infer_program( ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + # Note: executing domain inference again afterwards will give wrong domains. + # This might be problematic in the temporary extraction, where we do this... + ir = ConstantFolding.apply(ir) # TODO: remove + ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = ConstantFolding.apply(ir) # TODO: remove + ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) + for _ in range(10): inlined = ir diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 6a94406163..363b7401d5 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -17,6 +17,11 @@ class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + @classmethod def apply(cls, node: ir.Node): return cls().visit(node) @@ -39,6 +44,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: ) ) ), + node.annex.domain.as_expr() )(im.make_tuple(*dims), field_a, field_b, *refs) return node diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a8961fd9bc..333008c6a3 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -219,7 +219,7 @@ class Params: run_gtfn_gpu = GTFNBackendFactory(gpu=True) -run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) +run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True, otf_workflow__cached_translation=True) run_gtfn_no_transforms = GTFNBackendFactory( otf_workflow__bare_translation__enable_itir_transforms=False From 132e5764566597f83dac75fdd5f1175f8d0f9303 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 21 Feb 2025 01:29:15 +0100 Subject: [PATCH 038/124] Improve docs --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 3 ++- src/gt4py/next/iterator/transforms/expand_library_functions.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2e6a3505dd..d8fb19f694 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -468,7 +468,8 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal def _populate_domain_annex_wrapper(*args, **kwargs): node = result(*args, **kwargs) # note: if the domain is not a direct construction, e.g. because it is only a reference - # to a domain defined in a let, don't populate the annex + # to a domain defined in a let, don't populate the annex, since we can not create a + # symbolic domain for it. if domain and cpm.is_call_to(domain, ("cartesian_domain", "unstructured_domain")): node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 6711c7d7e9..5dc628c38a 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -30,6 +30,8 @@ def apply(cls, node: ir.Node): def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: node = self.generic_visit(node) + # `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` + # -> `i0 < i < i1 & j0 < j < j1 & k0 < k < k1` if cpm.is_call_to(node, "in_"): ret = [] pos, domain = node.args From e46907559b90f59eef1df66a2c86433bc3d73364 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 21 Feb 2025 01:38:26 +0100 Subject: [PATCH 039/124] Improve docs --- src/gt4py/next/iterator/transforms/global_tmps.py | 4 +++- src/gt4py/next/iterator/transforms/infer_domain.py | 10 +++++----- .../next/iterator/transforms/transform_concat_where.py | 2 +- src/gt4py/next/program_processors/runners/gtfn.py | 4 +++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a54497f1fc..04c22f6f4f 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -187,7 +187,9 @@ def create_global_tmps( arguments into temporaries. """ # TODO: document why to keep existing domains, add test - program = infer_domain.infer_program(program, offset_provider=offset_provider, keep_existing_domains=True) + program = infer_domain.infer_program( + program, offset_provider=offset_provider, keep_existing_domains=True + ) program = type_inference.infer( program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index dddf1fc0a4..cb622c6f61 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -199,7 +199,7 @@ def _infer_as_fieldop( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, - keep_existing_domains: bool + keep_existing_domains: bool, ) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") @@ -248,7 +248,7 @@ def _infer_as_fieldop( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, - keep_existing_domains=keep_existing_domains + keep_existing_domains=keep_existing_domains, ) transformed_inputs.append(transformed_input) @@ -439,7 +439,7 @@ def infer_expr( offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, - keep_existing_domains: bool = False + keep_existing_domains: bool = False, ) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -512,7 +512,7 @@ def infer_program( symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, # TODO: add test - keep_existing_domains: bool = False + keep_existing_domains: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -534,7 +534,7 @@ def infer_program( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, - keep_existing_domains=keep_existing_domains + keep_existing_domains=keep_existing_domains, ) for stmt in program.body ], diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index 363b7401d5..be238c812b 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -44,7 +44,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: ) ) ), - node.annex.domain.as_expr() + node.annex.domain.as_expr(), )(im.make_tuple(*dims), field_a, field_b, *refs) return node diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 333008c6a3..b6983a9bb3 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -219,7 +219,9 @@ class Params: run_gtfn_gpu = GTFNBackendFactory(gpu=True) -run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True, otf_workflow__cached_translation=True) +run_gtfn_gpu_cached = GTFNBackendFactory( + gpu=True, cached=True, otf_workflow__cached_translation=True +) run_gtfn_no_transforms = GTFNBackendFactory( otf_workflow__bare_translation__enable_itir_transforms=False From 24e2f57af197b5eb055bda21290d455f5f9e50fb Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 21 Feb 2025 13:49:48 +0100 Subject: [PATCH 040/124] Fix typo --- src/gt4py/next/iterator/transforms/infer_domain_ops.py | 2 +- .../feature_tests/ffront_tests/test_concat_where.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 508634e255..60e582a00d 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -79,7 +79,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: elif cpm.is_call_to(node, "greater_equal"): min_ = value max_ = ir.InfinityLiteral.POSITIVE - # IDim == 1 + # IDim == 1 # TODO: isn't this removed before and rewritten as two concat_where? elif cpm.is_call_to(node, "eq"): min_ = value max_ = im.plus(value, 1) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 34b3bbe79f..1939bbd1bf 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -69,7 +69,7 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: def test_concat_where_non_overlapping_different_dims(cartesian_case): @gtx.field_operator def testee( - ground: cases.KField, # note: boundary field is only defined in K + ground: cases.IJField, # note: boundary field is only defined in K air: cases.IJKField, ) -> cases.IJKField: return concat_where(KDim == 0, ground, air) From d14fb21b1fc5e249e77eba5451d6263e22f4d3b3 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 24 Feb 2025 16:27:02 +0100 Subject: [PATCH 041/124] Cleanup & improve test coverage --- src/gt4py/next/iterator/ir.py | 6 +- .../iterator/transforms/constant_folding.py | 3 +- .../iterator/transforms/infer_domain_ops.py | 10 +- .../ffront_tests/test_concat_where.py | 98 ++++++++++++------- 4 files changed, 74 insertions(+), 43 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index d1e82b6edc..f054cfc203 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations +import typing from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import gt4py.eve as eve @@ -65,10 +66,9 @@ class NoneLiteral(Expr): class InfinityLiteral(Expr): + # TODO(tehrengruber): self referential `ClassVar` not supported in eve. if TYPE_CHECKING: - POSITIVE: ClassVar[ - InfinityLiteral - ] # TODO(tehrengruber): should be `ClassVar[InfinityLiteral]`, but self-referential not supported in eve + POSITIVE: ClassVar[InfinityLiteral] NEGATIVE: ClassVar[InfinityLiteral] name: typing.Literal["POSITIVE", "NEGATIVE"] diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index cec465ad68..ccc76ead06 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -32,8 +32,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node: return new_node.args[0] if cpm.is_call_to(new_node, "plus"): - a, b = new_node.args - for arg, other_arg in ((a, b), (b, a)): + for arg in new_node.args: # `a + inf` -> `inf` if arg == ir.InfinityLiteral.POSITIVE: return ir.InfinityLiteral.POSITIVE diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 60e582a00d..ac57453b84 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -63,23 +63,23 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: min_: int | ir.InfinityLiteral max_: int | ir.InfinityLiteral - # IDim < 1 + # `IDim < 1` if cpm.is_call_to(node, "less"): min_ = ir.InfinityLiteral.NEGATIVE max_ = value - # IDim <= 1 + # `IDim <= 1` elif cpm.is_call_to(node, "less_equal"): min_ = ir.InfinityLiteral.NEGATIVE max_ = im.plus(value, 1) - # IDim > 1 + # `IDim > 1` elif cpm.is_call_to(node, "greater"): min_ = im.plus(value, 1) max_ = ir.InfinityLiteral.POSITIVE - # IDim >= 1 + # `IDim >= 1` elif cpm.is_call_to(node, "greater_equal"): min_ = value max_ = ir.InfinityLiteral.POSITIVE - # IDim == 1 # TODO: isn't this removed before and rewritten as two concat_where? + # `IDim == 1` # TODO: isn't this removed before and rewritten as two concat_where? elif cpm.is_call_to(node, "eq"): min_ = value max_ = im.plus(value, 1) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 1939bbd1bf..7da87fb2bb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -66,29 +66,81 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: @pytest.mark.uses_frontend_concat_where -def test_concat_where_non_overlapping_different_dims(cartesian_case): +def test_concat_where_single_level_broadcast(cartesian_case): @gtx.field_operator - def testee( - ground: cases.IJField, # note: boundary field is only defined in K - air: cases.IJKField, - ) -> cases.IJKField: - return concat_where(KDim == 0, ground, air) + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ground = cases.allocate(cartesian_case, testee, "ground", domain=gtx.domain({KDim: (0, 1)}))() - air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() + a = cases.allocate( + cartesian_case, testee, "a", domain=gtx.domain({KDim: out.domain.shape[2]}) + )() + b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() ref = np.concatenate( ( - np.tile( - ground.asnumpy(), (*air.domain.shape[0:2], len(ground.domain[KDim].unit_range)) - ), - air.asnumpy(), + np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + b.asnumpy(), ), axis=2, ) + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) - cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + +@pytest.mark.uses_frontend_concat_where +def test_concat_where_single_level_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + # note: this field is only defined on K: 0, 1, i.e., contains only a single value + a = cases.allocate(cartesian_case, testee, "a", domain=gtx.domain({KDim: (0, 1)}))() + b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate( + ( + np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + b.asnumpy(), + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + + +@pytest.mark.uses_frontend_concat_where +def test_lap_like(cartesian_case): + pytest.xfail("Requires #1847.") + + @gtx.field_operator + def testee( + input: cases.IJKField, boundary: float, shape: tuple[np.int64, np.int64, np.int64] + ) -> cases.IJKField: + return concat_where( + (IDim == 0) + | (JDim == 0) + | (KDim == 0) + | (IDim == shape[0] - 1) + | (JDim == shape[1] - 1) + | (KDim == shape[2] - 1), + boundary, + input, + ) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + input = cases.allocate( + cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1, 1:-1] + )() + boundary = 2.0 + + ref = np.full(out.domain.shape, np.nan) + ref[0, :, :] = boundary + ref[:, 0, :] = boundary + ref[:, :, 0] = boundary + ref[-1, :, :] = boundary + ref[:, -1, :] = boundary + ref[:, :, -1] = boundary + cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref) @pytest.mark.uses_frontend_concat_where @@ -155,26 +207,6 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -@pytest.mark.uses_frontend_concat_where -def test_boundary_horizontal_slice(cartesian_case): - @gtx.field_operator - def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: - return concat_where(KDim == 0, boundary, interior) - - interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() - - k = np.arange(0, cartesian_case.default_sizes[KDim]) - ref = np.where( - k[np.newaxis, np.newaxis, :] == 0, - boundary.asnumpy()[:, :, np.newaxis], - interior.asnumpy(), - ) - - cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) - - @pytest.mark.uses_frontend_concat_where def test_boundary_single_layer(cartesian_case): @gtx.field_operator From 1e3ced5c1d3d35324e1ad3167630379829b7f822 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 24 Feb 2025 16:28:55 +0100 Subject: [PATCH 042/124] Cleanup --- .../feature_tests/ffront_tests/test_concat_where.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 7da87fb2bb..4f6b1bf806 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -88,7 +88,7 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: @pytest.mark.uses_frontend_concat_where -def test_concat_where_single_level_broadcast(cartesian_case): +def test_concat_where_single_level_restricted_domain_broadcast(cartesian_case): @gtx.field_operator def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: return concat_where(KDim == 0, a, b) From 595b675abf7f9ed237f0d8fd386c6b2529630e56 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 24 Feb 2025 17:00:37 +0100 Subject: [PATCH 043/124] Cleanup --- .../next/type_system/type_specifications.py | 2 + .../ffront_tests/test_concat_where.py | 100 +++++++++++------- 2 files changed, 62 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 839e89bd34..cc70f41b48 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -139,4 +139,6 @@ def __str__(self) -> str: class DomainType(DataType): + # TODO(tehrengruber): Remove "unknown" here again after the result type of `as_fieldop` + # is always precisely known. This is the case after #1853. dims: list[common.Dimension] | Literal["unknown"] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 4f6b1bf806..e6a63765f7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -109,38 +109,43 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: @pytest.mark.uses_frontend_concat_where -def test_lap_like(cartesian_case): - pytest.xfail("Requires #1847.") +def test_boundary_single_layer_3d_bc(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where( + k[np.newaxis, np.newaxis, :] == 0, + np.broadcast_to(boundary.asnumpy(), interior.shape), + interior.asnumpy(), + ) + + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + +@pytest.mark.uses_frontend_concat_where +def test_boundary_single_layer_2d_bc(cartesian_case): @gtx.field_operator - def testee( - input: cases.IJKField, boundary: float, shape: tuple[np.int64, np.int64, np.int64] - ) -> cases.IJKField: - return concat_where( - (IDim == 0) - | (JDim == 0) - | (KDim == 0) - | (IDim == shape[0] - 1) - | (JDim == shape[1] - 1) - | (KDim == shape[2] - 1), - boundary, - input, - ) + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - input = cases.allocate( - cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1, 1:-1] - )() - boundary = 2.0 - ref = np.full(out.domain.shape, np.nan) - ref[0, :, :] = boundary - ref[:, 0, :] = boundary - ref[:, :, 0] = boundary - ref[-1, :, :] = boundary - ref[:, -1, :] = boundary - ref[:, :, -1] = boundary - cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref) + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where( + k[np.newaxis, np.newaxis, :] == 0, + boundary.asnumpy()[:, :, np.newaxis], + interior.asnumpy(), + ) + + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) @pytest.mark.uses_frontend_concat_where @@ -208,23 +213,38 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: @pytest.mark.uses_frontend_concat_where -def test_boundary_single_layer(cartesian_case): +def test_lap_like(cartesian_case): + pytest.xfail("Requires #1847.") + @gtx.field_operator - def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: - return concat_where(KDim == 0, boundary, interior) + def testee( + input: cases.IJKField, boundary: float, shape: tuple[np.int64, np.int64, np.int64] + ) -> cases.IJKField: + return concat_where( + (IDim == 0) + | (JDim == 0) + | (KDim == 0) + | (IDim == shape[0] - 1) + | (JDim == shape[1] - 1) + | (KDim == shape[2] - 1), + boundary, + input, + ) - interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + input = cases.allocate( + cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1, 1:-1] + )() + boundary = 2.0 - k = np.arange(0, cartesian_case.default_sizes[KDim]) - ref = np.where( - k[np.newaxis, np.newaxis, :] == 0, - np.broadcast_to(boundary.asnumpy(), interior.shape), - interior.asnumpy(), - ) - - cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + ref = np.full(out.domain.shape, np.nan) + ref[0, :, :] = boundary + ref[:, 0, :] = boundary + ref[:, :, 0] = boundary + ref[-1, :, :] = boundary + ref[:, -1, :] = boundary + ref[:, :, -1] = boundary + cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref) @pytest.mark.uses_frontend_concat_where From 59a1226e9def4ab635ccdddc7243888591ad0a64 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 28 Feb 2025 14:51:55 +0100 Subject: [PATCH 044/124] Improve type inference for concat_where tuple case --- src/gt4py/next/ffront/foast_to_gtir.py | 8 +------ .../iterator/type_system/type_synthesizer.py | 21 +++++++++++++++++-- .../iterator_tests/test_type_inference.py | 19 ++++++++++++++++- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 62f1c15050..1b225510ca 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -403,13 +403,7 @@ def create_if( def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: domain, true_branch, false_branch = self.visit(node.args) - return lowering_utils.process_elements( - lambda tb, fb: im.call("concat_where")(domain, tb, fb), - (true_branch, false_branch), - node.type, - ) - - # TODO: tuple case + return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f80c17e2fc..77bc68f06e 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -13,7 +13,7 @@ import inspect from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.iterator import builtins from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts @@ -128,6 +128,7 @@ def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: @_register_builtin_type_synthesizer(fun_names=builtins.BINARY_LOGICAL_BUILTINS) def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: if isinstance(lhs, ts.DomainType) and isinstance(rhs, ts.DomainType): + assert lhs.dims != "unknown" and rhs.dims != "unknown" return ts.DomainType(dims=common.promote_dims(lhs.dims, rhs.dims)) else: return synthesize_binary_math_comparison_builtins(lhs, rhs) @@ -230,7 +231,23 @@ def concat_where( ) -> ts.FieldType | ts.TupleType | ts.DeferredType: if isinstance(true_field, ts.DeferredType) or isinstance(false_field, ts.DeferredType): return ts.DeferredType(constraint=None) - return type_info.promote(true_field, false_field) + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda el: ts.TupleType(types=list(el)), + ) + def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + tb_dtype, fb_type = (type_info.extract_dtype(b) for b in [tb, fb]) + + assert ( + tb_dtype == fb_type + ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_type}'." + + return_dims = common.promote_dims(domain.dims, type_info.promote(tb, fb).dims) + return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + return return_type + + return deduce_return_type(true_field, false_field) @_register_builtin_type_synthesizer diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 6a0fe82f8f..fb7a66ab43 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -52,6 +52,7 @@ int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) +float_ij_field = ts.FieldType(dims=[IDim, JDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) float_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=float64_type) float_vertex_v2e_field = ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type) @@ -72,7 +73,6 @@ def expression_test_cases(): return ( # itir expr, type - # TODO: write test for IDim < 10, concat_where (im.call("abs")(1), int_type), (im.call("power")(2.0, 2), float64_type), (im.plus(1, 2), int_type), @@ -213,6 +213,23 @@ def expression_test_cases(): ), ts.TupleType(types=[float_i_field, float_i_field]), ), + # concat_where + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", float_i_field), + im.ref("b", float_ij_field), + ), + float_ij_field, + ), + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", ts.TupleType(types=[float_i_field] * 2)), + im.ref("b", ts.TupleType(types=[float_i_field] * 2)), + ), + ts.TupleType(types=[float_i_field] * 2), + ), ) From f832a19fe5d4df404a078b59ceeb3a681447fe5e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 28 Feb 2025 14:59:52 +0100 Subject: [PATCH 045/124] Fix typo --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 77bc68f06e..6a9bb08878 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -237,14 +237,14 @@ def concat_where( result_collection_constructor=lambda el: ts.TupleType(types=list(el)), ) def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): - tb_dtype, fb_type = (type_info.extract_dtype(b) for b in [tb, fb]) + tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) assert ( - tb_dtype == fb_type - ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_type}'." + tb_dtype == fb_dtype + ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." return_dims = common.promote_dims(domain.dims, type_info.promote(tb, fb).dims) - return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(tb_dtype, fb_dtype)) return return_type return deduce_return_type(true_field, false_field) From 75cc4f2c5ef41b1e192573a3dadbd2a7778c5822 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 10:11:19 +0100 Subject: [PATCH 046/124] Fix bug in infer domain ops --- .../iterator/transforms/expand_library_functions.py | 4 ++-- .../next/iterator/transforms/infer_domain_ops.py | 10 +++++----- .../feature_tests/ffront_tests/test_concat_where.py | 11 ++++++----- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 5dc628c38a..5b6dd9ec1e 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -31,7 +31,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: node = self.generic_visit(node) # `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` - # -> `i0 < i < i1 & j0 < j < j1 & k0 < k < k1` + # -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` if cpm.is_call_to(node, "in_"): ret = [] pos, domain = node.args @@ -43,7 +43,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: im.less_equal(v.start, im.tuple_get(i, pos)), im.less(im.tuple_get(i, pos), v.stop), ) - ) # TODO: avoid pos duplication + ) # TODO(tehrengruber): Avoid position expr duplication. return reduce(im.and_, ret) return node diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index ac57453b84..f8e54e4373 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -41,12 +41,12 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: ): # TODO: add tests arg1, arg2 = node.args if isinstance(arg2, ir.AxisLiteral): - # take complementary operation if we have e.g. `IDim > 1` use `1 <= IDim` + # take complementary operation if we have e.g. `0 < IDim` use `IDim > 0` complementary_op = { - "less": "greater_equal", - "less_equal": "greater", - "greater": "less_equal", - "greater_equal": "less", + "less": "greater", + "less_equal": "greater_equal", + "greater": "greater_equal", + "greater_equal": "less_equal", "eq": "eq", "not_eq": "not_eq", } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index e6a63765f7..653cc1737e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -170,16 +170,17 @@ def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField @pytest.mark.uses_frontend_concat_where def test_dimension_two_conditions_and(cartesian_case): @gtx.field_operator - def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: - return concat_where(((KDim > 2) & (KDim <= 5)), interior, boundary) + def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: + return concat_where(0 < KDim < (nlev-1), interior, boundary) interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - k = np.arange(0, cartesian_case.default_sizes[KDim]) - ref = np.where((k > 2) & (k <= 5), interior.asnumpy(), boundary.asnumpy()) - cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + nlev = cartesian_case.default_sizes[KDim] + k = np.arange(0, nlev) + ref = np.where((0 < k) & (k < (nlev-1)), interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) @pytest.mark.uses_frontend_concat_where From 6e85bd056f84d2d5e4b6aa4f7af2e642c6b05d02 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 11:54:28 +0100 Subject: [PATCH 047/124] Address review comments --- .../unit_tests/iterator_tests/test_type_inference.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 577c7bce1c..c8620af0b6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -238,6 +238,17 @@ def test_adhoc_polymorphism(): assert result.type == ts.TupleType(types=[bool_type, int_type]) +def test_binary_lambda(): + func = im.lambda_("a", "b")(im.make_tuple("a", "b")) + testee = im.call(func)(im.ref("a_", bool_type), im.ref("b_", int_type)) + + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + + assert result.type == ts.TupleType(types=[bool_type, int_type]) + + def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) result = itir_type_inference.infer(testee, offset_provider_type={}) From 9978a43e011acedf4f38ba60093b542b200f0ba3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 11:57:07 +0100 Subject: [PATCH 048/124] Address review comments --- .../unit_tests/iterator_tests/test_type_inference.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index c8620af0b6..c13cb1d119 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -246,7 +246,16 @@ def test_binary_lambda(): testee, offset_provider_type={}, allow_undeclared_symbols=True ) - assert result.type == ts.TupleType(types=[bool_type, int_type]) + expected_type = ts.TupleType(types=[bool_type, int_type]) + assert result.type == expected_type + assert result.fun.params[0].type == bool_type + assert result.fun.params[1].type == int_type + assert result.fun.type == ts.FunctionType( + pos_only_args=[bool_type, int_type], + pos_or_kw_args={}, + kw_only_args={}, + returns=expected_type, + ) def test_aliased_function(): From 232d4b8d580cb24965ff17d7fe4042b3463a0c7d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 12:03:13 +0100 Subject: [PATCH 049/124] Address review comments --- src/gt4py/next/iterator/type_system/inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index cc7a7123b9..d6faefc372 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -37,6 +37,10 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: assert type_info.is_compatible_type( node.type, type_ ), "Node already has a type which differs." + # Also populate the type of the parameters of a lambda. That way the one can access the type + # of a parameter by a lookup in the symbol table. As long as `_set_node_type` is used + # exclusively, the information stays consistent with the types stored in the `FunctionType` + # of the lambda itself. if isinstance(node, itir.Lambda): assert isinstance(type_, ts.FunctionType) for param, param_type in zip(node.params, type_.pos_only_args): From d0f93be87c523baad650717b80142351e76b5ab3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 3 Mar 2025 16:06:39 +0100 Subject: [PATCH 050/124] Fix deferred type in concat_where --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6a9bb08878..e738e39553 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -237,6 +237,9 @@ def concat_where( result_collection_constructor=lambda el: ts.TupleType(types=list(el)), ) def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + if any(isinstance(b, ts.DeferredType) for b in [tb, fb]): + return ts.DeferredType(constraint=ts.FieldType) + tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) assert ( From cf50a37b6f4aae11f73353387dae7418d1546f28 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 3 Mar 2025 16:33:00 +0100 Subject: [PATCH 051/124] Fix tuple concat_where (not fully done yet) --- .../next/iterator/transforms/infer_domain.py | 17 ++++++++++------- .../transforms/transform_concat_where.py | 8 +++++++- .../ffront_tests/test_concat_where.py | 4 ++-- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index cb622c6f61..0de8e02bf3 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -8,6 +8,7 @@ from __future__ import annotations +import functools import itertools import typing @@ -377,21 +378,23 @@ def _infer_concat_where( **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "concat_where") - assert isinstance(domain, domain_utils.SymbolicDomain) + # assert isinstance(domain, domain_utils.SymbolicDomain) # todo: per el assert infered_args_expr = [] actual_domains: AccessedDomains = {} cond, true_field, false_field = expr.args symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) for arg in [true_field, false_field]: if arg == true_field: - extended_cond = domain_utils.promote_to_same_dimensions(symbolic_cond, domain) - domain_ = domain_utils.domain_intersection(domain, extended_cond) + extended_cond = tree_map( + functools.partial(domain_utils.promote_to_same_dimensions, symbolic_cond) + )(domain) + domain_ = tree_map(domain_utils.domain_intersection)(domain, extended_cond) elif arg == false_field: cond_complement = domain_utils.domain_complement(symbolic_cond) - extended_cond_complement = domain_utils.promote_to_same_dimensions( - cond_complement, domain - ) - domain_ = domain_utils.domain_intersection(domain, extended_cond_complement) + extended_cond_complement = tree_map( + functools.partial(domain_utils.promote_to_same_dimensions, cond_complement) + )(domain) + domain_ = tree_map(domain_utils.domain_intersection)(domain, extended_cond_complement) infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) infered_args_expr.append(infered_arg_expr) diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py index be238c812b..2f43948a87 100644 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ b/src/gt4py/next/iterator/transforms/transform_concat_where.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import utils from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, @@ -34,6 +35,11 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] refs = symbol_ref_utils.collect_symbol_refs(cond_expr) + # TODO: cleanup + domains = utils.flatten_nested_tuple(node.annex.domain) + assert all(domain == domains[0] for domain in domains) + domain_expr = domains[0].as_expr() + return im.as_fieldop( im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( im.let(*zip(refs, map(im.deref, refs), strict=True))( @@ -44,7 +50,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: ) ) ), - node.annex.domain.as_expr(), + domain_expr, )(im.make_tuple(*dims), field_a, field_b, *refs) return node diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 653cc1737e..40ecfda24c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -171,7 +171,7 @@ def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField def test_dimension_two_conditions_and(cartesian_case): @gtx.field_operator def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: - return concat_where(0 < KDim < (nlev-1), interior, boundary) + return concat_where(0 < KDim < (nlev - 1), interior, boundary) interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() @@ -179,7 +179,7 @@ def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> ca nlev = cartesian_case.default_sizes[KDim] k = np.arange(0, nlev) - ref = np.where((0 < k) & (k < (nlev-1)), interior.asnumpy(), boundary.asnumpy()) + ref = np.where((0 < k) & (k < (nlev - 1)), interior.asnumpy(), boundary.asnumpy()) cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) From 5fc42cea4366c99fc280a0bbc9c188a4a72cebf5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 3 Mar 2025 17:26:26 +0100 Subject: [PATCH 052/124] Fix tuple concat_where (not fully done yet) --- .../iterator/transforms/infer_domain_ops.py | 24 +++++++++---------- .../next/iterator/transforms/pass_manager.py | 5 ++++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index f8e54e4373..e20e1bcb22 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -99,17 +99,17 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: else: raise AssertionError() - if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( - isinstance(arg.type, ts.DomainType) for arg in node.args - ): - if cpm.is_call_to(node, "and_"): - # TODO: domain promotion - return ConstantFolding.apply( - domain_utils.domain_intersection( - *[domain_utils.SymbolicDomain.from_expr(arg) for arg in node.args] - ).as_expr() - ) - else: - raise NotImplementedError() + # if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( + # isinstance(arg.type, ts.DomainType) for arg in node.args + # ): + # if cpm.is_call_to(node, "and_"): + # # TODO: domain promotion + # return ConstantFolding.apply( + # domain_utils.domain_intersection( + # *[domain_utils.SymbolicDomain.from_expr(arg) for arg in node.args] + # ).as_expr() + # ) + # else: + # raise NotImplementedError() return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ac23a80073..a5497ed3fa 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -91,8 +91,12 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + # TODO(tehrengruber): fuse into one pass? InferDomainOps might create an `and` again so the + # second call is required. This happens in test_fused_velocity_advection_stencil_15_to_18. + # TODO(tehrengruber): also write a test case. ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain.infer_program( ir, @@ -204,6 +208,7 @@ def apply_fieldview_transforms( # TODO: deduplicate with regular pass manager ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = ConstantFolding.apply(ir) ir = infer_domain.infer_program(ir, offset_provider=offset_provider) From 77edc985a1b0b4eb2b4f778362c0c58984a5a55e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 11 Mar 2025 15:18:18 +0100 Subject: [PATCH 053/124] Unclean fixes (revert tuple lowering) --- src/gt4py/next/ffront/foast_to_gtir.py | 10 +++- .../next/iterator/ir_utils/domain_utils.py | 16 ++++++ .../next/iterator/transforms/infer_domain.py | 32 ++++++------ .../iterator/transforms/infer_domain_ops.py | 4 +- .../iterator/transforms/nest_concat_wheres.py | 50 +++++++++++++++---- .../next/iterator/transforms/pass_manager.py | 5 -- 6 files changed, 85 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 1b225510ca..eadc860179 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -403,7 +403,15 @@ def create_if( def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: domain, true_branch, false_branch = self.visit(node.args) - return im.concat_where(domain, true_branch, false_branch) + return lowering_utils.process_elements( + lambda tb, fb: im.call("concat_where")(domain, tb, fb), + (true_branch, false_branch), + node.type, + ) + # TODO: use this case again. breaks domain inference in fused_velocity_advection_stencil_1_to_7 + # because some tuple elements are never accessed and the collapse tuple + # does not propagate across concat where + #return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: expr = self.visit(node.args[0], **kwargs) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 9d9018714d..b3f341eac7 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -226,3 +226,19 @@ def promote_to_same_dimensions( itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE ) return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured + + +def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: + """ + Return whether a range is unbounded in (at least) one direction. + + The expression is required to be constant folded before for the result to be reliable. + """ + if isinstance(range_ := range_or_domain, SymbolicRange): + # TODO: assert no infinity literal in here + if any(v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] for v in [range_.start, range_.stop]): + return False + return True + elif isinstance(domain := range_or_domain, SymbolicDomain): + return all(is_finite(range_) for range_ in domain.ranges.values()) + raise ValueError("Expected a SymbolicRange or SymbolicDomain.") \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 0de8e02bf3..d03caf65b5 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -15,7 +15,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, @@ -127,8 +127,8 @@ def _canonicalize_domain_structure( ... ) True """ - if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): - return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if not isinstance(d1, tuple) and isinstance(d2, tuple): + return _canonicalize_domain_structure((d1,) * len(d2), d2) if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): @@ -378,23 +378,25 @@ def _infer_concat_where( **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "concat_where") - # assert isinstance(domain, domain_utils.SymbolicDomain) # todo: per el assert + #assert all(isinstance(domain, domain_utils.SymbolicDomain) for domain in utils.flatten_nested_tuple(domain)) infered_args_expr = [] actual_domains: AccessedDomains = {} cond, true_field, false_field = expr.args symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + cond_complement = domain_utils.domain_complement(symbolic_cond) + for arg in [true_field, false_field]: - if arg == true_field: - extended_cond = tree_map( - functools.partial(domain_utils.promote_to_same_dimensions, symbolic_cond) - )(domain) - domain_ = tree_map(domain_utils.domain_intersection)(domain, extended_cond) - elif arg == false_field: - cond_complement = domain_utils.domain_complement(symbolic_cond) - extended_cond_complement = tree_map( - functools.partial(domain_utils.promote_to_same_dimensions, cond_complement) - )(domain) - domain_ = tree_map(domain_utils.domain_intersection)(domain, extended_cond_complement) + @tree_map + def mapper(d: NonTupleDomainAccess): + if isinstance(d, DomainAccessDescriptor): + return d + promoted_cond = domain_utils.promote_to_same_dimensions( + symbolic_cond if arg == true_field else cond_complement, + d + ) + return domain_utils.domain_intersection(d, promoted_cond) + + domain_ = mapper(domain) infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) infered_args_expr.append(infered_arg_expr) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index e20e1bcb22..cb499da671 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -92,10 +92,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: return domain.as_expr() elif cpm.is_call_to(node, "not_eq"): # `IDim != a -> `IDim < a & IDim > a` - return im.call("and_")( + return self.visit(im.call("and_")( self.visit(im.less(dim, value), **kwargs), self.visit(im.greater(dim, value), **kwargs), - ) + ), **{**kwargs, "recurse": False}) else: raise AssertionError() diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py index 829bdf2808..b08deea75f 100644 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -7,8 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im, \ + domain_utils +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain +from gt4py.next.iterator import ir as itir class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): @@ -19,22 +23,50 @@ def apply(cls, node: ir.Node): def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: node = self.generic_visit(node) + # TODO: do not duplicate exprs if cpm.is_call_to(node, "concat_where"): cond_expr, field_a, field_b = node.args # TODO: don't duplicate exprs here - if cpm.is_call_to(cond_expr, ("and_")): + if cpm.is_call_to(cond_expr, "and_"): conds = cond_expr.args - return im.concat_where( + return self.visit(im.concat_where( conds[0], im.concat_where(conds[1], field_a, field_b), field_b - ) - if cpm.is_call_to(cond_expr, ("or_")): + )) + if cpm.is_call_to(cond_expr, "or_"): conds = cond_expr.args - return im.concat_where( + return self.visit(im.concat_where( conds[0], field_a, im.concat_where(conds[1], field_a, field_b) - ) - if cpm.is_call_to(cond_expr, ("eq")): + )) + if cpm.is_call_to(cond_expr, "eq"): cond1 = im.less(cond_expr.args[0], cond_expr.args[1]) cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) - return im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) + return self.visit(im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a))) + + # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) + if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): + domain = SymbolicDomain.from_expr(cond_expr) + if len(domain.ranges) == 1: + dim, range_ = next(iter(domain.ranges.items())) + if domain_utils.is_finite(range_): + complement = _range_complement(range_) + new_domains = [im.domain( + domain.grid_type, + {dim: (cr.start, cr.stop)} + ) for cr in complement] + # TODO: fp transform + return self.visit(im.concat_where(im.call("or_")(*new_domains), field_b, field_a)) + else: + # TODO(tehrengruber): Implement. Note that this case can not be triggered by + # the frontend. + raise NotImplementedError() return node + + +def _range_complement(range_: domain_utils.SymbolicRange) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: + # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` + assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) + return ( + domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), + domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE) + ) \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index a5497ed3fa..6d953ea4bc 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -91,10 +91,6 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet - # TODO(tehrengruber): fuse into one pass? InferDomainOps might create an `and` again so the - # second call is required. This happens in test_fused_velocity_advection_stencil_15_to_18. - # TODO(tehrengruber): also write a test case. - ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) ir = nest_concat_wheres.NestConcatWheres.apply(ir) @@ -206,7 +202,6 @@ def apply_fieldview_transforms( ) # domain inference does not support dynamic offsets yet # TODO: deduplicate with regular pass manager - ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = infer_domain_ops.InferDomainOps.apply(ir) ir = nest_concat_wheres.NestConcatWheres.apply(ir) ir = ConstantFolding.apply(ir) From 1a4bf3a6c3cc262f62889afd41bc090c4ee52c8c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Mar 2025 11:55:44 +0100 Subject: [PATCH 054/124] Enable laplacian test --- .../ffront_tests/test_concat_where.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 40ecfda24c..39eba850ec 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -215,36 +215,31 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: @pytest.mark.uses_frontend_concat_where def test_lap_like(cartesian_case): - pytest.xfail("Requires #1847.") - @gtx.field_operator def testee( - input: cases.IJKField, boundary: float, shape: tuple[np.int64, np.int64, np.int64] - ) -> cases.IJKField: + input: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] + ) -> cases.IJField: return concat_where( (IDim == 0) | (JDim == 0) - | (KDim == 0) | (IDim == shape[0] - 1) - | (JDim == shape[1] - 1) - | (KDim == shape[2] - 1), + | (JDim == shape[1] - 1), boundary, input, ) out = cases.allocate(cartesian_case, testee, cases.RETURN)() input = cases.allocate( - cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1, 1:-1] + cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1] )() - boundary = 2.0 + boundary = 2 ref = np.full(out.domain.shape, np.nan) - ref[0, :, :] = boundary - ref[:, 0, :] = boundary - ref[:, :, 0] = boundary - ref[-1, :, :] = boundary - ref[:, -1, :] = boundary - ref[:, :, -1] = boundary + ref[0, :] = boundary + ref[:, 0] = boundary + ref[-1, :] = boundary + ref[:, -1] = boundary + ref[1:-1, 1:-1] = input.asnumpy() cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref) From 1ab8c696bce9afefd53e8e54d25c4d0d30b74fbc Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 21 Mar 2025 22:40:06 +0100 Subject: [PATCH 055/124] embedded concat_where --- src/gt4py/next/common.py | 40 +++++ src/gt4py/next/embedded/nd_array_field.py | 147 +++++++++--------- src/gt4py/next/ffront/foast_to_gtir.py | 2 +- .../next/iterator/ir_utils/domain_utils.py | 7 +- .../iterator/transforms/constant_folding.py | 2 +- .../next/iterator/transforms/global_tmps.py | 10 +- .../next/iterator/transforms/infer_domain.py | 9 +- .../iterator/transforms/infer_domain_ops.py | 12 +- .../iterator/transforms/nest_concat_wheres.py | 47 +++--- .../next/iterator/transforms/pass_manager.py | 2 +- tests/next_tests/definitions.py | 1 - .../ffront_tests/test_concat_where.py | 27 +++- 12 files changed, 191 insertions(+), 115 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d2bc090ce6..9b6bb24a43 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -109,6 +109,28 @@ def __add__(self, offset: int) -> Connectivity: def __sub__(self, offset: int) -> Connectivity: return self + (-offset) + def __gt__(self, value: int) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),)) + + def __ge__(self, value: int) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),)) + + def __lt__(self, value: int) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) + + def __eq__(self, value: Dimension | int) -> bool | Domain: + if isinstance(value, Dimension): + return self.value == value.value + elif isinstance(value, int): + # TODO probably only within valid embedded context? + return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),)) + else: + raise TypeError( + f"Cannot compare Dimension with {type(value)}, only with int or Dimension." + ) + + # TODO add other comparison operators and tests + class Infinity(enum.Enum): """Describes an unbounded `UnitRange`.""" @@ -500,6 +522,24 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) + def __or__(self, other: Domain) -> Domain: + # TODO support arbitrary union of domains + # TODO add tests + if self.ndim > 1 or other.ndim > 1: + raise NotImplementedError("Union of multidimensional domains is not supported.") + if self.ndim == 0: + return other + if other.ndim == 0: + return self + sorted_ = sorted((self, other), key=lambda x: x.ranges[0].start) + if sorted_[0].ranges[0].stop >= sorted_[1].ranges[0].start: + return Domain( + dims=(self.dims[0],), + ranges=(UnitRange(sorted_[0].ranges[0].start, sorted_[1].ranges[0].stop),), + ) + else: + return (sorted_[0], sorted_[1]) + @functools.cached_property def slice_at(self) -> utils.IndexerCallable[slice, Domain]: """ diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 537482508b..f036597966 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -809,25 +809,6 @@ def _hyperslice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _compute_mask_slices( - mask: core_defs.NDArrayObject, -) -> list[tuple[bool, slice]]: - """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" - # TODO: does it make sense to upgrade this naive algorithm to numpy? - assert mask.ndim == 1 - cur = bool(mask[0].item()) - ind = 0 - res = [] - for i in range(1, mask.shape[0]): - # Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy - if (mask_i := bool(mask[i].item())) != cur: - res.append((cur, slice(ind, i))) - cur = mask_i - ind = i - res.append((cur, slice(ind, mask.shape[0]))) - return res - - def _trim_empty_domains( lst: Iterable[tuple[bool, common.Domain]], ) -> list[tuple[bool, common.Domain]]: @@ -895,82 +876,108 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: # TODO(havogt): this function could be extended to a general concat - # currently only concatenate along the given dimension and requires the fields to be ordered + # currently only concatenate along the given dimension + sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start) if ( - len(fields) > 1 - and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() + len(sorted_fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty() ): raise ValueError("Fields to concatenate must not overlap.") - new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) + new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim) if new_domain is None: raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") - nd_array_class = _get_nd_array_class(*fields) + nd_array_class = _get_nd_array_class(*sorted_fields) return nd_array_class.from_array( nd_array_class.array_ns.concatenate( - [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], + [ + nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) + for f in sorted_fields + ], axis=new_domain.dim_index(dim, allow_missing=False), ), domain=new_domain, ) +def _invert_domain( + domains: common.Domain | tuple[common.Domain], +) -> common.Domain | tuple[common.Domain, ...]: + if not isinstance(domains, tuple): + domains = (domains,) + + assert all(d.ndim == 1 for d in domains) + dim = domains[0].dims[0] + assert all(d.dims[0] == dim for d in domains) + sorted_domains = sorted(domains, key=lambda d: d.ranges[0].start) + + result = [] + if domains[0].ranges[0].start is not common.Infinity.NEGATIVE: + result.append( + common.Domain( + dims=(dim,), + ranges=(common.UnitRange(common.Infinity.NEGATIVE, domains[0].ranges[0].start),), + ) + ) + for i in range(len(sorted_domains) - 1): + if sorted_domains[i].ranges[0].stop != sorted_domains[i + 1].ranges[0].start: + result.append( + common.Domain( + dims=(dim,), + ranges=( + common.UnitRange( + sorted_domains[i].ranges[0].stop, sorted_domains[i + 1].ranges[0].start + ), + ), + ) + ) + if domains[-1].ranges[0].stop is not common.Infinity.POSITIVE: + result.append( + common.Domain( + dims=(dim,), + ranges=(common.UnitRange(domains[-1].ranges[0].stop, common.Infinity.POSITIVE),), + ) + ) + return tuple(result) + + +def _intersect_multiple( + domain: common.Domain, domains: common.Domain | tuple[common.Domain] +) -> tuple[common.Domain, ...]: + if not isinstance(domains, tuple): + domains = (domains,) + + return tuple( + intersection + for d in domains + if not (intersection := embedded_common.domain_intersection(domain, d)).is_empty() + ) + + def _concat_where( - mask_field: common.Field, true_field: common.Field, false_field: common.Field + masks: common.Domain | tuple[common.Domain, ...], + true_field: common.Field, + false_field: common.Field, ) -> common.Field: - cls_ = _get_nd_array_class(mask_field, true_field, false_field) - xp = cls_.array_ns - if mask_field.domain.ndim != 1: + if not isinstance(masks, tuple): + masks = (masks,) + if any(m.ndim for m in masks) != 1: raise NotImplementedError( "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) - mask_dim = mask_field.domain.dims[0] + mask_dim = masks[0].dims[0] # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) - # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils - # compute the consecutive ranges (first relative, then domain) of true and false values - mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices( - mask_field.ndarray - ) - mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( - (mask, mask_field.domain.slice_at[domain_slice]) - for mask, domain_slice in mask_values_to_slices_mapping - ) - # mask domains intersected with the respective fields - mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( - ( - mask_value, - embedded_common.domain_intersection( - t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain - ), - ) - for mask_value, mask_domain in mask_values_to_domain_mapping - ) - - # remove the empty domains from the beginning and end - mask_values_to_intersected_domains_mapping = _trim_empty_domains( - mask_values_to_intersected_domains_mapping - ) - if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): - raise embedded_exceptions.NonContiguousDomain( - f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." - ) + true_domains = _intersect_multiple(t_broadcasted.domain, masks) + t_slices = tuple(t_broadcasted[d] for d in true_domains) - # slice the fields with the domain ranges - transformed = [ - t_broadcasted[d] if v else f_broadcasted[d] - for v, d in mask_values_to_intersected_domains_mapping - ] + inverted_masks = _invert_domain(masks) + false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks) + f_slices = tuple(f_broadcasted[d] for d in false_domains) - # stack the fields together - if transformed: - return _concat(*transformed, dim=mask_dim) - else: - result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) - result_array = xp.empty(result_domain.shape) - return cls_.from_array(result_array, domain=result_domain) + return _concat(*f_slices, *t_slices, dim=mask_dim) NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2c99c4a3b5..47fa707b3f 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -413,7 +413,7 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: # TODO: use this case again. breaks domain inference in fused_velocity_advection_stencil_1_to_7 # because some tuple elements are never accessed and the collapse tuple # does not propagate across concat where - #return im.concat_where(domain, true_branch, false_branch) + # return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.call("broadcast")(*self.visit(node.args, **kwargs)) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 2c21a4e393..ce463f4088 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -242,9 +242,12 @@ def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: """ if isinstance(range_ := range_or_domain, SymbolicRange): # TODO: assert no infinity literal in here - if any(v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] for v in [range_.start, range_.stop]): + if any( + v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] + for v in [range_.start, range_.stop] + ): return False return True elif isinstance(domain := range_or_domain, SymbolicDomain): return all(is_finite(range_) for range_ in domain.ranges.values()) - raise ValueError("Expected a SymbolicRange or SymbolicDomain.") \ No newline at end of file + raise ValueError("Expected a SymbolicRange or SymbolicDomain.") diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index d640f54175..b9a92c161f 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -284,4 +284,4 @@ def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.No if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: return im.literal_from_value(False) - return None \ No newline at end of file + return None diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 58111ffe19..70b2990c5b 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -190,11 +190,15 @@ def create_global_tmps( arguments into temporaries. """ # TODO: document why to keep existing domains, add test - offset_provider_type = common.offset_provider_to_type(offset_provider) program = infer_domain.infer_program( - program, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, keep_existing_domains=True + program, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + keep_existing_domains=True, + ) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) - program = type_inference.infer(program, offset_provider_type=common.offset_provider_to_type(offset_provider)) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 5e3a5cfcf4..57f59e1587 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -8,14 +8,13 @@ from __future__ import annotations -import functools import itertools import typing from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack -from gt4py.next import common, utils +from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, @@ -379,7 +378,7 @@ def _infer_concat_where( **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "concat_where") - #assert all(isinstance(domain, domain_utils.SymbolicDomain) for domain in utils.flatten_nested_tuple(domain)) + # assert all(isinstance(domain, domain_utils.SymbolicDomain) for domain in utils.flatten_nested_tuple(domain)) infered_args_expr = [] actual_domains: AccessedDomains = {} cond, true_field, false_field = expr.args @@ -387,13 +386,13 @@ def _infer_concat_where( cond_complement = domain_utils.domain_complement(symbolic_cond) for arg in [true_field, false_field]: + @tree_map def mapper(d: NonTupleDomainAccess): if isinstance(d, DomainAccessDescriptor): return d promoted_cond = domain_utils.promote_to_same_dimensions( - symbolic_cond if arg == true_field else cond_complement, - d + symbolic_cond if arg == true_field else cond_complement, d ) return domain_utils.domain_intersection(d, promoted_cond) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index cb499da671..fde4b22696 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -15,7 +15,6 @@ domain_utils, ir_makers as im, ) -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import _get_gridtype from gt4py.next.type_system import type_specifications as ts @@ -92,10 +91,13 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: return domain.as_expr() elif cpm.is_call_to(node, "not_eq"): # `IDim != a -> `IDim < a & IDim > a` - return self.visit(im.call("and_")( - self.visit(im.less(dim, value), **kwargs), - self.visit(im.greater(dim, value), **kwargs), - ), **{**kwargs, "recurse": False}) + return self.visit( + im.call("and_")( + self.visit(im.less(dim, value), **kwargs), + self.visit(im.greater(dim, value), **kwargs), + ), + **{**kwargs, "recurse": False}, + ) else: raise AssertionError() diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py index b08deea75f..74fa31f951 100644 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -7,12 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next import common -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im, \ - domain_utils +from gt4py.next.iterator import ir, ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain -from gt4py.next.iterator import ir as itir class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): @@ -29,18 +30,20 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: # TODO: don't duplicate exprs here if cpm.is_call_to(cond_expr, "and_"): conds = cond_expr.args - return self.visit(im.concat_where( - conds[0], im.concat_where(conds[1], field_a, field_b), field_b - )) + return self.visit( + im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) + ) if cpm.is_call_to(cond_expr, "or_"): conds = cond_expr.args - return self.visit(im.concat_where( - conds[0], field_a, im.concat_where(conds[1], field_a, field_b) - )) + return self.visit( + im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) + ) if cpm.is_call_to(cond_expr, "eq"): cond1 = im.less(cond_expr.args[0], cond_expr.args[1]) cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) - return self.visit(im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a))) + return self.visit( + im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) + ) # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): @@ -49,12 +52,14 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: dim, range_ = next(iter(domain.ranges.items())) if domain_utils.is_finite(range_): complement = _range_complement(range_) - new_domains = [im.domain( - domain.grid_type, - {dim: (cr.start, cr.stop)} - ) for cr in complement] + new_domains = [ + im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) + for cr in complement + ] # TODO: fp transform - return self.visit(im.concat_where(im.call("or_")(*new_domains), field_b, field_a)) + return self.visit( + im.concat_where(im.call("or_")(*new_domains), field_b, field_a) + ) else: # TODO(tehrengruber): Implement. Note that this case can not be triggered by # the frontend. @@ -63,10 +68,12 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: return node -def _range_complement(range_: domain_utils.SymbolicRange) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: +def _range_complement( + range_: domain_utils.SymbolicRange, +) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) return ( domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), - domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE) - ) \ No newline at end of file + domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), + ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b0c02cf182..b44bfcfc83 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -21,8 +21,8 @@ inline_fundefs, inline_lifts, nest_concat_wheres, - transform_concat_where, prune_broadcast, + transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 5216b2ae32..022369b9b7 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -172,7 +172,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args - (USES_FRONTEND_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 39eba850ec..09cc2d92a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -167,11 +167,28 @@ def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +# @pytest.mark.uses_frontend_concat_where +# def test_dimension_two_illegal_threeway_comparison(cartesian_case): +# @gtx.field_operator +# def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: +# return concat_where(0 < KDim < (nlev - 1), interior, boundary) + +# interior = cases.allocate(cartesian_case, testee, "interior")() +# boundary = cases.allocate(cartesian_case, testee, "boundary")() +# out = cases.allocate(cartesian_case, testee, cases.RETURN)() + +# nlev = cartesian_case.default_sizes[KDim] +# k = np.arange(0, nlev) +# ref = np.where((0 < k) & (k < (nlev - 1)), interior.asnumpy(), boundary.asnumpy()) +# with pytest.raises: # TODO +# cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) + + @pytest.mark.uses_frontend_concat_where def test_dimension_two_conditions_and(cartesian_case): @gtx.field_operator def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: - return concat_where(0 < KDim < (nlev - 1), interior, boundary) + return concat_where((0 < KDim) & (KDim < (nlev - 1)), interior, boundary) interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() @@ -219,13 +236,11 @@ def test_lap_like(cartesian_case): def testee( input: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] ) -> cases.IJField: + # TODO add support for multi-dimensional concat_where masks return concat_where( - (IDim == 0) - | (JDim == 0) - | (IDim == shape[0] - 1) - | (JDim == shape[1] - 1), + (IDim == 0) | (IDim == shape[0] - 1), boundary, - input, + concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, input), ) out = cases.allocate(cartesian_case, testee, cases.RETURN)() From ae0782660b75bdf2ef0b2953f1e3138334b97dff Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 21 Mar 2025 23:18:06 +0100 Subject: [PATCH 056/124] add support for more comparison operators --- src/gt4py/next/common.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 9b6bb24a43..6de1ec2818 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -118,6 +118,10 @@ def __ge__(self, value: int) -> Domain: def __lt__(self, value: int) -> Domain: return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) + def __le__(self, value: int) -> Domain: + # TODO add test + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),)) + def __eq__(self, value: Dimension | int) -> bool | Domain: if isinstance(value, Dimension): return self.value == value.value @@ -129,7 +133,20 @@ def __eq__(self, value: Dimension | int) -> bool | Domain: f"Cannot compare Dimension with {type(value)}, only with int or Dimension." ) - # TODO add other comparison operators and tests + def __ne__(self, value: Dimension | int) -> bool | tuple[Domain, Domain]: + # TODO add test + if isinstance(value, Dimension): + return self.value != value.value + elif isinstance(value, int): + # TODO probably only within valid embedded context? + return ( + Domain(self, UnitRange(Infinity.NEGATIVE, value)), + Domain(self, UnitRange(value + 1, Infinity.POSITIVE)), + ) + else: + raise TypeError( + f"Cannot compare Dimension with {type(value)}, only with int or Dimension." + ) class Infinity(enum.Enum): From a8fe04e6d13df31631d2c155a60cc2c7fc9144be Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sat, 22 Mar 2025 09:39:59 +0100 Subject: [PATCH 057/124] change Dimension comparison --- src/gt4py/next/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 6de1ec2818..bfadd66b8c 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -129,9 +129,7 @@ def __eq__(self, value: Dimension | int) -> bool | Domain: # TODO probably only within valid embedded context? return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),)) else: - raise TypeError( - f"Cannot compare Dimension with {type(value)}, only with int or Dimension." - ) + return False def __ne__(self, value: Dimension | int) -> bool | tuple[Domain, Domain]: # TODO add test @@ -144,9 +142,7 @@ def __ne__(self, value: Dimension | int) -> bool | tuple[Domain, Domain]: Domain(self, UnitRange(value + 1, Infinity.POSITIVE)), ) else: - raise TypeError( - f"Cannot compare Dimension with {type(value)}, only with int or Dimension." - ) + return True class Infinity(enum.Enum): From 40cf33bda3c551e1ea328e4f874ba34ffd520fb5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sun, 23 Mar 2025 08:55:20 +0100 Subject: [PATCH 058/124] embedded: non-python int comparison --- src/gt4py/next/common.py | 16 ++++++++-------- .../ffront_tests/test_concat_where.py | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index bfadd66b8c..bb08946d46 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -109,33 +109,33 @@ def __add__(self, offset: int) -> Connectivity: def __sub__(self, offset: int) -> Connectivity: return self + (-offset) - def __gt__(self, value: int) -> Domain: + def __gt__(self, value: core_defs.IntegralScalar) -> Domain: return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),)) - def __ge__(self, value: int) -> Domain: + def __ge__(self, value: core_defs.IntegralScalar) -> Domain: return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),)) - def __lt__(self, value: int) -> Domain: + def __lt__(self, value: core_defs.IntegralScalar) -> Domain: return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) - def __le__(self, value: int) -> Domain: + def __le__(self, value: core_defs.IntegralScalar) -> Domain: # TODO add test return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),)) - def __eq__(self, value: Dimension | int) -> bool | Domain: + def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain: if isinstance(value, Dimension): return self.value == value.value - elif isinstance(value, int): + elif isinstance(value, core_defs.INTEGRAL_TYPES): # TODO probably only within valid embedded context? return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),)) else: return False - def __ne__(self, value: Dimension | int) -> bool | tuple[Domain, Domain]: + def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | tuple[Domain, Domain]: # TODO add test if isinstance(value, Dimension): return self.value != value.value - elif isinstance(value, int): + elif isinstance(value, core_defs.INTEGRAL_TYPES): # TODO probably only within valid embedded context? return ( Domain(self, UnitRange(Infinity.NEGATIVE, value)), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 09cc2d92a8..548faeba9e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -18,6 +18,8 @@ exec_alloc_descriptor, ) +# TODO test non-Python int for embedded comparison + @pytest.mark.uses_frontend_concat_where def test_concat_where_simple(cartesian_case): From 8c4fc4555e1d34cf421503ee208bf97a62d596d7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 14 Apr 2025 14:55:27 +0200 Subject: [PATCH 059/124] Fix import --- .../transforms_tests/test_constant_folding.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0dd74e44bb..fe29f4c599 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -8,7 +8,7 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir import pytest from gt4py.next.iterator.ir_utils import ir_makers as im @@ -166,86 +166,86 @@ def test_constant_folding(test_case): # TODO: integrate into test structure above def test_constant_folding_inf_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) - expected = ir.InfinityLiteral.POSITIVE + testee = im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) + expected = itir.InfinityLiteral.POSITIVE actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("maximum")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) - expected = ir.InfinityLiteral.POSITIVE + testee = im.call("maximum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + expected = itir.InfinityLiteral.POSITIVE actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("maximum")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) + testee = im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("maximum")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + testee = im.call("maximum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_folding_inf_minimum(): - testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) + testee = im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("minimum")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + testee = im.call("minimum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) expected = im.literal_from_value(1) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("minimum")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) - expected = ir.InfinityLiteral.NEGATIVE + testee = im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) + expected = itir.InfinityLiteral.NEGATIVE actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("minimum")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) - expected = ir.InfinityLiteral.NEGATIVE + testee = im.call("minimum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + expected = itir.InfinityLiteral.NEGATIVE actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_greater_less(): - testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) + testee = im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("greater")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) + testee = im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral.POSITIVE) + testee = im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(im.literal_from_value(1), ir.InfinityLiteral.NEGATIVE) + testee = im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("greater")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + testee = im.call("greater")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("greater")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + testee = im.call("greater")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(ir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) + testee = im.call("less")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) expected = im.literal_from_value(False) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.call("less")(ir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) + testee = im.call("less")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) expected = im.literal_from_value(True) actual = ConstantFolding.apply(testee) assert actual == expected From a41eb1a47f2e764b025e46300fdc5647883204ef Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 24 Apr 2025 12:04:45 +0200 Subject: [PATCH 060/124] feat[next]: GTIR concat_where frontend --- src/gt4py/next/ffront/experimental.py | 18 +- src/gt4py/next/ffront/fbuiltins.py | 15 +- .../ffront/foast_passes/type_deduction.py | 155 +++++++++++++----- .../type_system/type_specifications.py | 4 - .../iterator/type_system/type_synthesizer.py | 10 +- src/gt4py/next/type_system/type_info.py | 7 +- .../next/type_system/type_specifications.py | 11 +- .../ffront_tests/test_type_deduction.py | 56 ++++++- .../iterator_tests/test_type_inference.py | 8 +- 9 files changed, 221 insertions(+), 63 deletions(-) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index bd22aebe57..dfa89468a5 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -6,11 +6,16 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple +from typing import Tuple, TypeVar from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction +from gt4py.next.ffront.fbuiltins import ( + BuiltInFunction, + FieldOffset, + FieldT, + WhereLikeBuiltinFunction, +) @BuiltInFunction @@ -18,9 +23,14 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi raise NotImplementedError() -@WhereBuiltinFunction +_R = TypeVar("_R") +DomainT = TypeVar("DomainT", bound=common.Field) +ConcatWhereBuiltinFunction = WhereLikeBuiltinFunction[_R, DomainT, FieldT] + + +@ConcatWhereBuiltinFunction def concat_where( - mask: common.Field, + mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, false_field: common.Field | core_defs.ScalarT | Tuple, /, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index ee14006b22..17e6bb2133 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -66,6 +66,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType + elif t is common.Domain: + return ts.DomainType elif t is type: return ( ts.FunctionType @@ -135,14 +137,15 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=common.Field) +MaskLikeT = TypeVar("MaskLikeT", bound=common.Field) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) -class WhereBuiltinFunction( - BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] +class WhereLikeBuiltinFunction( + BuiltInFunction[_R, [MaskLikeT, FieldT, FieldT]], + Generic[_R, MaskLikeT, FieldT], ): - def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: + def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( @@ -157,6 +160,10 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: return super().__call__(mask, true_field, false_field) +MaskT = TypeVar("MaskT", bound=common.Field) +WhereBuiltinFunction = WhereLikeBuiltinFunction[_R, MaskT, FieldT] + + @BuiltInFunction def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 26bcadaef1..370d0be85c 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -10,8 +10,8 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next import errors -from gt4py.next.common import DimensionKind +from gt4py.next import errors, utils +from gt4py.next.common import DimensionKind, promote_dims from gt4py.next.ffront import ( # noqa dialect_ast_enums, experimental, @@ -20,6 +20,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices +from gt4py.next.iterator import builtins from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -566,16 +567,10 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> foast.Compare: op=node.op, left=new_left, right=new_right, location=node.location, type=new_type ) - def _deduce_compare_type( + def _deduce_arithmetic_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: - # check both types compatible - for arg in (left, right): - if not type_info.is_arithmetic(arg.type): - raise errors.DSLError( - arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." - ) - + # e.g. `1 < 2` self._check_operand_dtypes_match(node, left=left, right=right) try: @@ -592,6 +587,48 @@ def _deduce_compare_type( f" in call to '{node.op}'.", ) from ex + def _deduce_dimension_compare_type( + self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any + ) -> Optional[ts.TypeSpec]: + # e.g. `IDim > 1` + index_type = ts.ScalarType( + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + ) + + if isinstance(left.type, ts.DimensionType): + if not right.type == index_type: + raise errors.DSLError( + right.location, + f"Expected an {index_type}, but got '{right.type}' instead.", + ) + return ts.DomainType(dims=[left.type.dim]) + elif isinstance(right.type, ts.DimensionType): + if not left.type == index_type: + raise errors.DSLError( + left.location, + f"Expected an {index_type}, but got '{right.type}' instead.", + ) + return ts.DomainType(dims=[right.type.dim]) + else: + raise AssertionError() + + def _deduce_compare_type( + self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any + ) -> Optional[ts.TypeSpec]: + # e.g. `1 < 1` + if all(type_info.is_arithmetic(arg) for arg in (left.type, right.type)): + return self._deduce_arithmetic_compare_type(node, left=left, right=right) + # e.g. `IDim > 1` + if any(isinstance(arg, ts.DimensionType) for arg in (left.type, right.type)): + return self._deduce_dimension_compare_type(node, left=left, right=right) + + raise errors.DSLError( + left.location, + "Comparison operators can only be used between arithmetic types " + "(scalars, fields) or between a dimension and an index type " + "({builtins.INTEGER_INDEX_BUILTIN}).", + ) + def _deduce_binop_type( self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: @@ -612,37 +649,48 @@ def _deduce_binop_type( dialect_ast_enums.BinaryOperator.BIT_OR, dialect_ast_enums.BinaryOperator.BIT_XOR, } - is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic - - # check both types compatible - for arg in (left, right): - if not is_compatible(arg.type): - raise errors.DSLError( - arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." - ) - - left_type = cast(ts.FieldType | ts.ScalarType, left.type) - right_type = cast(ts.FieldType | ts.ScalarType, right.type) - if node.op == dialect_ast_enums.BinaryOperator.POW: - return left_type + err_msg = f"Unsupported operand type(s) for {node.op}: '{left.type}' and '{right.type}'." - if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( - right_type + if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( + right.type, (ts.ScalarType, ts.FieldType) ): - raise errors.DSLError( - arg.location, - f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", + is_compatible = ( + type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic ) + for arg in (left, right): + if not is_compatible(arg.type): + raise errors.DSLError(arg.location, err_msg) - try: - return type_info.promote(left_type, right_type) - except ValueError as ex: - raise errors.DSLError( - node.location, - f"Could not promote '{left_type}' and '{right_type}' to common type" - f" in call to '{node.op}'.", - ) from ex + if node.op == dialect_ast_enums.BinaryOperator.POW: + return left.type + + if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( + right.type + ): + raise errors.DSLError( + arg.location, + f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.", + ) + + try: + return type_info.promote(left.type, right.type) + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Could not promote '{left.type}' and '{right.type}' to common type" + f" in call to '{node.op}'.", + ) from ex + elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): + if node.op not in logical_ops: + raise errors.DSLError( + node.location, + f"{err_msg} Operator " + f"must be one of {', '.join((str(op) for op in logical_ops))}.", + ) + return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) + else: + raise errors.DSLError(node.location, err_msg) def _check_operand_dtypes_match( self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr @@ -908,6 +956,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) try: + # TODO(tehrengruber): the construct_tuple_type function doesn't look correct if isinstance(true_branch_type, ts.TupleType) and isinstance( false_branch_type, ts.TupleType ): @@ -943,7 +992,39 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: location=node.location, ) - _visit_concat_where = _visit_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: + mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) + + assert isinstance(mask_type, ts.DomainType) + assert all( + isinstance(el, (ts.FieldType, ts.ScalarType)) + for arg in (true_branch_type, false_branch_type) + for el in type_info.primitive_constituents(arg) + ) + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda el: ts.TupleType(types=list(el)), + ) + def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)): + raise errors.DSLError( + node.location, + f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", + ) + return_dims = promote_dims(mask_type.dims, type_info.promote(tb, fb).dims) + return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + return return_type + + return_type = deduce_return_type(true_branch_type, false_branch_type) + + return foast.Call( + func=node.func, + args=node.args, + kwargs=node.kwargs, + type=return_type, + location=node.location, + ) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 9695982e97..39e9e607ce 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -16,10 +16,6 @@ class NamedRangeType(ts.TypeSpec): dim: common.Dimension -class DomainType(ts.DataType): - dims: list[common.Dimension] | Literal["unknown"] - - class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | str diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index be731254c2..8c381d51a8 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -187,9 +187,9 @@ def named_range( @_register_builtin_type_synthesizer(fun_names=["cartesian_domain", "unstructured_domain"]) -def _(*args: it_ts.NamedRangeType) -> it_ts.DomainType: +def _(*args: it_ts.NamedRangeType) -> ts.DomainType: assert all(isinstance(arg, it_ts.NamedRangeType) for arg in args) - return it_ts.DomainType(dims=[arg.dim for arg in args]) + return ts.DomainType(dims=[arg.dim for arg in args]) @_register_builtin_type_synthesizer @@ -271,7 +271,7 @@ def apply_lift( def _convert_as_fieldop_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec + domain: ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: # get the dimensions of all non-zero-dimensional field inputs and check they agree all_input_dims = ( @@ -311,7 +311,7 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( stencil: TypeSynthesizer, - domain: Optional[it_ts.DomainType] = None, + domain: Optional[ts.DomainType] = None, *, offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: @@ -326,7 +326,7 @@ def as_fieldop( # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` # it is unclear if the result has dimension I, J or J, I. if domain is None: - domain = it_ts.DomainType(dims="unknown") + domain = ts.DomainType(dims="unknown") @TypeSynthesizer def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index e6f9073350..42e586bdf5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -235,6 +235,8 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ + if not isinstance(symbol_type, (ts.ScalarType, ts.FieldType)): + return False return isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind in [ ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64, @@ -278,6 +280,8 @@ def is_integral(symbol_type: ts.TypeSpec) -> bool: >>> is_integral(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ + if not isinstance(symbol_type, (ts.ScalarType, ts.FieldType)): + return False return is_integer(extract_dtype(symbol_type)) @@ -305,7 +309,8 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: return ( - isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + isinstance(symbol_type, (ts.FieldType, ts.ScalarType)) + and isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind is ts.ScalarKind.BOOL ) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index e83ade9ccc..4822719b40 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Optional, Sequence, Union +from typing import Iterator, Literal, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -49,6 +49,9 @@ class VoidType(TypeSpec): class DimensionType(TypeSpec): dim: common.Dimension + def __str__(self) -> str: + return str(self.dim) + class OffsetType(TypeSpec): # TODO(havogt): replace by ConnectivityType @@ -138,3 +141,9 @@ def __str__(self) -> str: kwarg_strs = [f"{key}: {value}" for key, value in self.pos_or_kw_args.items()] args_str = ", ".join((*arg_strs, *kwarg_strs)) return f"({args_str}) -> {self.returns}" + + +class DomainType(DataType): + # TODO(tehrengruber): Remove "unknown" here again after the result type of `as_fieldop` + # is always precisely known. This is the case after #1853. + dims: list[common.Dimension] | Literal["unknown"] diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index d3396a54e0..4f882402f2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -28,6 +28,7 @@ neighbor_sum, where, ) +from gt4py.next.ffront.experimental import concat_where from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -75,7 +76,12 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") + errors.DSLError, + match=( + re.escape( + "Unsupported operand type(s) for +: 'Field[[TDim], bool]' and 'Field[[TDim], bool]'." + ) + ), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -95,13 +101,15 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): ) -def test_bitopping_float(): +def test_bitop_float(): def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( errors.DSLError, - match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), + match=re.escape( + "Unsupported operand type(s) for &: 'Field[[TDim], float64]' and 'Field[[TDim], float64]'." + ), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -128,6 +136,48 @@ def not_int(a: Field[[TDim], int64]): _ = FieldOperatorParser.apply_to_function(not_int) +def test_concat_where(): + def simple_concat_where(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > 0, a, b) + + parsed = FieldOperatorParser.apply_to_function(simple_concat_where) + compare_node = parsed.body.stmts[0].value.args[0] + assert compare_node.type == ts.DomainType(dims=[TDim]) + + +def test_domain_comparison_failure(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > 1.0, a, b) + + with pytest.raises( + errors.DSLError, + match=re.escape("Expected an int32, but got 'float64' instead."), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + +def test_domain_comparison_checkerboard_failure(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim % 2.0, a, b) + + with pytest.raises( + errors.DSLError, + match=re.escape("Unsupported operand type(s) for %"), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + +def test_concat_where_invalid_dtype(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > 0, 1.0, 2) + + with pytest.raises( + errors.DSLError, + match=re.escape("Field arguments must be of same dtype, got 'float64' != 'int32'."), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + @pytest.fixture def premap_setup(): X = Dimension("X") diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 05316902a7..db933ce736 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -94,7 +94,7 @@ def expression_test_cases(): im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ), - it_ts.DomainType(dims=[IDim]), + ts.DomainType(dims=[IDim]), ), ( im.call("unstructured_domain")( @@ -102,7 +102,7 @@ def expression_test_cases(): itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ) ), - it_ts.DomainType(dims=[Vertex]), + ts.DomainType(dims=[Vertex]), ), # make_tuple ( @@ -327,7 +327,7 @@ def test_cartesian_fencil_definition(): program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) assert result.type == program_type - domain_type = it_ts.DomainType(dims=[IDim]) + domain_type = ts.DomainType(dims=[IDim]) assert result.body[0].domain.type == domain_type assert result.body[0].expr.type == float_i_field assert result.body[0].target.type == float_i_field @@ -366,7 +366,7 @@ def test_unstructured_fencil_definition(): params={"inp": float_edge_k_field, "out": float_vertex_k_field} ) assert result.type == program_type - domain_type = it_ts.DomainType(dims=[Vertex, KDim]) + domain_type = ts.DomainType(dims=[Vertex, KDim]) assert result.body[0].domain.type == domain_type assert result.body[0].expr.type == float_vertex_k_field assert result.body[0].target.type == float_vertex_k_field From e2c053c1ce1b2ea9c604fc4ccc6e6abd76076982 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 24 Apr 2025 12:14:30 +0200 Subject: [PATCH 061/124] disable concat_where tests --- pyproject.toml | 1 + tests/next_tests/definitions.py | 4 ++++ .../feature_tests/ffront_tests/test_concat_where.py | 2 ++ .../unit_tests/embedded_tests/test_nd_array_field.py | 1 + 4 files changed, 8 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b5925be4b5..db34567a49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -302,6 +302,7 @@ markers = [ 'uses_unstructured_shift: tests that use a unstructured connectivity', 'uses_max_over: tests that use the max_over builtin', 'uses_mesh_with_skip_values: tests that use a mesh with skip values', + 'uses_concat_where: tests that use the concat_where builtin', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 1f81076abf..ba30cfef4f 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -127,6 +127,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MAX_OVER = "uses_max_over" USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" +USES_CONCAT_WHERE = "uses_concat_where" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') @@ -143,6 +144,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ @@ -169,10 +171,12 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args + (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 364434029f..44ef9b62f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -17,6 +17,8 @@ exec_alloc_descriptor, ) +pytestmark = pytest.mark.uses_concat_where + def test_boundary_same_size_fields(cartesian_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 9bdc6ab5c1..7c29faca92 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -983,6 +983,7 @@ def test_hyperslice(index_array, expected): assert result == expected +@pytest.mark.uses_concat_where @pytest.mark.parametrize( "mask_data, true_data, false_data, expected", [ From 4b46fcdd54e0d5ca5b9ddff8635a37ca3505bcb4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 24 Apr 2025 12:39:25 +0200 Subject: [PATCH 062/124] one more it_ts.DomainType --- src/gt4py/next/iterator/type_system/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index b8aaed3f23..bd5989ef2e 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -448,7 +448,7 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) - assert isinstance(domain, it_ts.DomainType) + assert isinstance(domain, ts.DomainType) assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( From d77a4c077c127297116578ada16c4ed2d5ea3d76 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 24 Apr 2025 12:58:47 +0200 Subject: [PATCH 063/124] add test for concat_where on scalars and fix typing --- src/gt4py/next/embedded/nd_array_field.py | 2 +- .../next/ffront/foast_passes/type_deduction.py | 17 ++++++++++++++--- .../ffront_tests/test_type_deduction.py | 9 +++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 22b91d947c..339516ccf4 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -974,7 +974,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) def _make_reduction( diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 370d0be85c..f6afebd60b 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -429,7 +429,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt: if not isinstance(new_node.condition.type, ts.ScalarType): raise errors.DSLError( node.location, - "Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.", + f"Condition for 'if' must be scalar, got '{new_node.condition.type}' instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: @@ -688,6 +688,8 @@ def _deduce_binop_type( f"{err_msg} Operator " f"must be one of {', '.join((str(op) for op in logical_ops))}.", ) + assert isinstance(right.type.dims, list) + assert isinstance(left.type.dims, list) return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) else: raise errors.DSLError(node.location, err_msg) @@ -1006,13 +1008,22 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: collection_type=ts.TupleType, result_collection_constructor=lambda el: ts.TupleType(types=list(el)), ) - def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + def deduce_return_type( + tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType + ) -> ts.FieldType: if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)): raise errors.DSLError( node.location, f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", ) - return_dims = promote_dims(mask_type.dims, type_info.promote(tb, fb).dims) + assert isinstance(mask_type.dims, list) + promoted_branches = type_info.promote(tb, fb) + branches_dims = ( + [] if isinstance(promoted_branches, ts.ScalarType) else promoted_branches.dims + ) + return_dims = promote_dims(mask_type.dims, branches_dims) + assert isinstance(t_dtype, ts.ScalarType) + assert isinstance(f_dtype, ts.ScalarType) return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) return return_type diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 4f882402f2..0b326a2293 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -145,6 +145,15 @@ def simple_concat_where(a: Field[[TDim], float], b: Field[[TDim], float]): assert compare_node.type == ts.DomainType(dims=[TDim]) +def test_concat_where_scalar(): + def simple_concat_where(a: float, b: float): + return concat_where(TDim > 0, a, b) + + parsed = FieldOperatorParser.apply_to_function(simple_concat_where) + compare_node = parsed.body.stmts[0].value.args[0] + assert compare_node.type == ts.DomainType(dims=[TDim]) + + def test_domain_comparison_failure(): def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): return concat_where(TDim > 1.0, a, b) From f41c112e3824919ca873e8917b1e2dab47504c19 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 25 Apr 2025 09:00:35 +0200 Subject: [PATCH 064/124] add test for chained comparison --- src/gt4py/next/ffront/func_to_foast.py | 12 +++++------- .../unit_tests/ffront_tests/test_type_deduction.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ef20b99d91..60282bf6c6 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -26,7 +26,6 @@ SingleAssignTargetPass, SingleStaticAssignPass, StringifyAnnotationsPass, - UnchainComparesPass, ) from gt4py.next.ffront.dialect_parser import DialectParser from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind @@ -144,12 +143,11 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): """ @classmethod - def _preprocess_definition_ast(cls, definition_ast: ast.AST) -> ast.AST: - sta = StringifyAnnotationsPass.apply(definition_ast) - ssa = SingleStaticAssignPass.apply(sta) - sat = SingleAssignTargetPass.apply(ssa) - ucc = UnchainComparesPass.apply(sat) - return ucc + def _preprocess_definition_ast(cls, ast: ast.AST) -> ast.AST: + ast = StringifyAnnotationsPass.apply(ast) + ast = SingleStaticAssignPass.apply(ast) + ast = SingleAssignTargetPass.apply(ast) + return ast @classmethod def _postprocess_dialect_ast( diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 0b326a2293..8be3fa0dbd 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -187,6 +187,18 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): _ = FieldOperatorParser.apply_to_function(domain_comparison) +def test_domain_chained_comparison_failure(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(0 < TDim < 42, a, b) + + # _ = FieldOperatorParser.apply_to_function(domain_comparison) + with pytest.raises( + errors.DSLError, + match=re.escape("TODO"), + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + @pytest.fixture def premap_setup(): X = Dimension("X") From 359a921a774853095ace454420cfd92b7dbfa445 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 18 May 2025 17:48:48 +0200 Subject: [PATCH 065/124] Fix broken merge --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 19b0804975..295ec8fec9 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -560,7 +560,7 @@ def applied_as_fieldop( ) assert all(isinstance(dim, common.Dimension) for dim in output_dims) - deduced_domain = it_ts.DomainType(dims=output_dims) + deduced_domain = ts.DomainType(dims=output_dims) if deduced_domain: domain = deduced_domain From 4764c2be09cea5f9c126c301d8ae71b6adfb3c07 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 18 May 2025 17:58:38 +0200 Subject: [PATCH 066/124] Simplify tuple lowering, unit tests, cleanup --- src/gt4py/next/ffront/foast_to_gtir.py | 9 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 9 +- .../iterator/transforms/collapse_tuple.py | 5 +- .../transforms/concat_where_transforms.py | 112 ++++++++++++++++++ .../transforms/expand_library_functions.py | 4 +- .../next/iterator/transforms/pass_manager.py | 7 +- .../transforms/transform_concat_where.py | 56 --------- .../iterator/type_system/type_synthesizer.py | 6 +- tests/next_tests/integration_tests/cases.py | 2 +- .../ffront_tests/test_concat_where.py | 45 +++++++ .../transforms_tests/test_collapse_tuple.py | 23 +++- .../test_transform_concat_where.py | 68 +++++++++++ 12 files changed, 265 insertions(+), 81 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/concat_where_transforms.py delete mode 100644 src/gt4py/next/iterator/transforms/transform_concat_where.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 1b677d187b..a40fd65971 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -408,16 +408,11 @@ def create_if( return im.let(cond_symref_name, cond_)(result) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - domain, true_branch, false_branch = self.visit(node.args) - return lowering_utils.process_elements( - lambda tb, fb: im.call("concat_where")(domain, tb, fb), - (true_branch, false_branch), - node.type, - ) + domain, true_branch, false_branch = self.visit(node.args, **kwargs) # TODO: use this case again. breaks domain inference in fused_velocity_advection_stencil_1_to_7 # because some tuple elements are never accessed and the collapse tuple # does not propagate across concat where - # return im.concat_where(domain, true_branch, false_branch) + return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.call("broadcast")(*self.visit(node.args, **kwargs)) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 50a050211d..3aca150da6 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -245,6 +245,7 @@ def if_(cond, true_val, false_val): def concat_where(cond, true_field, false_field): """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + return call("concat_where")(cond, true_field, false_field) @@ -442,18 +443,18 @@ def domain( """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" - return call(grid_type)( + expr = call(grid_type)( *[ call("named_range")( - itir.AxisLiteral(value=d.value, kind=d.kind) - if isinstance(d, common.Dimension) - else itir.AxisLiteral(value=d), + itir.AxisLiteral(value=d.value, kind=d.kind), r[0], r[1], ) for d, r in ranges.items() ] ) + expr.type = ts.DomainType(dims=list(ranges.keys())) + return expr def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index e9479514cd..6928c8dcf8 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -308,9 +308,10 @@ def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optiona self.fp_transform(im.tuple_get(idx.value, expr.fun.expr), **kwargs) ) )(*expr.args) - elif cpm.is_call_to(expr, "if_"): + elif cpm.is_call_to(expr, ("if_", "concat_where")): + fun = expr.fun cond, true_branch, false_branch = expr.args - return im.if_( + return im.call(fun)( cond, self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), diff --git a/src/gt4py/next/iterator/transforms/concat_where_transforms.py b/src/gt4py/next/iterator/transforms/concat_where_transforms.py new file mode 100644 index 0000000000..b1a1936b23 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where_transforms.py @@ -0,0 +1,112 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common, utils +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms import infer_domain, symbol_ref_utils +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +class _TransformTupleConcatWhere(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply(cls, node: ir.Node, offset_provider_type: Optional[common.OffsetProviderType] = None): + node = type_inference.infer( + node, + offset_provider_type=offset_provider_type, + ) + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + + # `concat_where(cond, {a, b}, {c, d})` + # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` + if cpm.is_call_to(node, "concat_where") and isinstance(node.args[1].type, ts.TupleType): + cond, true_branch, false_branch = node.args + new_els = [] + for i in range(len(true_branch.type.types)): + new_els.append( + im.concat_where(cond, im.tuple_get(i, "__tb"), im.tuple_get(i, "__fb")) + ) + + new_node = im.let(("__tb", true_branch), ("__fb", false_branch))( + im.make_tuple(*new_els) + ) + # restore domain information + new_node, _ = infer_domain.infer_expr( + new_node, + node.annex.domain, + keep_existing_domains=True, + # offset provider not needed as all as_fieldop already have a domain + offset_provider={}, + ) + return new_node + + return node + + +expand_tuple = _TransformTupleConcatWhere.apply + + +class _ExpandConcatWhere(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply(cls, node: ir.Node): + node = cls().visit(node) + node = type_inference.SanitizeTypes().visit(node) + return node + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond, true_branch, false_branch = node.args + assert isinstance(cond.type, ts.DomainType) + position = [im.index(dim) for dim in cond.type.dims] + refs = symbol_ref_utils.collect_symbol_refs(cond) + + domains = utils.flatten_nested_tuple(node.annex.domain) + assert all( + domain == domains[0] for domain in domains + ), "At this point all `concat_where` arguments should be posed on the same domain." + assert isinstance(domains[0], domain_utils.SymbolicDomain) + domain_expr = domains[0].as_expr() + + return im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( + im.let(*zip(refs, map(im.deref, refs), strict=True))( + im.if_( + im.call("in_")(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ) + ), + domain_expr, + )(im.make_tuple(*position), true_branch, false_branch, *refs) + + return node + + +expand = _ExpandConcatWhere.apply diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py index 5b6dd9ec1e..2ad3c783da 100644 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ b/src/gt4py/next/iterator/transforms/expand_library_functions.py @@ -35,8 +35,8 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: if cpm.is_call_to(node, "in_"): ret = [] pos, domain = node.args - for i, (_, v) in enumerate( - domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items() + for i, v in enumerate( + domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.values() ): ret.append( im.and_( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 292853ef8e..9facdd14e3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,8 +12,9 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( - expand_library_functions, + concat_where_transforms, dead_code_elimination, + expand_library_functions, fuse_as_fieldop, global_tmps, infer_domain, @@ -23,7 +24,6 @@ inline_lifts, nest_concat_wheres, remove_broadcast, - transform_concat_where, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -97,8 +97,9 @@ def apply_common_transforms( # Note: executing domain inference again afterwards will give wrong domains. # This might be problematic in the temporary extraction, where we do this... + ir = concat_where_transforms.expand_tuple(ir) ir = ConstantFolding.apply(ir) # TODO: remove - ir = transform_concat_where.TransformConcatWhere.apply(ir) + ir = concat_where_transforms.expand(ir) ir = ConstantFolding.apply(ir) # TODO: remove ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) diff --git a/src/gt4py/next/iterator/transforms/transform_concat_where.py b/src/gt4py/next/iterator/transforms/transform_concat_where.py deleted file mode 100644 index 2f43948a87..0000000000 --- a/src/gt4py/next/iterator/transforms/transform_concat_where.py +++ /dev/null @@ -1,56 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next import utils -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) -from gt4py.next.iterator.transforms import symbol_ref_utils - - -class TransformConcatWhere(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ( - "type", - "domain", - ) - - @classmethod - def apply(cls, node: ir.Node): - return cls().visit(node) - - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - if cpm.is_call_to(node, "concat_where"): - cond_expr, field_a, field_b = node.args - cond = domain_utils.SymbolicDomain.from_expr(cond_expr).ranges.keys() - dims = [im.call("index")(ir.AxisLiteral(value=k.value, kind=k.kind)) for k in cond] - refs = symbol_ref_utils.collect_symbol_refs(cond_expr) - - # TODO: cleanup - domains = utils.flatten_nested_tuple(node.annex.domain) - assert all(domain == domains[0] for domain in domains) - domain_expr = domains[0].as_expr() - - return im.as_fieldop( - im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( - im.let(*zip(refs, map(im.deref, refs), strict=True))( - im.if_( - im.call("in_")(im.deref("__tcw_pos"), cond_expr), - im.deref("__tcw_arg0"), - im.deref("__tcw_arg1"), - ) - ) - ), - domain_expr, - )(im.make_tuple(*dims), field_a, field_b, *refs) - - return node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 295ec8fec9..20bffd2b10 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -369,11 +369,7 @@ def _collect_and_check_dimensions(input_: ts.TypeSpec) -> list[common.Dimension] .filter(lambda dims: len(dims) > 0) .to_list() ) - if all_input_dims: - assert all(cur_input_dims == all_input_dims[0] for cur_input_dims in all_input_dims) - return all_input_dims[0] - - return [] + return common.promote_dims(*all_input_dims) def _convert_as_fieldop_input_to_iterator( diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index fe9addb8ea..0a5ae74070 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -403,7 +403,7 @@ def verify( out: Optional[FieldViewInout] = None, inout: Optional[FieldViewInout] = None, offset_provider: Optional[OffsetProvider] = None, - comparison: Callable[[Any, Any], bool] = np.allclose, + comparison: Callable[[Any, Any], bool] = gt_utils.tree_map(np.allclose), ) -> None: """ Check the result of executing a fieldview program or operator against ref. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 548faeba9e..5796cf3534 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -12,6 +12,7 @@ from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx from gt4py.next import errors +from gt4py.next import broadcast from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -300,3 +301,47 @@ def testee( out=out, ref=(ref0, ref1), ) + + +@pytest.mark.uses_frontend_concat_where +@pytest.mark.uses_tuple_returns +def test_with_tuples_different_domain(cartesian_case): + @gtx.field_operator + def testee( + interior0: cases.IJKField, + boundary0: cases.IJKField, + interior1: cases.KField, + boundary1: cases.KField, + ) -> Tuple[cases.IJKField, cases.IJKField]: + a, b = concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) + # the broadcast is only needed since we can not return fields on different domains yet + return a, broadcast(b, (IDim, JDim, KDim)) + + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref0 = np.where( + k[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy(), + interior0.asnumpy(), + ) + ref1 = np.where( + k == 0, + boundary1.asnumpy(), + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 71cecd5adc..ecda23dc4a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause from gt4py.next import common +import pytest from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts @@ -114,7 +115,27 @@ def test_simple_tuple_get_make_tuple(): assert expected == actual -def test_propagate_tuple_get(): +@pytest.mark.parametrize("fun", ["if_", "concat_where"]) +def test_propagate_tuple_get(fun): + testee = im.tuple_get( + 0, im.call(fun)("cond", im.make_tuple("el1", "el2"), im.make_tuple("el1", "el2")) + ) + expected = im.call(fun)( + "cond", + im.tuple_get(0, im.make_tuple("el1", "el2")), + im.tuple_get(0, im.make_tuple("el1", "el2")), + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TUPLE_GET, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert expected == actual + + +def test_propagate_tuple_get_let(): expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) actual = CollapseTuple.apply( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py new file mode 100644 index 0000000000..9c00e675af --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py @@ -0,0 +1,68 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import transform_concat_where +from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + + +def test_trivial(): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) + + cond = im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}) + testee = im.concat_where(cond, "true_branch", "false_branch") + testee.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + expected = im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1")( + im.if_( + im.call("in_")(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ), + domain, + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "cond") + + actual = transform_concat_where.TransformConcatWhere.apply(testee) + actual = inline_lambdas.InlineLambdas.apply(actual) # simplify + + assert actual == expected + + +def test_capturing_cond(): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}) + + cond = im.domain(common.GridType.CARTESIAN, {IDim: ("start", "stop")}) + testee = im.concat_where(cond, "true_branch", "false_branch") + testee.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + expected = im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", "start", "stop")( + im.if_( + im.call("in_")( + im.deref("__tcw_pos"), + im.domain( + common.GridType.CARTESIAN, {IDim: (im.deref("start"), im.deref("stop"))} + ), + ), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ), + domain, + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "start", "stop") + + actual = transform_concat_where.TransformConcatWhere.apply(testee) + actual = inline_lambdas.InlineLambdas.apply(actual) # simplify + + assert actual == expected From 62db9ff3efdd5c797d53b74b0afab432301053ee Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 18 May 2025 18:41:25 +0200 Subject: [PATCH 067/124] Small fix --- .../next/iterator/transforms/concat_where_transforms.py | 2 +- src/gt4py/next/iterator/transforms/pass_manager.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where_transforms.py b/src/gt4py/next/iterator/transforms/concat_where_transforms.py index b1a1936b23..83558402c7 100644 --- a/src/gt4py/next/iterator/transforms/concat_where_transforms.py +++ b/src/gt4py/next/iterator/transforms/concat_where_transforms.py @@ -27,7 +27,7 @@ class _TransformTupleConcatWhere(PreserveLocationVisitor, NodeTranslator): ) @classmethod - def apply(cls, node: ir.Node, offset_provider_type: Optional[common.OffsetProviderType] = None): + def apply(cls, node: ir.Node, offset_provider_type: common.OffsetProviderType): node = type_inference.infer( node, offset_provider_type=offset_provider_type, diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 9facdd14e3..2ad26005fc 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -97,10 +97,10 @@ def apply_common_transforms( # Note: executing domain inference again afterwards will give wrong domains. # This might be problematic in the temporary extraction, where we do this... - ir = concat_where_transforms.expand_tuple(ir) - ir = ConstantFolding.apply(ir) # TODO: remove + ir = concat_where_transforms.expand_tuple(ir, offset_provider_type=offset_provider_type) + #ir = ConstantFolding.apply(ir) # TODO: remove ir = concat_where_transforms.expand(ir) - ir = ConstantFolding.apply(ir) # TODO: remove + #ir = ConstantFolding.apply(ir) # TODO: remove ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): From 7053c392d80e5643f722c2fed735eda7c6473e0c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 18 May 2025 18:48:25 +0200 Subject: [PATCH 068/124] Cleanup --- src/gt4py/next/type_system/type_specifications.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 4822719b40..0a572dcc0f 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -144,6 +144,4 @@ def __str__(self) -> str: class DomainType(DataType): - # TODO(tehrengruber): Remove "unknown" here again after the result type of `as_fieldop` - # is always precisely known. This is the case after #1853. - dims: list[common.Dimension] | Literal["unknown"] + dims: list[common.Dimension] From d597a4dcdfeb5ce81b2e5b59a312c21c8b3a8839 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 18 May 2025 18:51:38 +0200 Subject: [PATCH 069/124] Cleanup --- .../transforms_tests/test_transform_concat_where.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py index 9c00e675af..bc163aaad4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py @@ -8,7 +8,7 @@ from gt4py.next import common import pytest from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils -from gt4py.next.iterator.transforms import transform_concat_where +from gt4py.next.iterator.transforms import concat_where_transforms from gt4py.next.iterator.transforms import inline_lambdas from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -32,9 +32,9 @@ def test_trivial(): ) ), domain, - )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "cond") + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch") - actual = transform_concat_where.TransformConcatWhere.apply(testee) + actual = concat_where_transforms.expand(testee) actual = inline_lambdas.InlineLambdas.apply(actual) # simplify assert actual == expected @@ -62,7 +62,7 @@ def test_capturing_cond(): domain, )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "start", "stop") - actual = transform_concat_where.TransformConcatWhere.apply(testee) + actual = concat_where_transforms.expand(testee) actual = inline_lambdas.InlineLambdas.apply(actual) # simplify assert actual == expected From 45ccbbc9de08eb5795a026b7dd2cbc00c9dd67c3 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 18 May 2025 19:08:10 +0200 Subject: [PATCH 070/124] Cleanup --- src/gt4py/next/iterator/transforms/inline_fundefs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 03b20d14fe..2b8767e4a2 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -32,6 +32,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: """ Remove all function declarations that are never called. + >>> from gt4py.next import common >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> fun1 = itir.FunctionDefinition( ... id="fun1", @@ -43,6 +44,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: ... params=[im.sym("a")], ... expr=im.deref("a"), ... ) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) >>> program = itir.Program( ... id="testee", ... function_definitions=[fun1, fun2], @@ -51,7 +53,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: ... body=[ ... itir.SetAt( ... expr=im.call("fun1")("inp"), - ... domain=im.domain("cartesian_domain", {"IDim": (0, 10)}), + ... domain=im.domain("cartesian_domain", {IDim: (0, 10)}), ... target=im.ref("out"), ... ) ... ], From 0326e80d91439ab626fd95b45c176ab6b76dd1b4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 20 May 2025 15:03:51 +0200 Subject: [PATCH 071/124] Add more unit tests --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +- src/gt4py/next/iterator/ir_utils/misc.py | 19 +++++ .../transforms/concat_where_transforms.py | 1 - .../iterator/transforms/infer_domain_ops.py | 36 ++++------ .../next/iterator/transforms/pass_manager.py | 2 - .../codegens/gtfn/itir_to_gtfn_ir.py | 20 +----- .../next/type_system/type_specifications.py | 2 +- .../test_expand_library_functions.py | 39 ++++++++++ .../transforms_tests/test_infer_domain_ops.py | 71 +++++++++++++++++++ 9 files changed, 150 insertions(+), 46 deletions(-) create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 3aca150da6..739aa5d90d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -446,7 +446,7 @@ def domain( expr = call(grid_type)( *[ call("named_range")( - itir.AxisLiteral(value=d.value, kind=d.kind), + axis_literal(d), r[0], r[1], ) @@ -522,6 +522,10 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def axis_literal(dim: common.Dimension) -> itir.AxisLiteral: + return itir.AxisLiteral(value=dim.value, kind=dim.kind) + + def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): """ Promotes the function `cast_` to a field_operator. diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 63090903df..988b26c793 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -12,6 +12,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import inline_lambdas @@ -216,3 +217,21 @@ def extract_projector( projector = projector if cur_projector is None else im.compose(cur_projector, projector) projector = inline_lambdas.InlineLambdas.apply(projector) return extract_projector(expr, projector, _depth + 1) + + +def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: + if cpm.is_call_to(domain, "cartesian_domain"): + return common.GridType.CARTESIAN + else: + assert cpm.is_call_to(domain, "unstructured_domain") + return common.GridType.UNSTRUCTURED + + +def grid_type_from_program(program: itir.Program) -> common.GridType: + domains = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + grid_types = {grid_type_from_domain(d) for d in domains} + if len(grid_types) != 1: + raise ValueError( + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." + ) + return grid_types.pop() diff --git a/src/gt4py/next/iterator/transforms/concat_where_transforms.py b/src/gt4py/next/iterator/transforms/concat_where_transforms.py index 83558402c7..e294c0562e 100644 --- a/src/gt4py/next/iterator/transforms/concat_where_transforms.py +++ b/src/gt4py/next/iterator/transforms/concat_where_transforms.py @@ -5,7 +5,6 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common, utils diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index fde4b22696..6eb062979a 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -14,8 +14,9 @@ common_pattern_matcher as cpm, domain_utils, ir_makers as im, + misc as ir_misc, ) -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import _get_gridtype +from gt4py.next.iterator.type_system import inference from gt4py.next.type_system import type_specifications as ts @@ -25,26 +26,25 @@ class InferDomainOps(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, program: ir.Program): - # TODO: move _get_gridtype - return cls(grid_type=_get_gridtype(program.body)).visit(program, recurse=True) + return cls(grid_type=ir_misc.grid_type_from_program(program)).visit(program, recurse=True) def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if kwargs["recurse"]: node = self.generic_visit(node, **kwargs) - # IDim < a + # e.g. `IDim < a` if ( cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) and any(isinstance(arg, ir.AxisLiteral) for arg in node.args) and any(isinstance(arg, ir.Expr) for arg in node.args) - ): # TODO: add tests + ): arg1, arg2 = node.args if isinstance(arg2, ir.AxisLiteral): # take complementary operation if we have e.g. `0 < IDim` use `IDim > 0` complementary_op = { "less": "greater", "less_equal": "greater_equal", - "greater": "greater_equal", + "greater": "less", "greater_equal": "less_equal", "eq": "eq", "not_eq": "not_eq", @@ -54,6 +54,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: **{**kwargs, "recurse": False}, ) + inference.reinfer(arg1) assert isinstance(arg1.type, ts.DimensionType) dim: common.Dimension = arg1.type.dim value: ir.Expr = arg2 @@ -93,25 +94,16 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: # `IDim != a -> `IDim < a & IDim > a` return self.visit( im.call("and_")( - self.visit(im.less(dim, value), **kwargs), - self.visit(im.greater(dim, value), **kwargs), + self.visit( + im.less(im.axis_literal(dim), value), **(kwargs | {"recurse": False}) + ), + self.visit( + im.greater(im.axis_literal(dim), value), **(kwargs | {"recurse": False}) + ), ), - **{**kwargs, "recurse": False}, + **(kwargs | {"recurse": False}), ) else: raise AssertionError() - # if cpm.is_call_to(node, builtins.BINARY_LOGICAL_BUILTINS) and all( - # isinstance(arg.type, ts.DomainType) for arg in node.args - # ): - # if cpm.is_call_to(node, "and_"): - # # TODO: domain promotion - # return ConstantFolding.apply( - # domain_utils.domain_intersection( - # *[domain_utils.SymbolicDomain.from_expr(arg) for arg in node.args] - # ).as_expr() - # ) - # else: - # raise NotImplementedError() - return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2ad26005fc..52c67ee5cb 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -98,9 +98,7 @@ def apply_common_transforms( # Note: executing domain inference again afterwards will give wrong domains. # This might be problematic in the temporary extraction, where we do this... ir = concat_where_transforms.expand_tuple(ir, offset_provider_type=offset_provider_type) - #ir = ConstantFolding.apply(ir) # TODO: remove ir = concat_where_transforms.expand(ir) - #ir = ConstantFolding.apply(ir) # TODO: remove ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 696cfc62ea..a445390583 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -84,24 +84,6 @@ def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: return result -def _extract_grid_type(domain: itir.FunCall) -> common.GridType: - if domain.fun == itir.SymRef(id="cartesian_domain"): - return common.GridType.CARTESIAN - else: - assert domain.fun == itir.SymRef(id="unstructured_domain") - return common.GridType.UNSTRUCTURED - - -def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: - domains = _get_domains(body) - grid_types = {_extract_grid_type(d) for d in domains} - if len(grid_types) != 1: - raise ValueError( - f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." - ) - return grid_types.pop() - - def _name_from_named_range(named_range_call: itir.FunCall) -> str: assert isinstance(named_range_call, itir.FunCall) and named_range_call.fun == itir.SymRef( id="named_range" @@ -342,7 +324,7 @@ def apply( raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - grid_type = _get_gridtype(node.body) + grid_type = ir_utils_misc.grid_type_from_program(node) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0a572dcc0f..c69d1ae00d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Literal, Optional, Sequence, Union +from typing import Iterator, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py new file mode 100644 index 0000000000..40a046d866 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import textwrap + +from gt4py.eve.utils import UIDGenerator +from gt4py.next import common +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.expand_library_functions import ExpandLibraryFunctions + +from next_tests.integration_tests.cases import IDim, JDim, KDim + + +def test_trivial(): + pos = im.make_tuple(0, 1) + bounds = { + IDim: (3, 4), + JDim: (5, 6), + } + testee = im.call("in_")(pos, im.domain(common.GridType.CARTESIAN, bounds)) + expected = im.and_( + im.and_( + im.less_equal(bounds[IDim][0], im.tuple_get(0, pos)), + im.less(im.tuple_get(0, pos), bounds[IDim][1]), + ), + im.and_( + im.less_equal(bounds[JDim][0], im.tuple_get(1, pos)), + im.less(im.tuple_get(1, pos), bounds[JDim][1]), + ), + ) + actual = ExpandLibraryFunctions.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py new file mode 100644 index 0000000000..77ba3719be --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py @@ -0,0 +1,71 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import textwrap + +from gt4py.eve.utils import UIDGenerator +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.infer_domain_ops import InferDomainOps +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding + +from next_tests.integration_tests.cases import IDim, JDim, KDim + + +def test_data(): + return [ + ( + im.less(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + ), + ( + im.less_equal(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 2)}), + ), + ( + im.greater(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.greater_equal(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.less(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.less_equal(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.greater(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + ), + ( + im.greater_equal(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 2)}), + ), + (im.eq(1, im.axis_literal(IDim)), im.domain(common.GridType.CARTESIAN, {IDim: (1, 2)})), + ( + im.not_eq(1, im.axis_literal(IDim)), + im.and_( + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ), + ] + + +@pytest.mark.parametrize("testee,expected", test_data()) +def test_trivial(testee, expected): + actual = InferDomainOps(grid_type=common.GridType.CARTESIAN).visit(testee, recurse=True) + actual = ConstantFolding.apply(actual) # simplify expr to get simpler expected expressions + assert actual == expected From 1ce4ed465b51802f84882fb6e610bdc9d858feb4 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 22 May 2025 12:26:05 +0200 Subject: [PATCH 072/124] Cleanup --- .../transforms/concat_where_transforms.py | 72 ++++++++++++++++- .../iterator/transforms/infer_domain_ops.py | 29 ++++--- .../iterator/transforms/nest_concat_wheres.py | 79 ------------------- .../next/iterator/transforms/pass_manager.py | 3 +- ...t_where.py => test_expand_concat_where.py} | 2 +- .../test_expand_tuple_concat_where.py | 47 +++++++++++ .../test_nest_concat_where.py | 45 +++++++++++ 7 files changed, 179 insertions(+), 98 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/nest_concat_wheres.py rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_transform_concat_where.py => test_expand_concat_where.py} (98%) create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py diff --git a/src/gt4py/next/iterator/transforms/concat_where_transforms.py b/src/gt4py/next/iterator/transforms/concat_where_transforms.py index e294c0562e..d9c1f5cab9 100644 --- a/src/gt4py/next/iterator/transforms/concat_where_transforms.py +++ b/src/gt4py/next/iterator/transforms/concat_where_transforms.py @@ -17,6 +17,24 @@ from gt4py.next.iterator.transforms import infer_domain, symbol_ref_utils from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir, ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain + +def _range_complement( + range_: domain_utils.SymbolicRange, +) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: + # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` + assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) + return ( + domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), + domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), + ) class _TransformTupleConcatWhere(PreserveLocationVisitor, NodeTranslator): @@ -26,10 +44,17 @@ class _TransformTupleConcatWhere(PreserveLocationVisitor, NodeTranslator): ) @classmethod - def apply(cls, node: ir.Node, offset_provider_type: common.OffsetProviderType): + def apply( + cls, + node: ir.Node, + *, + offset_provider_type: common.OffsetProviderType, + allow_undeclared_symbols: bool = False + ) -> ir.Node: node = type_inference.infer( node, offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols ) return cls().visit(node) @@ -109,3 +134,48 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: expand = _ExpandConcatWhere.apply + +class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + + # TODO: do not duplicate exprs + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + # TODO: don't duplicate exprs here + if cpm.is_call_to(cond_expr, "and_"): + conds = cond_expr.args + return self.visit( + im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) + ) + if cpm.is_call_to(cond_expr, "or_"): + conds = cond_expr.args + return self.visit( + im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) + ) + + # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) + if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): + domain = SymbolicDomain.from_expr(cond_expr) + if len(domain.ranges) == 1: + dim, range_ = next(iter(domain.ranges.items())) + if domain_utils.is_finite(range_): + complement = _range_complement(range_) + new_domains = [ + im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) + for cr in complement + ] + # TODO: fp transform + return self.visit( + im.concat_where(im.call("or_")(*new_domains), field_b, field_a) + ) + else: + # TODO(tehrengruber): Implement. Note that this case can not be triggered by + # the frontend. + raise NotImplementedError() + + return node \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 6eb062979a..26e475bfcc 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -9,7 +9,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common -from gt4py.next.iterator import builtins, ir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, @@ -25,21 +25,20 @@ class InferDomainOps(PreserveLocationVisitor, NodeTranslator): grid_type: common.GridType @classmethod - def apply(cls, program: ir.Program): + def apply(cls, program: itir.Program): return cls(grid_type=ir_misc.grid_type_from_program(program)).visit(program, recurse=True) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: if kwargs["recurse"]: node = self.generic_visit(node, **kwargs) # e.g. `IDim < a` if ( cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) - and any(isinstance(arg, ir.AxisLiteral) for arg in node.args) - and any(isinstance(arg, ir.Expr) for arg in node.args) + and any(isinstance(arg, itir.AxisLiteral) for arg in node.args) ): arg1, arg2 = node.args - if isinstance(arg2, ir.AxisLiteral): + if isinstance(arg2, itir.AxisLiteral): # take complementary operation if we have e.g. `0 < IDim` use `IDim > 0` complementary_op = { "less": "greater", @@ -57,29 +56,29 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: inference.reinfer(arg1) assert isinstance(arg1.type, ts.DimensionType) dim: common.Dimension = arg1.type.dim - value: ir.Expr = arg2 + value: itir.Expr = arg2 if cpm.is_call_to(node, ("less", "less_equal", "greater", "greater_equal", "eq")): - min_: int | ir.InfinityLiteral - max_: int | ir.InfinityLiteral + min_: int | itir.InfinityLiteral + max_: int | itir.InfinityLiteral # `IDim < 1` if cpm.is_call_to(node, "less"): - min_ = ir.InfinityLiteral.NEGATIVE + min_ = itir.InfinityLiteral.NEGATIVE max_ = value # `IDim <= 1` elif cpm.is_call_to(node, "less_equal"): - min_ = ir.InfinityLiteral.NEGATIVE + min_ = itir.InfinityLiteral.NEGATIVE max_ = im.plus(value, 1) # `IDim > 1` elif cpm.is_call_to(node, "greater"): min_ = im.plus(value, 1) - max_ = ir.InfinityLiteral.POSITIVE + max_ = itir.InfinityLiteral.POSITIVE # `IDim >= 1` elif cpm.is_call_to(node, "greater_equal"): min_ = value - max_ = ir.InfinityLiteral.POSITIVE - # `IDim == 1` # TODO: isn't this removed before and rewritten as two concat_where? + max_ = itir.InfinityLiteral.POSITIVE + # `IDim == 1` elif cpm.is_call_to(node, "eq"): min_ = value max_ = im.plus(value, 1) @@ -91,7 +90,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: return domain.as_expr() elif cpm.is_call_to(node, "not_eq"): - # `IDim != a -> `IDim < a & IDim > a` + # `IDim != a` -> `IDim < a & IDim > a` return self.visit( im.call("and_")( self.visit( diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py deleted file mode 100644 index 74fa31f951..0000000000 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ /dev/null @@ -1,79 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir, ir as itir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) -from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain - - -class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): - @classmethod - def apply(cls, node: ir.Node): - return cls().visit(node) - - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - - # TODO: do not duplicate exprs - if cpm.is_call_to(node, "concat_where"): - cond_expr, field_a, field_b = node.args - # TODO: don't duplicate exprs here - if cpm.is_call_to(cond_expr, "and_"): - conds = cond_expr.args - return self.visit( - im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) - ) - if cpm.is_call_to(cond_expr, "or_"): - conds = cond_expr.args - return self.visit( - im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) - ) - if cpm.is_call_to(cond_expr, "eq"): - cond1 = im.less(cond_expr.args[0], cond_expr.args[1]) - cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) - return self.visit( - im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) - ) - - # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) - if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): - domain = SymbolicDomain.from_expr(cond_expr) - if len(domain.ranges) == 1: - dim, range_ = next(iter(domain.ranges.items())) - if domain_utils.is_finite(range_): - complement = _range_complement(range_) - new_domains = [ - im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) - for cr in complement - ] - # TODO: fp transform - return self.visit( - im.concat_where(im.call("or_")(*new_domains), field_b, field_a) - ) - else: - # TODO(tehrengruber): Implement. Note that this case can not be triggered by - # the frontend. - raise NotImplementedError() - - return node - - -def _range_complement( - range_: domain_utils.SymbolicRange, -) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: - # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` - assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) - return ( - domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), - domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), - ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 52c67ee5cb..2b83e724e4 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -22,7 +22,6 @@ inline_dynamic_shifts, inline_fundefs, inline_lifts, - nest_concat_wheres, remove_broadcast, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet @@ -86,7 +85,7 @@ def apply_common_transforms( ir ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) - ir = nest_concat_wheres.NestConcatWheres.apply(ir) + ir = concat_where_transforms.NestConcatWheres.apply(ir) ir = infer_domain.infer_program( ir, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_concat_where.py similarity index 98% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_concat_where.py index bc163aaad4..7aa36b8acf 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_concat_where.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_concat_where.py @@ -65,4 +65,4 @@ def test_capturing_cond(): actual = concat_where_transforms.expand(testee) actual = inline_lambdas.InlineLambdas.apply(actual) # simplify - assert actual == expected + assert actual == expected \ No newline at end of file diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py new file mode 100644 index 0000000000..14fd4a3938 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import concat_where_transforms, inline_lambdas, infer_domain, collapse_tuple +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +field_type = ts.FieldType(dims=[IDim], dtype=int_type) + +def test_trivial(): + cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) + symbolic_domain = domain_utils.SymbolicDomain.from_expr(domain) + + testee = im.concat_where( + cond, + im.make_tuple(im.ref("a", field_type), im.ref("c", field_type)), + im.make_tuple(im.ref("b", field_type), im.ref("d", field_type)) + ) + testee, _ = infer_domain.infer_expr( + testee, + (symbolic_domain, symbolic_domain), + keep_existing_domains=True, + offset_provider={}, + ) + + expected = im.make_tuple( + im.concat_where(cond, "a", "b"), + im.concat_where(cond, "c", "d") + ) + + actual = concat_where_transforms.expand_tuple( + testee, offset_provider_type={}, allow_undeclared_symbols=True) + + actual = collapse_tuple.CollapseTuple.apply(actual, allow_undeclared_symbols=True, within_stencil=False) + + assert actual == expected \ No newline at end of file diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py new file mode 100644 index 0000000000..1869e6b6df --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py @@ -0,0 +1,45 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import concat_where_transforms, inline_lambdas, infer_domain, collapse_tuple +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +field_type = ts.FieldType(dims=[IDim], dtype=int_type) + +def test_data(): + return [ + # testee, expected + ( + im.concat_where(im.and_("cond1", "cond2"), "a", "b"), + im.concat_where("cond1", im.concat_where("cond2", "a", "b"), "b"), + ), + ( + im.concat_where(im.or_("cond1", "cond2"), "a", "b"), + im.concat_where("cond1", "a", im.concat_where("cond2", "a", "b")), + ), + ( + im.concat_where(im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), "a", "b"), + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 0)}), + "b", + im.concat_where(im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), "b", "a") + ), + ) + ] + +@pytest.mark.parametrize("testee, expected", test_data()) +def test_nested_concat_where(testee, expected): + actual = concat_where_transforms.NestConcatWheres.apply(testee) + + assert actual == expected \ No newline at end of file From 62688b2a21e71ea57f2c69734275346075a45838 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 23 May 2025 10:52:15 +0200 Subject: [PATCH 073/124] Cleanup --- .../ffront/foast_passes/type_deduction.py | 8 +- .../transforms/concat_where/__init__.py | 18 ++ .../concat_where/expand_tuple_args.py | 68 +++++++ .../concat_where/simplify_domain_argument.py | 76 ++++++++ .../concat_where/transform_to_as_fieldop.py | 72 +++++++ .../transforms/concat_where_transforms.py | 181 ------------------ .../next/iterator/transforms/infer_domain.py | 4 +- .../iterator/transforms/infer_domain_ops.py | 5 +- .../next/iterator/transforms/pass_manager.py | 13 +- .../iterator/type_system/type_synthesizer.py | 9 +- ...=> test_concat_where_expand_tuple_args.py} | 26 ++- ...test_concat_where_simplify_domain_args.py} | 23 ++- ...t_concat_where_transform_to_as_fieldop.py} | 12 +- 13 files changed, 294 insertions(+), 221 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/concat_where/__init__.py create mode 100644 src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py create mode 100644 src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py create mode 100644 src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py delete mode 100644 src/gt4py/next/iterator/transforms/concat_where_transforms.py rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_expand_tuple_concat_where.py => test_concat_where_expand_tuple_args.py} (75%) rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_nest_concat_where.py => test_concat_where_simplify_domain_args.py} (75%) rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_expand_concat_where.py => test_concat_where_transform_to_as_fieldop.py} (86%) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 370d0be85c..f40f42e279 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -1006,13 +1006,17 @@ def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: collection_type=ts.TupleType, result_collection_constructor=lambda el: ts.TupleType(types=list(el)), ) - def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + def deduce_return_type( + tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType + ) -> ts.FieldType: if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)): raise errors.DSLError( node.location, f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", ) - return_dims = promote_dims(mask_type.dims, type_info.promote(tb, fb).dims) + return_dims = promote_dims( + mask_type.dims, type_info.extract_dims(type_info.promote(tb, fb)) + ) return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) return return_type diff --git a/src/gt4py/next/iterator/transforms/concat_where/__init__.py b/src/gt4py/next/iterator/transforms/concat_where/__init__.py new file mode 100644 index 0000000000..a9c3fb2576 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/__init__.py @@ -0,0 +1,18 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.transforms.concat_where.expand_tuple_args import expand_tuple_args +from gt4py.next.iterator.transforms.concat_where.simplify_domain_argument import ( + simplify_domain_argument, +) +from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import ( + transform_to_as_fieldop, +) + + +__all__ = ["expand_tuple_args", "simplify_domain_argument", "transform_to_as_fieldop"] diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py new file mode 100644 index 0000000000..4cf8c765df --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -0,0 +1,68 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import infer_domain +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +class _ExpandTupleArgs(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply( + cls, + node: itir.Node, + *, + offset_provider_type: common.OffsetProviderType, + allow_undeclared_symbols: bool = False, + ) -> itir.Node: + node = type_inference.infer( + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, + ) + return cls().visit(node) + + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + node = self.generic_visit(node) + + # `concat_where(cond, {a, b}, {c, d})` + # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` + if cpm.is_call_to(node, "concat_where") and isinstance(node.args[1].type, ts.TupleType): + cond, true_branch, false_branch = node.args + new_els = [] + for i in range(len(true_branch.type.types)): + new_els.append( + im.concat_where(cond, im.tuple_get(i, "__tb"), im.tuple_get(i, "__fb")) + ) + + new_node = im.let(("__tb", true_branch), ("__fb", false_branch))( + im.make_tuple(*new_els) + ) + # restore domain information + new_node, _ = infer_domain.infer_expr( + new_node, + node.annex.domain, + keep_existing_domains=True, + # offset provider not needed as all as_fieldop already have a domain + offset_provider={}, + ) + return new_node + + return node + + +expand_tuple_args = _ExpandTupleArgs.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py new file mode 100644 index 0000000000..74f24f28b7 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py @@ -0,0 +1,76 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir, ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain + + +def _range_complement( + range_: domain_utils.SymbolicRange, +) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: + # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` # noqa: RUF003 + assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) + return ( + domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), + domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), + ) + + +class _SimplifyDomainArgument(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + + # TODO: do not duplicate exprs + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + # TODO: don't duplicate exprs here + if cpm.is_call_to(cond_expr, "and_"): + conds = cond_expr.args + return self.visit( + im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) + ) + if cpm.is_call_to(cond_expr, "or_"): + conds = cond_expr.args + return self.visit( + im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) + ) + + # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) + if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): + domain = SymbolicDomain.from_expr(cond_expr) + if len(domain.ranges) == 1: + dim, range_ = next(iter(domain.ranges.items())) + if domain_utils.is_finite(range_): + complement = _range_complement(range_) + new_domains = [ + im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) + for cr in complement + ] + # TODO: fp transform + return self.visit( + im.concat_where(im.call("or_")(*new_domains), field_b, field_a) + ) + else: + # TODO(tehrengruber): Implement. Note that this case can not be triggered by + # the frontend. + raise NotImplementedError() + + return node + + +simplify_domain_argument = _SimplifyDomainArgument.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py new file mode 100644 index 0000000000..e99837b1db --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -0,0 +1,72 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms import symbol_ref_utils +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +class _TransformToAsFieldop(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply(cls, node: itir.Node): + """ + Transform `concat_where` expressions into equivalent `as_fieldop` expressions. + + Note that (backward) domain inference may not be executed after this pass as it can not + correctly infer the accessed domains when the value selection is represented as an `if_` + inside the `as_fieldop. + """ + node = cls().visit(node) + node = type_inference.SanitizeTypes().visit(node) + return node + + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond, true_branch, false_branch = node.args + assert isinstance(cond.type, ts.DomainType) + position = [im.index(dim) for dim in cond.type.dims] + refs = symbol_ref_utils.collect_symbol_refs(cond) + + domains = utils.flatten_nested_tuple(node.annex.domain) + assert all( + domain == domains[0] for domain in domains + ), "At this point all `concat_where` arguments should be posed on the same domain." + assert isinstance(domains[0], domain_utils.SymbolicDomain) + domain_expr = domains[0].as_expr() + + return im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( + im.let(*zip(refs, map(im.deref, refs), strict=True))( + im.if_( + im.call("in_")(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ) + ), + domain_expr, + )(im.make_tuple(*position), true_branch, false_branch, *refs) + + return node + + +transform_to_as_fieldop = _TransformToAsFieldop.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where_transforms.py b/src/gt4py/next/iterator/transforms/concat_where_transforms.py deleted file mode 100644 index d9c1f5cab9..0000000000 --- a/src/gt4py/next/iterator/transforms/concat_where_transforms.py +++ /dev/null @@ -1,181 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next import common, utils -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) -from gt4py.next.iterator.transforms import infer_domain, symbol_ref_utils -from gt4py.next.iterator.type_system import inference as type_inference -from gt4py.next.type_system import type_specifications as ts -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir, ir as itir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) -from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain - -def _range_complement( - range_: domain_utils.SymbolicRange, -) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: - # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` - assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) - return ( - domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), - domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), - ) - - -class _TransformTupleConcatWhere(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ( - "type", - "domain", - ) - - @classmethod - def apply( - cls, - node: ir.Node, - *, - offset_provider_type: common.OffsetProviderType, - allow_undeclared_symbols: bool = False - ) -> ir.Node: - node = type_inference.infer( - node, - offset_provider_type=offset_provider_type, - allow_undeclared_symbols=allow_undeclared_symbols - ) - return cls().visit(node) - - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - - # `concat_where(cond, {a, b}, {c, d})` - # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` - if cpm.is_call_to(node, "concat_where") and isinstance(node.args[1].type, ts.TupleType): - cond, true_branch, false_branch = node.args - new_els = [] - for i in range(len(true_branch.type.types)): - new_els.append( - im.concat_where(cond, im.tuple_get(i, "__tb"), im.tuple_get(i, "__fb")) - ) - - new_node = im.let(("__tb", true_branch), ("__fb", false_branch))( - im.make_tuple(*new_els) - ) - # restore domain information - new_node, _ = infer_domain.infer_expr( - new_node, - node.annex.domain, - keep_existing_domains=True, - # offset provider not needed as all as_fieldop already have a domain - offset_provider={}, - ) - return new_node - - return node - - -expand_tuple = _TransformTupleConcatWhere.apply - - -class _ExpandConcatWhere(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ( - "type", - "domain", - ) - - @classmethod - def apply(cls, node: ir.Node): - node = cls().visit(node) - node = type_inference.SanitizeTypes().visit(node) - return node - - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - if cpm.is_call_to(node, "concat_where"): - cond, true_branch, false_branch = node.args - assert isinstance(cond.type, ts.DomainType) - position = [im.index(dim) for dim in cond.type.dims] - refs = symbol_ref_utils.collect_symbol_refs(cond) - - domains = utils.flatten_nested_tuple(node.annex.domain) - assert all( - domain == domains[0] for domain in domains - ), "At this point all `concat_where` arguments should be posed on the same domain." - assert isinstance(domains[0], domain_utils.SymbolicDomain) - domain_expr = domains[0].as_expr() - - return im.as_fieldop( - im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( - im.let(*zip(refs, map(im.deref, refs), strict=True))( - im.if_( - im.call("in_")(im.deref("__tcw_pos"), cond), - im.deref("__tcw_arg0"), - im.deref("__tcw_arg1"), - ) - ) - ), - domain_expr, - )(im.make_tuple(*position), true_branch, false_branch, *refs) - - return node - - -expand = _ExpandConcatWhere.apply - -class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): - @classmethod - def apply(cls, node: ir.Node): - return cls().visit(node) - - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - - # TODO: do not duplicate exprs - if cpm.is_call_to(node, "concat_where"): - cond_expr, field_a, field_b = node.args - # TODO: don't duplicate exprs here - if cpm.is_call_to(cond_expr, "and_"): - conds = cond_expr.args - return self.visit( - im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) - ) - if cpm.is_call_to(cond_expr, "or_"): - conds = cond_expr.args - return self.visit( - im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) - ) - - # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) - if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): - domain = SymbolicDomain.from_expr(cond_expr) - if len(domain.ranges) == 1: - dim, range_ = next(iter(domain.ranges.items())) - if domain_utils.is_finite(range_): - complement = _range_complement(range_) - new_domains = [ - im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) - for cr in complement - ] - # TODO: fp transform - return self.visit( - im.concat_where(im.call("or_")(*new_domains), field_b, field_a) - ) - else: - # TODO(tehrengruber): Implement. Note that this case can not be triggered by - # the frontend. - raise NotImplementedError() - - return node \ No newline at end of file diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 38aa13010d..61bf8f0c73 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -362,7 +362,6 @@ def _infer_concat_where( **kwargs: Unpack[InferenceOptions], ) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "concat_where") - # assert all(isinstance(domain, domain_utils.SymbolicDomain) for domain in utils.flatten_nested_tuple(domain)) infered_args_expr = [] actual_domains: AccessedDomains = {} cond, true_field, false_field = expr.args @@ -376,7 +375,8 @@ def mapper(d: NonTupleDomainAccess): if isinstance(d, DomainAccessDescriptor): return d promoted_cond = domain_utils.promote_to_same_dimensions( - symbolic_cond if arg == true_field else cond_complement, d + symbolic_cond if arg == true_field else cond_complement, # noqa: B023 # function is never used outside the loop + d, ) return domain_utils.domain_intersection(d, promoted_cond) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 26e475bfcc..7081905f5e 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -33,9 +33,8 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: node = self.generic_visit(node, **kwargs) # e.g. `IDim < a` - if ( - cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) - and any(isinstance(arg, itir.AxisLiteral) for arg in node.args) + if cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) and any( + isinstance(arg, itir.AxisLiteral) for arg in node.args ): arg1, arg2 = node.args if isinstance(arg2, itir.AxisLiteral): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2b83e724e4..59f491c960 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,7 +12,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( - concat_where_transforms, + concat_where, dead_code_elimination, expand_library_functions, fuse_as_fieldop, @@ -85,8 +85,9 @@ def apply_common_transforms( ir ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) - ir = concat_where_transforms.NestConcatWheres.apply(ir) + ir = concat_where.simplify_domain_argument(ir) + ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -94,10 +95,7 @@ def apply_common_transforms( ) ir = remove_broadcast.RemoveBroadcast.apply(ir) - # Note: executing domain inference again afterwards will give wrong domains. - # This might be problematic in the temporary extraction, where we do this... - ir = concat_where_transforms.expand_tuple(ir, offset_provider_type=offset_provider_type) - ir = concat_where_transforms.expand(ir) + ir = concat_where.transform_to_as_fieldop(ir) ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): @@ -196,9 +194,8 @@ def apply_fieldview_transforms( ir ) # domain inference does not support dynamic offsets yet - # TODO: deduplicate with regular pass manager ir = infer_domain_ops.InferDomainOps.apply(ir) - ir = nest_concat_wheres.NestConcatWheres.apply(ir) + ir = concat_where.simplify_domain_argument.apply(ir) ir = ConstantFolding.apply(ir) ir = infer_domain.infer_program(ir, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 20bffd2b10..8cbb5f66b7 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -285,9 +285,12 @@ def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.S assert ( tb_dtype == fb_dtype ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + dtype = tb_dtype - return_dims = common.promote_dims(domain.dims, type_info.promote(tb, fb).dims) - return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(tb_dtype, fb_dtype)) + return_dims = common.promote_dims( + domain.dims, type_info.extract_dims(type_info.promote(tb, fb)) + ) + return_type = ts.FieldType(dims=return_dims, dtype=dtype) return return_type return deduce_return_type(true_field, false_field) @@ -373,7 +376,7 @@ def _collect_and_check_dimensions(input_: ts.TypeSpec) -> list[common.Dimension] def _convert_as_fieldop_input_to_iterator( - domain: it_ts.DomainType, input_: ts.TypeSpec + domain: ts.DomainType, input_: ts.TypeSpec ) -> it_ts.IteratorType: """ Convert a field operation input into an iterator type, preserving its dimensions and data type. diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py similarity index 75% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py index 14fd4a3938..42ad292043 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_tuple_concat_where.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py @@ -9,7 +9,12 @@ import pytest from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils -from gt4py.next.iterator.transforms import concat_where_transforms, inline_lambdas, infer_domain, collapse_tuple +from gt4py.next.iterator.transforms import ( + concat_where, + inline_lambdas, + infer_domain, + collapse_tuple, +) from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -17,6 +22,7 @@ IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) field_type = ts.FieldType(dims=[IDim], dtype=int_type) + def test_trivial(): cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) @@ -25,7 +31,7 @@ def test_trivial(): testee = im.concat_where( cond, im.make_tuple(im.ref("a", field_type), im.ref("c", field_type)), - im.make_tuple(im.ref("b", field_type), im.ref("d", field_type)) + im.make_tuple(im.ref("b", field_type), im.ref("d", field_type)), ) testee, _ = infer_domain.infer_expr( testee, @@ -34,14 +40,14 @@ def test_trivial(): offset_provider={}, ) - expected = im.make_tuple( - im.concat_where(cond, "a", "b"), - im.concat_where(cond, "c", "d") - ) + expected = im.make_tuple(im.concat_where(cond, "a", "b"), im.concat_where(cond, "c", "d")) - actual = concat_where_transforms.expand_tuple( - testee, offset_provider_type={}, allow_undeclared_symbols=True) + actual = concat_where.expand_tuple_args( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) - actual = collapse_tuple.CollapseTuple.apply(actual, allow_undeclared_symbols=True, within_stencil=False) + actual = collapse_tuple.CollapseTuple.apply( + actual, allow_undeclared_symbols=True, within_stencil=False + ) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py similarity index 75% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py index 1869e6b6df..cbc5f61950 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_nest_concat_where.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py @@ -9,7 +9,12 @@ import pytest from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils -from gt4py.next.iterator.transforms import concat_where_transforms, inline_lambdas, infer_domain, collapse_tuple +from gt4py.next.iterator.transforms import ( + concat_where, + inline_lambdas, + infer_domain, + collapse_tuple, +) from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -17,6 +22,7 @@ IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) field_type = ts.FieldType(dims=[IDim], dtype=int_type) + def test_data(): return [ # testee, expected @@ -33,13 +39,20 @@ def test_data(): im.concat_where( im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 0)}), "b", - im.concat_where(im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), "b", "a") + im.concat_where( + im.domain( + common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)} + ), + "b", + "a", + ), ), - ) + ), ] + @pytest.mark.parametrize("testee, expected", test_data()) def test_nested_concat_where(testee, expected): - actual = concat_where_transforms.NestConcatWheres.apply(testee) + actual = concat_where.simplify_domain_argument(testee) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_concat_where.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py similarity index 86% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_concat_where.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py index 7aa36b8acf..37ebf69b6c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_concat_where.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py @@ -6,12 +6,10 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause from gt4py.next import common -import pytest + from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils -from gt4py.next.iterator.transforms import concat_where_transforms -from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.iterator.transforms import concat_where, inline_lambdas from gt4py.next.type_system import type_specifications as ts -from gt4py.next.iterator.type_system import type_specifications as it_ts int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -34,7 +32,7 @@ def test_trivial(): domain, )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch") - actual = concat_where_transforms.expand(testee) + actual = concat_where.transform_to_as_fieldop(testee) actual = inline_lambdas.InlineLambdas.apply(actual) # simplify assert actual == expected @@ -62,7 +60,7 @@ def test_capturing_cond(): domain, )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "start", "stop") - actual = concat_where_transforms.expand(testee) + actual = concat_where.transform_to_as_fieldop(testee) actual = inline_lambdas.InlineLambdas.apply(actual) # simplify - assert actual == expected \ No newline at end of file + assert actual == expected From eb9adf918cfe1a292314282dcf89dc75cee59fd7 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 23 May 2025 10:53:48 +0200 Subject: [PATCH 074/124] Cleanup --- .../transforms/concat_where/expand_tuple_args.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py index 4cf8c765df..16b54fa088 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -53,13 +53,13 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: im.make_tuple(*new_els) ) # restore domain information - new_node, _ = infer_domain.infer_expr( - new_node, - node.annex.domain, - keep_existing_domains=True, - # offset provider not needed as all as_fieldop already have a domain - offset_provider={}, - ) + #new_node, _ = infer_domain.infer_expr( + # new_node, + # node.annex.domain, + # keep_existing_domains=True, + # # offset provider not needed as all as_fieldop already have a domain + # offset_provider={}, + # ) return new_node return node From dc98855049ce172a05a620c2c79bbdd92eae1541 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 23 May 2025 11:25:29 +0200 Subject: [PATCH 075/124] Cleanup --- src/gt4py/next/iterator/transforms/pass_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 59f491c960..3dc42fe6c4 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -195,7 +195,7 @@ def apply_fieldview_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) - ir = concat_where.simplify_domain_argument.apply(ir) + ir = concat_where.simplify_domain_argument(ir) ir = ConstantFolding.apply(ir) ir = infer_domain.infer_program(ir, offset_provider=offset_provider) From b4e5fd1f3ad3eea3bc5f67f8ea9c8c03947e380b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 23 May 2025 17:19:31 +0200 Subject: [PATCH 076/124] Cleanup --- .../ffront/foast_passes/type_deduction.py | 2 +- .../ir_utils/common_pattern_matcher.py | 6 ++- .../next/iterator/ir_utils/domain_utils.py | 18 +++---- .../concat_where/expand_tuple_args.py | 9 ---- .../concat_where/transform_to_as_fieldop.py | 24 ++++++++- .../transforms/expand_library_functions.py | 49 ------------------- .../iterator/transforms/infer_domain_ops.py | 4 +- .../next/iterator/transforms/pass_manager.py | 6 +-- .../iterator/type_system/type_synthesizer.py | 2 +- ...st_concat_where_transform_to_as_fieldop.py | 23 +++++++++ .../test_expand_library_functions.py | 39 --------------- 11 files changed, 62 insertions(+), 120 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/expand_library_functions.py delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index f40f42e279..325896393f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -1017,7 +1017,7 @@ def deduce_return_type( return_dims = promote_dims( mask_type.dims, type_info.extract_dims(type_info.promote(tb, fb)) ) - return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + return_type = ts.FieldType(dims=return_dims, dtype=t_dtype) return return_type return_type = deduce_return_type(true_branch_type, false_branch_type) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index da13d20bb6..98766518e6 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -108,8 +108,10 @@ def is_let(node: itir.Node) -> TypeGuard[_FunCallToLambda]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_ref_to(node, ref: str) -> TypeGuard[itir.SymRef]: - return isinstance(node, itir.SymRef) and node.id == ref +def is_ref_to(node, ref: str | Iterable[str]) -> TypeGuard[itir.SymRef]: + if isinstance(ref, str): + return isinstance(node, itir.SymRef) and node.id == ref + return any(is_ref_to(node, el) for el in ref) def is_identity_as_fieldop(node: itir.Expr) -> TypeGuard[_FunCallToFunCallToRef]: diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index e922171825..c01b0a0dcc 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -14,7 +14,7 @@ from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -51,7 +51,7 @@ def translate(self, distance: int) -> SymbolicRange: @dataclasses.dataclass(frozen=True) class SymbolicDomain: - grid_type: Literal["unstructured_domain", "cartesian_domain"] + grid_type: common.GridType ranges: dict[ common.Dimension, SymbolicRange ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere @@ -61,25 +61,19 @@ def __hash__(self) -> int: @classmethod def from_expr(cls, node: itir.Node) -> SymbolicDomain: - assert isinstance(node, itir.FunCall) and node.fun in [ - im.ref("unstructured_domain"), - im.ref("cartesian_domain"), - ] + assert cpm.is_call_to(node, ("unstructured_domain", "cartesian_domain")) + grid_type = getattr(common.GridType, node.fun.id[: -len("_domain")].upper()) ranges: dict[common.Dimension, SymbolicRange] = {} for named_range in node.args: - assert ( - isinstance(named_range, itir.FunCall) - and isinstance(named_range.fun, itir.SymRef) - and named_range.fun.id == "named_range" - ) + assert cpm.is_call_to(named_range, "named_range") axis_literal, lower_bound, upper_bound = named_range.args assert isinstance(axis_literal, itir.AxisLiteral) ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( SymbolicRange(lower_bound, upper_bound) ) - return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above + return cls(grid_type, ranges) def as_expr(self) -> itir.FunCall: converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py index 16b54fa088..65e9f6ca0a 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -10,7 +10,6 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import infer_domain from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts @@ -52,14 +51,6 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: new_node = im.let(("__tb", true_branch), ("__fb", false_branch))( im.make_tuple(*new_els) ) - # restore domain information - #new_node, _ = infer_domain.infer_expr( - # new_node, - # node.annex.domain, - # keep_existing_domains=True, - # # offset provider not needed as all as_fieldop already have a domain - # offset_provider={}, - # ) return new_node return node diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py index e99837b1db..c703a59e0f 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import functools + from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import utils from gt4py.next.iterator import ir as itir @@ -19,6 +21,24 @@ from gt4py.next.type_system import type_specifications as ts +def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr: + """ + Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain. + + `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` + -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` + """ + ret = [] + for i, v in enumerate(domain_utils.SymbolicDomain.from_expr(domain).ranges.values()): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) + return functools.reduce(im.and_, ret) + + class _TransformToAsFieldop(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ( "type", @@ -46,7 +66,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: position = [im.index(dim) for dim in cond.type.dims] refs = symbol_ref_utils.collect_symbol_refs(cond) - domains = utils.flatten_nested_tuple(node.annex.domain) + domains: tuple[domain_utils.SymbolicDomain, ...] = utils.flatten_nested_tuple( + node.annex.domain + ) assert all( domain == domains[0] for domain in domains ), "At this point all `concat_where` arguments should be posed on the same domain." diff --git a/src/gt4py/next/iterator/transforms/expand_library_functions.py b/src/gt4py/next/iterator/transforms/expand_library_functions.py deleted file mode 100644 index 2ad3c783da..0000000000 --- a/src/gt4py/next/iterator/transforms/expand_library_functions.py +++ /dev/null @@ -1,49 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from functools import reduce - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) - - -class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ( - "type", - "domain", - ) - - @classmethod - def apply(cls, node: ir.Node): - return cls().visit(node) - - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - - # `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` - # -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` - if cpm.is_call_to(node, "in_"): - ret = [] - pos, domain = node.args - for i, v in enumerate( - domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.values() - ): - ret.append( - im.and_( - im.less_equal(v.start, im.tuple_get(i, pos)), - im.less(im.tuple_get(i, pos), v.stop), - ) - ) # TODO(tehrengruber): Avoid position expr duplication. - return reduce(im.and_, ret) - - return node diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index 7081905f5e..6447ceca32 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -58,8 +58,8 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: value: itir.Expr = arg2 if cpm.is_call_to(node, ("less", "less_equal", "greater", "greater_equal", "eq")): - min_: int | itir.InfinityLiteral - max_: int | itir.InfinityLiteral + min_: itir.Expr + max_: itir.Expr # `IDim < 1` if cpm.is_call_to(node, "less"): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 3dc42fe6c4..e2e4217b7f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -14,7 +14,6 @@ from gt4py.next.iterator.transforms import ( concat_where, dead_code_elimination, - expand_library_functions, fuse_as_fieldop, global_tmps, infer_domain, @@ -87,7 +86,7 @@ def apply_common_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.simplify_domain_argument(ir) - ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) + ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -96,7 +95,6 @@ def apply_common_transforms( ir = remove_broadcast.RemoveBroadcast.apply(ir) ir = concat_where.transform_to_as_fieldop(ir) - ir = expand_library_functions.ExpandLibraryFunctions.apply(ir) for _ in range(10): inlined = ir @@ -196,7 +194,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.simplify_domain_argument(ir) - ir = ConstantFolding.apply(ir) + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program(ir, offset_provider=offset_provider) ir = remove_broadcast.RemoveBroadcast.apply(ir) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 8cbb5f66b7..d96d71515a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -614,7 +614,7 @@ def applied_map( assert isinstance(el_type, ts.DataType) offset_types = [arg.offset_type for arg in args if arg.offset_type] offset_type = offset_types[0] if offset_types else None - assert all(offset_type == arg for arg in offset_types) + assert all(offset_type == arg for arg in offset_types) # type: ignore[operator] # mypy not smart enough return ts.ListType(element_type=el_type, offset_type=offset_type) return applied_map diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py index 37ebf69b6c..c9e8bb5806 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py @@ -9,10 +9,33 @@ from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils from gt4py.next.iterator.transforms import concat_where, inline_lambdas +from gt4py.next.iterator.transforms.concat_where import transform_to_as_fieldop +from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import _in from gt4py.next.type_system import type_specifications as ts int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) + + +def test_in_helper(): + pos = im.make_tuple(0, 1) + bounds = { + IDim: (3, 4), + JDim: (5, 6), + } + expected = im.and_( + im.and_( + im.less_equal(bounds[IDim][0], im.tuple_get(0, pos)), + im.less(im.tuple_get(0, pos), bounds[IDim][1]), + ), + im.and_( + im.less_equal(bounds[JDim][0], im.tuple_get(1, pos)), + im.less(im.tuple_get(1, pos), bounds[JDim][1]), + ), + ) + actual = _in(pos, im.domain(common.GridType.CARTESIAN, bounds)) + assert actual == expected def test_trivial(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py deleted file mode 100644 index 40a046d866..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_expand_library_functions.py +++ /dev/null @@ -1,39 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import pytest -import textwrap - -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.expand_library_functions import ExpandLibraryFunctions - -from next_tests.integration_tests.cases import IDim, JDim, KDim - - -def test_trivial(): - pos = im.make_tuple(0, 1) - bounds = { - IDim: (3, 4), - JDim: (5, 6), - } - testee = im.call("in_")(pos, im.domain(common.GridType.CARTESIAN, bounds)) - expected = im.and_( - im.and_( - im.less_equal(bounds[IDim][0], im.tuple_get(0, pos)), - im.less(im.tuple_get(0, pos), bounds[IDim][1]), - ), - im.and_( - im.less_equal(bounds[JDim][0], im.tuple_get(1, pos)), - im.less(im.tuple_get(1, pos), bounds[JDim][1]), - ), - ) - actual = ExpandLibraryFunctions.apply(testee) - assert actual == expected From 0bd26cef933346e8e508f592bbc3c6fc75862267 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 23 May 2025 17:29:25 +0200 Subject: [PATCH 077/124] Cleanup --- .../ffront_tests/test_concat_where.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index d0a62a4bc4..a79c303630 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -162,23 +162,6 @@ def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -# @pytest.mark.uses_frontend_concat_where -# def test_dimension_two_illegal_threeway_comparison(cartesian_case): -# @gtx.field_operator -# def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: -# return concat_where(0 < KDim < (nlev - 1), interior, boundary) - -# interior = cases.allocate(cartesian_case, testee, "interior")() -# boundary = cases.allocate(cartesian_case, testee, "boundary")() -# out = cases.allocate(cartesian_case, testee, cases.RETURN)() - -# nlev = cartesian_case.default_sizes[KDim] -# k = np.arange(0, nlev) -# ref = np.where((0 < k) & (k < (nlev - 1)), interior.asnumpy(), boundary.asnumpy()) -# with pytest.raises: # TODO -# cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) - - def test_dimension_two_conditions_and(cartesian_case): @gtx.field_operator def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: From e5dbf4a25cb909fd2a7b43fbd4d32eaf6ab25573 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 23 May 2025 17:34:48 +0200 Subject: [PATCH 078/124] Cleanup --- .../next/iterator/transforms/concat_where/expand_tuple_args.py | 1 + src/gt4py/next/type_system/type_specifications.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py index 65e9f6ca0a..209ceba8c7 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -43,6 +43,7 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: if cpm.is_call_to(node, "concat_where") and isinstance(node.args[1].type, ts.TupleType): cond, true_branch, false_branch = node.args new_els = [] + assert isinstance(true_branch.type, ts.TupleType) for i in range(len(true_branch.type.types)): new_els.append( im.concat_where(cond, im.tuple_get(i, "__tb"), im.tuple_get(i, "__fb")) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0a572dcc0f..c69d1ae00d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Literal, Optional, Sequence, Union +from typing import Iterator, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common From 0e8faadd5f32a9225d735b052c0ca6685ba1d002 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 23 May 2025 23:45:38 +0200 Subject: [PATCH 079/124] Fix dace --- .../program_processors/runners/dace/gtir_builtin_translators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py index ac7df22ccc..03d5691eab 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -523,7 +523,7 @@ def parse_range_boundary(expr: gtir.Expr) -> str: domain.append((dim, lower_bound, upper_bound)) elif isinstance(node, domain_utils.SymbolicDomain): - assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + assert isinstance(node.grid_type, gtx_common.GridType) for dim, drange in node.ranges.items(): domain.append( (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) From 2499389be7e4f2ffebb420fe981afdc8820818ae Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Jun 2025 11:32:42 +0200 Subject: [PATCH 080/124] remove unchain comparison (because doesn't make sense) --- .../ffront/ast_passes/unchain_compares.py | 2 +- src/gt4py/next/ffront/func_to_foast.py | 18 ++++++++-- .../ffront_tests/test_func_to_foast.py | 36 ++++++++++++++++++- .../ffront_tests/test_type_deduction.py | 31 +++++++++------- 4 files changed, 70 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/ffront/ast_passes/unchain_compares.py b/src/gt4py/next/ffront/ast_passes/unchain_compares.py index 669a07c3d8..9f42acddc4 100644 --- a/src/gt4py/next/ffront/ast_passes/unchain_compares.py +++ b/src/gt4py/next/ffront/ast_passes/unchain_compares.py @@ -46,7 +46,7 @@ def visit_Compare(self, node: ast.Compare) -> ast.Compare | ast.BinOp: # the remainder of the chain -> right branch of the new tree # example: ``b > c > d`` - remaining_chain = copy.copy(node) + remaining_chain = copy.deepcopy(node) remaining_chain.left = remaining_chain.comparators.pop(0) remaining_chain.ops.pop(0) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 60282bf6c6..9a8ab17e54 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -10,6 +10,7 @@ import ast import builtins +import textwrap import typing from typing import Any, Callable, Iterable, Mapping, Type @@ -26,6 +27,7 @@ SingleAssignTargetPass, SingleStaticAssignPass, StringifyAnnotationsPass, + UnchainComparesPass, ) from gt4py.next.ffront.dialect_parser import DialectParser from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind @@ -472,10 +474,20 @@ def _visit_stmts( def visit_Compare(self, node: ast.Compare, **kwargs: Any) -> foast.Compare: loc = self.get_location(node) + if len(node.ops) != 1 or len(node.comparators) != 1: - # Remove comparison chains in a preprocessing pass - # TODO: maybe add a note to the error about preprocessing passes? - raise errors.UnsupportedPythonFeatureError(loc, "comparison chains") + refactored = UnchainComparesPass.apply(node) + raise errors.DSLError( + loc, + textwrap.dedent( + f""" + Comparison chains are not allowed. Please replace + {ast.unparse(node)} + by + {ast.unparse(refactored)} + """, + ), + ) return foast.Compare( op=self.visit(node.ops[0]), left=self.visit(node.left), diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index b7c46719e2..19b56a76c4 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -36,7 +36,17 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import astype, broadcast, errors, float32, float64, int32, int64, where +from gt4py.next import ( + astype, + broadcast, + errors, + float32, + float64, + int32, + int64, + where, +) +from gt4py.next.ffront.experimental import concat_where from gt4py.next.ffront import field_operator_ast as foast from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -368,3 +378,27 @@ def zero_dims_ternary( msg = r"Incompatible datatypes in operator '=='" with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) + + +def test_domain_chained_comparison_failure(): + def domain_comparison(a: gtx.Field[[TDim], float], b: gtx.Field[[TDim], float]): + return concat_where(0 < TDim < 42, a, b) + + with pytest.raises( + errors.DSLError, + match=r".*chain.*not.*allowed(?s:.)*\(0 < TDim\) & \(TDim < 42\).*", + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + +def test_field_chained_comparison_failure(): + def comparison( + cond: gtx.Field[[TDim], float], a: gtx.Field[[TDim], float], b: gtx.Field[[TDim], float] + ): + return where(0.0 < cond < 42.0, a, b) + + with pytest.raises( + errors.DSLError, + match=r".*chain.*not.*allowed(?s:.)*\(0.0 < cond\) & \(cond < 42.0\).*", + ): + _ = FieldOperatorParser.apply_to_function(comparison) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 8be3fa0dbd..a49ceba169 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -136,6 +136,25 @@ def not_int(a: Field[[TDim], int64]): _ = FieldOperatorParser.apply_to_function(not_int) +def test_compare(): + def compare(a: Field[[TDim], float64], b: Field[[TDim], float64]): + return a < b + + parsed = FieldOperatorParser.apply_to_function(compare) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + ) + + +def test_compare_wrong_dtype(): + def compare(a: Field[[TDim], float64], b: Field[[TDim], float32]): + return a < b + + with pytest.raises(errors.DSLError, match=r"Incompatible datatypes"): + _ = FieldOperatorParser.apply_to_function(compare) + + def test_concat_where(): def simple_concat_where(a: Field[[TDim], float], b: Field[[TDim], float]): return concat_where(TDim > 0, a, b) @@ -187,18 +206,6 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): _ = FieldOperatorParser.apply_to_function(domain_comparison) -def test_domain_chained_comparison_failure(): - def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): - return concat_where(0 < TDim < 42, a, b) - - # _ = FieldOperatorParser.apply_to_function(domain_comparison) - with pytest.raises( - errors.DSLError, - match=re.escape("TODO"), - ): - _ = FieldOperatorParser.apply_to_function(domain_comparison) - - @pytest.fixture def premap_setup(): X = Dimension("X") From 398ec68ce13f5a28848dda5d615c659fa081c2fb Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Jun 2025 12:01:19 +0200 Subject: [PATCH 081/124] improve error messages --- .../next/ffront/foast_passes/type_deduction.py | 7 +++++-- .../unit_tests/ffront_tests/test_type_deduction.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index f6afebd60b..67cadb6e50 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -595,18 +595,21 @@ def _deduce_dimension_compare_type( kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) ) + def error_msg(left: ts.TypeSpec, right: ts.TypeSpec) -> str: + return f"Dimension comparison needs to be between a 'Dimension' and index of type '{index_type}', got '{left}' and '{right}'." + if isinstance(left.type, ts.DimensionType): if not right.type == index_type: raise errors.DSLError( right.location, - f"Expected an {index_type}, but got '{right.type}' instead.", + error_msg(left.type, right.type), ) return ts.DomainType(dims=[left.type.dim]) elif isinstance(right.type, ts.DimensionType): if not left.type == index_type: raise errors.DSLError( left.location, - f"Expected an {index_type}, but got '{right.type}' instead.", + error_msg(left.type, right.type), ) return ts.DomainType(dims=[right.type.dim]) else: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index a49ceba169..feb84d6ccf 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -179,7 +179,7 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=re.escape("Expected an int32, but got 'float64' instead."), + match=r".*int32.*got.*float64.*", ): _ = FieldOperatorParser.apply_to_function(domain_comparison) @@ -195,6 +195,17 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): _ = FieldOperatorParser.apply_to_function(domain_comparison) +def test_domain_comparison_with_dimension_failure(): + def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): + return concat_where(TDim > TDim, a, b) + + with pytest.raises( + errors.DSLError, + match=r".*int32.*got.*TDim.*TDim.*", + ): + _ = FieldOperatorParser.apply_to_function(domain_comparison) + + def test_concat_where_invalid_dtype(): def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): return concat_where(TDim > 0, 1.0, 2) From f81393a4cd47ad33cee2e2f69befbc98f0382ae0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Jun 2025 12:01:39 +0200 Subject: [PATCH 082/124] fix chain test --- tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index f8baf2a2e1..0bdfae395f 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -705,7 +705,7 @@ def test_compare_chain(): def foo( a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64], c: gtx.Field[[TDim], float64] ) -> gtx.Field[[TDim], bool]: - return a > b > c + return (a > b) & (b > c) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) From eae7dc7cf8d7546e6086faa64a182196864fa09a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Jun 2025 12:31:11 +0200 Subject: [PATCH 083/124] simplify typing --- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/ffront/experimental.py | 16 +++------------- src/gt4py/next/ffront/fbuiltins.py | 13 ++++--------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 661fb97111..25ce060c7c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -974,7 +974,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR def _make_reduction( diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index dfa89468a5..c9bea908a8 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -6,16 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple, TypeVar +from typing import Tuple from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.ffront.fbuiltins import ( - BuiltInFunction, - FieldOffset, - FieldT, - WhereLikeBuiltinFunction, -) +from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction @BuiltInFunction @@ -23,12 +18,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi raise NotImplementedError() -_R = TypeVar("_R") -DomainT = TypeVar("DomainT", bound=common.Field) -ConcatWhereBuiltinFunction = WhereLikeBuiltinFunction[_R, DomainT, FieldT] - - -@ConcatWhereBuiltinFunction +@WhereBuiltinFunction def concat_where( mask: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 743d79ccd3..28fb5bdf72 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -137,15 +137,14 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskLikeT = TypeVar("MaskLikeT", bound=common.Field) +MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) -class WhereLikeBuiltinFunction( - BuiltInFunction[_R, [MaskLikeT, FieldT, FieldT]], - Generic[_R, MaskLikeT, FieldT], +class WhereBuiltinFunction( + BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] ): - def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> _R: + def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( @@ -160,10 +159,6 @@ def __call__(self, mask: MaskLikeT, true_field: FieldT, false_field: FieldT) -> return super().__call__(mask, true_field, false_field) -MaskT = TypeVar("MaskT", bound=common.Field) -WhereBuiltinFunction = WhereLikeBuiltinFunction[_R, MaskT, FieldT] - - @BuiltInFunction def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field: raise NotImplementedError() From 16e1c65abcf53e4321d227a6305221d770dbd216 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Jun 2025 12:34:38 +0200 Subject: [PATCH 084/124] rename --- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index c9bea908a8..b30b25b309 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi @WhereBuiltinFunction def concat_where( - mask: common.Domain, + cond: common.Domain, true_field: common.Field | core_defs.ScalarT | Tuple, false_field: common.Field | core_defs.ScalarT | Tuple, /, diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 28fb5bdf72..82832dd0f6 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -137,14 +137,14 @@ def __gt_type__(self) -> ts.FunctionType: ) -MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain]) +CondT = TypeVar("CondT", bound=Union[common.Field, common.Domain]) FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) class WhereBuiltinFunction( - BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] + BuiltInFunction[_R, [CondT, FieldT, FieldT]], Generic[_R, CondT, FieldT] ): - def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: + def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( @@ -155,8 +155,8 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(self(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` - return super().__call__(mask, true_field, false_field) + return tuple(self(cond, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return super().__call__(cond, true_field, false_field) @BuiltInFunction From 5f7e251a554cd111c746cae5dfd3c47eae2a5918 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 5 Jun 2025 10:08:00 +0200 Subject: [PATCH 085/124] add promotion tests --- .../ffront/foast_passes/type_deduction.py | 8 +++--- .../ffront_tests/test_type_deduction.py | 26 ++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 67cadb6e50..b43c4e961c 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -998,9 +998,9 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: - mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) + cond_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) - assert isinstance(mask_type, ts.DomainType) + assert isinstance(cond_type, ts.DomainType) assert all( isinstance(el, (ts.FieldType, ts.ScalarType)) for arg in (true_branch_type, false_branch_type) @@ -1019,12 +1019,12 @@ def deduce_return_type( node.location, f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", ) - assert isinstance(mask_type.dims, list) + assert isinstance(cond_type.dims, list) promoted_branches = type_info.promote(tb, fb) branches_dims = ( [] if isinstance(promoted_branches, ts.ScalarType) else promoted_branches.dims ) - return_dims = promote_dims(mask_type.dims, branches_dims) + return_dims = promote_dims(cond_type.dims, branches_dims) assert isinstance(t_dtype, ts.ScalarType) assert isinstance(f_dtype, ts.ScalarType) return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index feb84d6ccf..b31bae3094 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -34,7 +34,9 @@ from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.type_system import type_info, type_specifications as ts -TDim = Dimension("TDim") # Meaningless dimension, used for tests. +# Meaningless dimensions, used for tests. +TDim = Dimension("TDim") +SDim = Dimension("SDim") def test_unpack_assign(): @@ -173,6 +175,28 @@ def simple_concat_where(a: float, b: float): assert compare_node.type == ts.DomainType(dims=[TDim]) +def test_concat_where_promotion0(): + def concat_where_promotion(a: Field[[SDim], float], b: Field[[SDim], float]): + return concat_where(TDim > 0, a, b) + + parsed = FieldOperatorParser.apply_to_function(concat_where_promotion) + _concat_where_expr = parsed.body.stmts[0].value + assert _concat_where_expr.type == ts.FieldType( + dims=[SDim, TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_concat_where_promotion1(): + def concat_where_promotion(a: Field[[TDim], float], b: Field[[SDim], float]): + return concat_where(TDim > 0, a, b) + + parsed = FieldOperatorParser.apply_to_function(concat_where_promotion) + _concat_where_expr = parsed.body.stmts[0].value + assert _concat_where_expr.type == ts.FieldType( + dims=[SDim, TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + def test_domain_comparison_failure(): def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): return concat_where(TDim > 1.0, a, b) From b1e8f8938718ba17952a85852ead6e1571e0f682 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 4 Jun 2025 13:25:49 +0200 Subject: [PATCH 086/124] Fix small type inference bug --- .../next/iterator/transforms/concat_where/expand_tuple_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py index 209ceba8c7..3f3d8e2603 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -40,7 +40,7 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: # `concat_where(cond, {a, b}, {c, d})` # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` - if cpm.is_call_to(node, "concat_where") and isinstance(node.args[1].type, ts.TupleType): + if cpm.is_call_to(node, "concat_where") and isinstance(type_inference.reinfer(node.args[1]).type, ts.TupleType): cond, true_branch, false_branch = node.args new_els = [] assert isinstance(true_branch.type, ts.TupleType) From 06905b8a0f1f7dda285773cfb7678d5a084fdcb8 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 5 Jun 2025 23:53:10 +0200 Subject: [PATCH 087/124] Merge branch 'main' into GTIR_concat_where --- .../feature_tests/ffront_tests/test_concat_where.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index a79c303630..b95ead17f3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -19,7 +19,7 @@ exec_alloc_descriptor, ) -pytestmark = pytest.mark.uses_frontend_concat_where +pytestmark = pytest.mark.uses_concat_where def test_concat_where_simple(cartesian_case): From d89cff6ef0cb3d2bf797d17141ad731bf37a6ffb Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 00:04:53 +0200 Subject: [PATCH 088/124] Backport fixes from main PR --- .../next/ffront/foast_passes/type_deduction.py | 13 +++---------- src/gt4py/next/type_system/type_specifications.py | 6 ++---- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index b43c4e961c..6f56dc605c 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -691,8 +691,6 @@ def _deduce_binop_type( f"{err_msg} Operator " f"must be one of {', '.join((str(op) for op in logical_ops))}.", ) - assert isinstance(right.type.dims, list) - assert isinstance(left.type.dims, list) return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims)) else: raise errors.DSLError(node.location, err_msg) @@ -1019,15 +1017,10 @@ def deduce_return_type( node.location, f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", ) - assert isinstance(cond_type.dims, list) - promoted_branches = type_info.promote(tb, fb) - branches_dims = ( - [] if isinstance(promoted_branches, ts.ScalarType) else promoted_branches.dims + return_dims = promote_dims( + cond_type.dims, type_info.extract_dims(type_info.promote(tb, fb)) ) - return_dims = promote_dims(cond_type.dims, branches_dims) - assert isinstance(t_dtype, ts.ScalarType) - assert isinstance(f_dtype, ts.ScalarType) - return_type = ts.FieldType(dims=return_dims, dtype=type_info.promote(t_dtype, f_dtype)) + return_type = ts.FieldType(dims=return_dims, dtype=t_dtype) return return_type return_type = deduce_return_type(true_branch_type, false_branch_type) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 4822719b40..c69d1ae00d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Iterator, Literal, Optional, Sequence, Union +from typing import Iterator, Optional, Sequence, Union from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types from gt4py.next import common @@ -144,6 +144,4 @@ def __str__(self) -> str: class DomainType(DataType): - # TODO(tehrengruber): Remove "unknown" here again after the result type of `as_fieldop` - # is always precisely known. This is the case after #1853. - dims: list[common.Dimension] | Literal["unknown"] + dims: list[common.Dimension] From 3dac495f36eb67a697c6a2b4d2865e8f1a4723aa Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 00:07:06 +0200 Subject: [PATCH 089/124] Cleanup --- src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 98766518e6..da13d20bb6 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -108,10 +108,8 @@ def is_let(node: itir.Node) -> TypeGuard[_FunCallToLambda]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_ref_to(node, ref: str | Iterable[str]) -> TypeGuard[itir.SymRef]: - if isinstance(ref, str): - return isinstance(node, itir.SymRef) and node.id == ref - return any(is_ref_to(node, el) for el in ref) +def is_ref_to(node, ref: str) -> TypeGuard[itir.SymRef]: + return isinstance(node, itir.SymRef) and node.id == ref def is_identity_as_fieldop(node: itir.Expr) -> TypeGuard[_FunCallToFunCallToRef]: From 506c2b511fbfdd8f8c4c9150c3ceb7e10bfe8c5b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 00:18:45 +0200 Subject: [PATCH 090/124] Extract concat_where transformations --- src/gt4py/next/iterator/ir.py | 21 ++++- .../next/iterator/ir_utils/domain_utils.py | 89 +++++++++++++++--- src/gt4py/next/iterator/ir_utils/ir_makers.py | 21 ++++- src/gt4py/next/iterator/pretty_printer.py | 7 ++ .../transforms/concat_where/__init__.py | 18 ++++ .../concat_where/expand_tuple_args.py | 60 ++++++++++++ .../concat_where/simplify_domain_argument.py | 76 +++++++++++++++ .../concat_where/transform_to_as_fieldop.py | 94 +++++++++++++++++++ .../next/iterator/transforms/infer_domain.py | 56 ++++++++++- .../next/iterator/type_system/inference.py | 6 ++ .../iterator/type_system/type_synthesizer.py | 69 +++++++++++--- .../iterator_tests/test_type_inference.py | 17 ++++ .../test_concat_where_expand_tuple_args.py | 53 +++++++++++ .../test_concat_where_simplify_domain_args.py | 58 ++++++++++++ ...st_concat_where_transform_to_as_fieldop.py | 89 ++++++++++++++++++ .../transforms_tests/test_infer_domain_ops.py | 71 ++++++++++++++ 16 files changed, 775 insertions(+), 30 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/concat_where/__init__.py create mode 100644 src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py create mode 100644 src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py create mode 100644 src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index ea5cf84d86..f054cfc203 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -5,8 +5,10 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations -from typing import ClassVar, List, Optional, Union +import typing +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef @@ -63,6 +65,22 @@ class NoneLiteral(Expr): _none_literal: int = 0 +class InfinityLiteral(Expr): + # TODO(tehrengruber): self referential `ClassVar` not supported in eve. + if TYPE_CHECKING: + POSITIVE: ClassVar[InfinityLiteral] + NEGATIVE: ClassVar[InfinityLiteral] + + name: typing.Literal["POSITIVE", "NEGATIVE"] + + def __str__(self): + return f"{type(self).__name__}.{self.name}" + + +InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE") +InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") + + class OffsetLiteral(Expr): value: Union[int, str] @@ -142,3 +160,4 @@ class Program(Node, ValidatedSymbolTableTrait): Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] +InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index afcf02927a..c01b0a0dcc 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -14,7 +14,7 @@ from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -51,7 +51,7 @@ def translate(self, distance: int) -> SymbolicRange: @dataclasses.dataclass(frozen=True) class SymbolicDomain: - grid_type: Literal["unstructured_domain", "cartesian_domain"] + grid_type: common.GridType ranges: dict[ common.Dimension, SymbolicRange ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere @@ -61,25 +61,19 @@ def __hash__(self) -> int: @classmethod def from_expr(cls, node: itir.Node) -> SymbolicDomain: - assert isinstance(node, itir.FunCall) and node.fun in [ - im.ref("unstructured_domain"), - im.ref("cartesian_domain"), - ] + assert cpm.is_call_to(node, ("unstructured_domain", "cartesian_domain")) + grid_type = getattr(common.GridType, node.fun.id[: -len("_domain")].upper()) ranges: dict[common.Dimension, SymbolicRange] = {} for named_range in node.args: - assert ( - isinstance(named_range, itir.FunCall) - and isinstance(named_range.fun, itir.SymRef) - and named_range.fun.id == "named_range" - ) + assert cpm.is_call_to(named_range, "named_range") axis_literal, lower_bound, upper_bound = named_range.args assert isinstance(axis_literal, itir.AxisLiteral) ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( SymbolicRange(lower_bound, upper_bound) ) - return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above + return cls(grid_type, ranges) def as_expr(self) -> itir.FunCall: converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { @@ -183,3 +177,74 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: + """Return the (set) intersection of a list of domains.""" + new_domain_ranges = {} + assert all(domain.grid_type == domains[0].grid_type for domain in domains) + for dim in domains[0].ranges.keys(): + start = functools.reduce( + lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + [domain.ranges[dim].start for domain in domains], + ) + stop = functools.reduce( + lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + [domain.ranges[dim].stop for domain in domains], + ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr + new_domain_ranges[dim] = SymbolicRange(start, stop) + + return SymbolicDomain(domains[0].grid_type, new_domain_ranges) + + +def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: + """Return the (set) complement of a domain.""" + dims_dict = {} + for dim in domain.ranges.keys(): + lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + # `]-inf, a[` -> `[a, inf[` + if lb == itir.InfinityLiteral.NEGATIVE: + dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral.POSITIVE) + # `[a, inf]` -> `]-inf, a]` + elif ub == itir.InfinityLiteral.POSITIVE: + dims_dict[dim] = SymbolicRange(start=itir.InfinityLiteral.NEGATIVE, stop=lb) + else: + raise ValueError("Invalid domain ranges") + return SymbolicDomain(domain.grid_type, dims_dict) + + +def promote_to_same_dimensions( + domain_small: SymbolicDomain, domain_large: SymbolicDomain +) -> SymbolicDomain: + """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + dims_dict = {} + for dim in domain_large.ranges.keys(): + if dim in domain_small.ranges.keys(): + lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop + dims_dict[dim] = SymbolicRange(lb, ub) + else: + dims_dict[dim] = SymbolicRange( + itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE + ) + return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured + + +def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: + """ + Return whether a range is unbounded in (at least) one direction. + + The expression is required to be constant folded before for the result to be reliable. + """ + if isinstance(range_ := range_or_domain, SymbolicRange): + # TODO: assert no infinity literal in here + if any( + v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] + for v in [range_.start, range_.stop] + ): + return False + return True + elif isinstance(domain := range_or_domain, SymbolicDomain): + return all(is_finite(range_) for range_ in domain.ranges.values()) + raise ValueError("Expected a SymbolicRange or SymbolicDomain.") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 7acdeb2f61..739aa5d90d 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -243,6 +243,12 @@ def if_(cond, true_val, false_val): return call("if_")(cond, true_val, false_val) +def concat_where(cond, true_field, false_field): + """Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``.""" + + return call("concat_where")(cond, true_field, false_field) + + def lift(expr): """Create a lift FunCall, shorthand for ``call(call("lift")(expr))``.""" return call(call("lift")(expr)) @@ -437,18 +443,18 @@ def domain( """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" - return call(grid_type)( + expr = call(grid_type)( *[ call("named_range")( - itir.AxisLiteral(value=d.value, kind=d.kind) - if isinstance(d, common.Dimension) - else itir.AxisLiteral(value=d), + axis_literal(d), r[0], r[1], ) for d, r in ranges.items() ] ) + expr.type = ts.DomainType(dims=list(ranges.keys())) + return expr def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: @@ -478,7 +484,8 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal def _populate_domain_annex_wrapper(*args, **kwargs): node = result(*args, **kwargs) # note: if the domain is not a direct construction, e.g. because it is only a reference - # to a domain defined in a let, don't populate the annex + # to a domain defined in a let, don't populate the annex, since we can not create a + # symbolic domain for it. if domain and cpm.is_call_to(domain, ("cartesian_domain", "unstructured_domain")): node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node @@ -515,6 +522,10 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def axis_literal(dim: common.Dimension) -> itir.AxisLiteral: + return itir.AxisLiteral(value=dim.value, kind=dim.kind) + + def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): """ Promotes the function `cast_` to a field_operator. diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 7acbf5d23d..5063e26392 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -133,6 +133,13 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]: def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]: return [str(node.value)] + def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[str]: + if node == ir.InfinityLiteral.POSITIVE: + return ["∞"] + elif node == ir.InfinityLiteral.NEGATIVE: + return ["-∞"] + raise AssertionError() + def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/concat_where/__init__.py b/src/gt4py/next/iterator/transforms/concat_where/__init__.py new file mode 100644 index 0000000000..a9c3fb2576 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/__init__.py @@ -0,0 +1,18 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.transforms.concat_where.expand_tuple_args import expand_tuple_args +from gt4py.next.iterator.transforms.concat_where.simplify_domain_argument import ( + simplify_domain_argument, +) +from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import ( + transform_to_as_fieldop, +) + + +__all__ = ["expand_tuple_args", "simplify_domain_argument", "transform_to_as_fieldop"] diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py new file mode 100644 index 0000000000..3f3d8e2603 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -0,0 +1,60 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +class _ExpandTupleArgs(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply( + cls, + node: itir.Node, + *, + offset_provider_type: common.OffsetProviderType, + allow_undeclared_symbols: bool = False, + ) -> itir.Node: + node = type_inference.infer( + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, + ) + return cls().visit(node) + + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + node = self.generic_visit(node) + + # `concat_where(cond, {a, b}, {c, d})` + # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` + if cpm.is_call_to(node, "concat_where") and isinstance(type_inference.reinfer(node.args[1]).type, ts.TupleType): + cond, true_branch, false_branch = node.args + new_els = [] + assert isinstance(true_branch.type, ts.TupleType) + for i in range(len(true_branch.type.types)): + new_els.append( + im.concat_where(cond, im.tuple_get(i, "__tb"), im.tuple_get(i, "__fb")) + ) + + new_node = im.let(("__tb", true_branch), ("__fb", false_branch))( + im.make_tuple(*new_els) + ) + return new_node + + return node + + +expand_tuple_args = _ExpandTupleArgs.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py new file mode 100644 index 0000000000..74f24f28b7 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py @@ -0,0 +1,76 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir, ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain + + +def _range_complement( + range_: domain_utils.SymbolicRange, +) -> tuple[domain_utils.SymbolicRange, domain_utils.SymbolicRange]: + # `[a, b[` -> `[-inf, a[` ∪ `[b, inf[` # noqa: RUF003 + assert not any(isinstance(b, itir.InfinityLiteral) for b in [range_.start, range_.stop]) + return ( + domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, range_.start), + domain_utils.SymbolicRange(range_.stop, itir.InfinityLiteral.POSITIVE), + ) + + +class _SimplifyDomainArgument(PreserveLocationVisitor, NodeTranslator): + @classmethod + def apply(cls, node: ir.Node): + return cls().visit(node) + + def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + node = self.generic_visit(node) + + # TODO: do not duplicate exprs + if cpm.is_call_to(node, "concat_where"): + cond_expr, field_a, field_b = node.args + # TODO: don't duplicate exprs here + if cpm.is_call_to(cond_expr, "and_"): + conds = cond_expr.args + return self.visit( + im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) + ) + if cpm.is_call_to(cond_expr, "or_"): + conds = cond_expr.args + return self.visit( + im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) + ) + + # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) + if cpm.is_call_to(cond_expr, ("cartesian_domain", "unstructured_domain")): + domain = SymbolicDomain.from_expr(cond_expr) + if len(domain.ranges) == 1: + dim, range_ = next(iter(domain.ranges.items())) + if domain_utils.is_finite(range_): + complement = _range_complement(range_) + new_domains = [ + im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) + for cr in complement + ] + # TODO: fp transform + return self.visit( + im.concat_where(im.call("or_")(*new_domains), field_b, field_a) + ) + else: + # TODO(tehrengruber): Implement. Note that this case can not be triggered by + # the frontend. + raise NotImplementedError() + + return node + + +simplify_domain_argument = _SimplifyDomainArgument.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py new file mode 100644 index 0000000000..d5195c120b --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -0,0 +1,94 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import functools + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms import symbol_ref_utils +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr: + """ + Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain. + + `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` + -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` + """ + ret = [] + for i, v in enumerate(domain_utils.SymbolicDomain.from_expr(domain).ranges.values()): + ret.append( + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), + ) + ) + return functools.reduce(im.and_, ret) + + +class _TransformToAsFieldop(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply(cls, node: itir.Node): + """ + Transform `concat_where` expressions into equivalent `as_fieldop` expressions. + + Note that (backward) domain inference may not be executed after this pass as it can not + correctly infer the accessed domains when the value selection is represented as an `if_` + inside the `as_fieldop. + """ + node = cls().visit(node) + node = type_inference.SanitizeTypes().visit(node) + return node + + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + node = self.generic_visit(node) + if cpm.is_call_to(node, "concat_where"): + cond, true_branch, false_branch = node.args + assert isinstance(cond.type, ts.DomainType) + position = [im.index(dim) for dim in cond.type.dims] + refs = symbol_ref_utils.collect_symbol_refs(cond) + + domains: tuple[domain_utils.SymbolicDomain, ...] = utils.flatten_nested_tuple( + node.annex.domain + ) + assert all( + domain == domains[0] for domain in domains + ), "At this point all `concat_where` arguments should be posed on the same domain." + assert isinstance(domains[0], domain_utils.SymbolicDomain) + domain_expr = domains[0].as_expr() + + return im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", *refs)( + im.let(*zip(refs, map(im.deref, refs), strict=True))( + im.if_( + _in(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ) + ), + domain_expr, + )(im.make_tuple(*position), true_branch, false_branch, *refs) + + return node + + +transform_to_as_fieldop = _TransformToAsFieldop.apply diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 9af5268b93..61bf8f0c73 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -21,6 +21,7 @@ ir_makers as im, misc as ir_misc, ) +from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain from gt4py.next.iterator.transforms import constant_folding, trace_shifts from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -56,6 +57,7 @@ class InferenceOptions(typing.TypedDict): offset_provider: common.OffsetProvider | common.OffsetProviderType symbolic_domain_sizes: Optional[dict[str, str]] allow_uninferred: bool + keep_existing_domains: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -182,11 +184,16 @@ def _infer_as_fieldop( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, + keep_existing_domains: bool, ) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") + + if len(applied_fieldop.fun.args) == 2 and keep_existing_domains: + target_domain = SymbolicDomain.from_expr(applied_fieldop.fun.args[1]) + # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough @@ -226,6 +233,7 @@ def _infer_as_fieldop( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) transformed_inputs.append(transformed_input) @@ -348,6 +356,40 @@ def _infer_if( return result_expr, actual_domains +def _infer_concat_where( + expr: itir.Expr, + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: + assert cpm.is_call_to(expr, "concat_where") + infered_args_expr = [] + actual_domains: AccessedDomains = {} + cond, true_field, false_field = expr.args + symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond) + cond_complement = domain_utils.domain_complement(symbolic_cond) + + for arg in [true_field, false_field]: + + @tree_map + def mapper(d: NonTupleDomainAccess): + if isinstance(d, DomainAccessDescriptor): + return d + promoted_cond = domain_utils.promote_to_same_dimensions( + symbolic_cond if arg == true_field else cond_complement, # noqa: B023 # function is never used outside the loop + d, + ) + return domain_utils.domain_intersection(d, promoted_cond) + + domain_ = mapper(domain) + + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs) + infered_args_expr.append(infered_arg_expr) + actual_domains = _merge_domains(actual_domains, actual_domains_arg) + + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains + + def _infer_broadcast( expr: itir.Expr, domain: DomainAccess, @@ -380,6 +422,8 @@ def _infer_expr( return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): return _infer_if(expr, domain, **kwargs) + elif cpm.is_call_to(expr, "concat_where"): + return _infer_concat_where(expr, domain, **kwargs) elif cpm.is_call_to(expr, "broadcast"): return _infer_broadcast(expr, domain, **kwargs) elif ( @@ -399,6 +443,7 @@ def infer_expr( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, + keep_existing_domains: bool = False, ) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -413,6 +458,10 @@ def infer_expr( name that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. + # TODO: describe why this is needed with concat_where (if inside as_fieldop might shrinken the + actually access domain) + - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and + use them to propagate the domain further. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) @@ -458,8 +507,10 @@ def infer_expr( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) - expr.annex.domain = domain + if not keep_existing_domains or not hasattr(expr.annex, "domain"): + expr.annex.domain = domain return expr, accessed_domains @@ -497,6 +548,8 @@ def infer_program( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, + # TODO: add test + keep_existing_domains: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -522,6 +575,7 @@ def infer_program( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, allow_uninferred=allow_uninferred, + keep_existing_domains=keep_existing_domains, ) for stmt in program.body ], diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 7e4ad504c2..3470af6493 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -511,6 +511,12 @@ def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type + def visit_InfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + + def visit_NegInfinityLiteral(self, node: itir.InfinityLiteral, **kwargs) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + def visit_SymRef( self, node: itir.SymRef, *, ctx: dict[str, ts.TypeSpec] ) -> ts.TypeSpec | type_synthesizer.TypeSynthesizer: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 8f683c2ff9..2d9a918bf5 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -15,7 +15,7 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -149,13 +149,31 @@ def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) -@_register_builtin_type_synthesizer( - fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS -) -def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: +def synthesize_binary_math_comparison_builtins( + lhs, rhs +) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.ScalarType) and isinstance(rhs, ts.DimensionType): + return ts.DomainType(dims=[rhs.dim]) + if isinstance(lhs, ts.DimensionType) and isinstance(rhs, ts.ScalarType): + return ts.DomainType(dims=[lhs.dim]) + assert all(isinstance(lhs, (ts.ScalarType, ts.DeferredType)) for arg in (lhs, rhs)) return ts.ScalarType(kind=ts.ScalarKind.BOOL) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_LOGICAL_BUILTINS) +def _(lhs, rhs) -> ts.ScalarType | ts.TupleType | ts.DomainType: + if isinstance(lhs, ts.DomainType) and isinstance(rhs, ts.DomainType): + assert lhs.dims != "unknown" and rhs.dims != "unknown" + return ts.DomainType(dims=common.promote_dims(lhs.dims, rhs.dims)) + else: + return synthesize_binary_math_comparison_builtins(lhs, rhs) + + @_register_builtin_type_synthesizer def deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.DataType | ts.DeferredType: if isinstance(it, ts.DeferredType): @@ -245,6 +263,39 @@ def index(arg: ts.DimensionType) -> ts.FieldType: ) +@_register_builtin_type_synthesizer +def concat_where( + domain: ts.DomainType, + true_field: ts.FieldType | ts.TupleType | ts.DeferredType, + false_field: ts.FieldType | ts.TupleType | ts.DeferredType, +) -> ts.FieldType | ts.TupleType | ts.DeferredType: + if isinstance(true_field, ts.DeferredType) or isinstance(false_field, ts.DeferredType): + return ts.DeferredType(constraint=None) + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda el: ts.TupleType(types=list(el)), + ) + def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): + if any(isinstance(b, ts.DeferredType) for b in [tb, fb]): + return ts.DeferredType(constraint=ts.FieldType) + + tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) + + assert ( + tb_dtype == fb_dtype + ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + dtype = tb_dtype + + return_dims = common.promote_dims( + domain.dims, type_info.extract_dims(type_info.promote(tb, fb)) + ) + return_type = ts.FieldType(dims=return_dims, dtype=dtype) + return return_type + + return deduce_return_type(true_field, false_field) + + @_register_builtin_type_synthesizer def broadcast( arg: ts.FieldType | ts.ScalarType | ts.DeferredType, dims: tuple[ts.DimensionType] @@ -321,11 +372,7 @@ def _collect_and_check_dimensions(input_: ts.TypeSpec) -> list[common.Dimension] .filter(lambda dims: len(dims) > 0) .to_list() ) - if all_input_dims: - assert all(cur_input_dims == all_input_dims[0] for cur_input_dims in all_input_dims) - return all_input_dims[0] - - return [] + return common.promote_dims(*all_input_dims) def _convert_as_fieldop_input_to_iterator( @@ -571,7 +618,7 @@ def applied_map( assert isinstance(el_type, ts.DataType) offset_types = [arg.offset_type for arg in args if arg.offset_type] offset_type = offset_types[0] if offset_types else None - assert all(offset_type == arg for arg in offset_types) + assert all(offset_type == arg for arg in offset_types) # type: ignore[operator] # mypy not smart enough return ts.ListType(element_type=el_type, offset_type=offset_type) return applied_map diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 4a574f256a..0589463777 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -259,6 +259,23 @@ def expression_test_cases(): ), ts.TupleType(types=[float_i_field, float_i_field]), ), + # concat_where + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", float_i_field), + im.ref("b", float_ij_field), + ), + float_ij_field, + ), + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", ts.TupleType(types=[float_i_field] * 2)), + im.ref("b", ts.TupleType(types=[float_i_field] * 2)), + ), + ts.TupleType(types=[float_i_field] * 2), + ), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py new file mode 100644 index 0000000000..42ad292043 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_expand_tuple_args.py @@ -0,0 +1,53 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import ( + concat_where, + inline_lambdas, + infer_domain, + collapse_tuple, +) +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +field_type = ts.FieldType(dims=[IDim], dtype=int_type) + + +def test_trivial(): + cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) + symbolic_domain = domain_utils.SymbolicDomain.from_expr(domain) + + testee = im.concat_where( + cond, + im.make_tuple(im.ref("a", field_type), im.ref("c", field_type)), + im.make_tuple(im.ref("b", field_type), im.ref("d", field_type)), + ) + testee, _ = infer_domain.infer_expr( + testee, + (symbolic_domain, symbolic_domain), + keep_existing_domains=True, + offset_provider={}, + ) + + expected = im.make_tuple(im.concat_where(cond, "a", "b"), im.concat_where(cond, "c", "d")) + + actual = concat_where.expand_tuple_args( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + + actual = collapse_tuple.CollapseTuple.apply( + actual, allow_undeclared_symbols=True, within_stencil=False + ) + + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py new file mode 100644 index 0000000000..cbc5f61950 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py @@ -0,0 +1,58 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common +import pytest +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import ( + concat_where, + inline_lambdas, + infer_domain, + collapse_tuple, +) +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +field_type = ts.FieldType(dims=[IDim], dtype=int_type) + + +def test_data(): + return [ + # testee, expected + ( + im.concat_where(im.and_("cond1", "cond2"), "a", "b"), + im.concat_where("cond1", im.concat_where("cond2", "a", "b"), "b"), + ), + ( + im.concat_where(im.or_("cond1", "cond2"), "a", "b"), + im.concat_where("cond1", "a", im.concat_where("cond2", "a", "b")), + ), + ( + im.concat_where(im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), "a", "b"), + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 0)}), + "b", + im.concat_where( + im.domain( + common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)} + ), + "b", + "a", + ), + ), + ), + ] + + +@pytest.mark.parametrize("testee, expected", test_data()) +def test_nested_concat_where(testee, expected): + actual = concat_where.simplify_domain_argument(testee) + + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py new file mode 100644 index 0000000000..c9e8bb5806 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py @@ -0,0 +1,89 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common + +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import concat_where, inline_lambdas +from gt4py.next.iterator.transforms.concat_where import transform_to_as_fieldop +from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import _in +from gt4py.next.type_system import type_specifications as ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) + + +def test_in_helper(): + pos = im.make_tuple(0, 1) + bounds = { + IDim: (3, 4), + JDim: (5, 6), + } + expected = im.and_( + im.and_( + im.less_equal(bounds[IDim][0], im.tuple_get(0, pos)), + im.less(im.tuple_get(0, pos), bounds[IDim][1]), + ), + im.and_( + im.less_equal(bounds[JDim][0], im.tuple_get(1, pos)), + im.less(im.tuple_get(1, pos), bounds[JDim][1]), + ), + ) + actual = _in(pos, im.domain(common.GridType.CARTESIAN, bounds)) + assert actual == expected + + +def test_trivial(): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 2)}) + + cond = im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}) + testee = im.concat_where(cond, "true_branch", "false_branch") + testee.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + expected = im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1")( + im.if_( + im.call("in_")(im.deref("__tcw_pos"), cond), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ), + domain, + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch") + + actual = concat_where.transform_to_as_fieldop(testee) + actual = inline_lambdas.InlineLambdas.apply(actual) # simplify + + assert actual == expected + + +def test_capturing_cond(): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}) + + cond = im.domain(common.GridType.CARTESIAN, {IDim: ("start", "stop")}) + testee = im.concat_where(cond, "true_branch", "false_branch") + testee.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + expected = im.as_fieldop( + im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", "start", "stop")( + im.if_( + im.call("in_")( + im.deref("__tcw_pos"), + im.domain( + common.GridType.CARTESIAN, {IDim: (im.deref("start"), im.deref("stop"))} + ), + ), + im.deref("__tcw_arg0"), + im.deref("__tcw_arg1"), + ) + ), + domain, + )(im.make_tuple(im.index(IDim)), "true_branch", "false_branch", "start", "stop") + + actual = concat_where.transform_to_as_fieldop(testee) + actual = inline_lambdas.InlineLambdas.apply(actual) # simplify + + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py new file mode 100644 index 0000000000..77ba3719be --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_infer_domain_ops.py @@ -0,0 +1,71 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import textwrap + +from gt4py.eve.utils import UIDGenerator +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.infer_domain_ops import InferDomainOps +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding + +from next_tests.integration_tests.cases import IDim, JDim, KDim + + +def test_data(): + return [ + ( + im.less(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + ), + ( + im.less_equal(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 2)}), + ), + ( + im.greater(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.greater_equal(im.axis_literal(IDim), 1), + im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.less(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.less_equal(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (1, itir.InfinityLiteral.POSITIVE)}), + ), + ( + im.greater(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + ), + ( + im.greater_equal(1, im.axis_literal(IDim)), + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 2)}), + ), + (im.eq(1, im.axis_literal(IDim)), im.domain(common.GridType.CARTESIAN, {IDim: (1, 2)})), + ( + im.not_eq(1, im.axis_literal(IDim)), + im.and_( + im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 1)}), + im.domain(common.GridType.CARTESIAN, {IDim: (2, itir.InfinityLiteral.POSITIVE)}), + ), + ), + ] + + +@pytest.mark.parametrize("testee,expected", test_data()) +def test_trivial(testee, expected): + actual = InferDomainOps(grid_type=common.GridType.CARTESIAN).visit(testee, recurse=True) + actual = ConstantFolding.apply(actual) # simplify expr to get simpler expected expressions + assert actual == expected From af36bc929295f2fcc405f40b8b6535a22b3210df Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 00:19:48 +0200 Subject: [PATCH 091/124] Small fix --- .../test_concat_where_transform_to_as_fieldop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py index c9e8bb5806..2517c7ab55 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_transform_to_as_fieldop.py @@ -47,7 +47,7 @@ def test_trivial(): expected = im.as_fieldop( im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1")( im.if_( - im.call("in_")(im.deref("__tcw_pos"), cond), + _in(im.deref("__tcw_pos"), cond), im.deref("__tcw_arg0"), im.deref("__tcw_arg1"), ) @@ -70,7 +70,7 @@ def test_capturing_cond(): expected = im.as_fieldop( im.lambda_("__tcw_pos", "__tcw_arg0", "__tcw_arg1", "start", "stop")( im.if_( - im.call("in_")( + _in( im.deref("__tcw_pos"), im.domain( common.GridType.CARTESIAN, {IDim: (im.deref("start"), im.deref("stop"))} From f1a99bd64d94364967cfdfecc6b784b4ce4da2ca Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 00:20:07 +0200 Subject: [PATCH 092/124] Format --- .../iterator/transforms/concat_where/expand_tuple_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py index 3f3d8e2603..ea4086976f 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -40,7 +40,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: # `concat_where(cond, {a, b}, {c, d})` # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` - if cpm.is_call_to(node, "concat_where") and isinstance(type_inference.reinfer(node.args[1]).type, ts.TupleType): + if cpm.is_call_to(node, "concat_where") and isinstance( + type_inference.reinfer(node.args[1]).type, ts.TupleType + ): cond, true_branch, false_branch = node.args new_els = [] assert isinstance(true_branch.type, ts.TupleType) From 1f6b284cea3efbde36bb811f1e8f7fb35510502f Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 00:21:00 +0200 Subject: [PATCH 093/124] Format --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 2d9a918bf5..8a69bda9c0 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -618,7 +618,7 @@ def applied_map( assert isinstance(el_type, ts.DataType) offset_types = [arg.offset_type for arg in args if arg.offset_type] offset_type = offset_types[0] if offset_types else None - assert all(offset_type == arg for arg in offset_types) # type: ignore[operator] # mypy not smart enough + assert all(offset_type == arg for arg in offset_types) return ts.ListType(element_type=el_type, offset_type=offset_type) return applied_map From 45a2e233b32da62d6616b0af9f8f5b878a7ac21a Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 01:07:43 +0200 Subject: [PATCH 094/124] Cleanup --- .../iterator/transforms/collapse_tuple.py | 2 +- .../concat_where/simplify_domain_argument.py | 51 ++++++++++++------ .../iterator/transforms/constant_folding.py | 2 +- .../transforms/fixed_point_transformation.py | 54 ++++++++++++------- .../iterator/transforms/fuse_as_fieldop.py | 2 +- .../test_concat_where_simplify_domain_args.py | 6 +-- 6 files changed, 73 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 6b5d4d3ed7..38426cacc5 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -113,7 +113,7 @@ def _flattened_as_fieldop_param_el_name(param: str, idx: int) -> str: # should revisit the pattern here and try to find a more general mechanism. @dataclasses.dataclass(frozen=True, kw_only=True) class CollapseTuple( - fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor + fixed_point_transformation.CombinedFixedPointTransform, eve.PreserveLocationVisitor ): """ Simplifies `make_tuple`, `tuple_get` calls. diff --git a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py index 74f24f28b7..b626bcb5a1 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py +++ b/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py @@ -6,14 +6,17 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir, ir as itir +from typing import Optional + +from gt4py.eve import PreserveLocationVisitor +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, ir_makers as im, ) from gt4py.next.iterator.ir_utils.domain_utils import SymbolicDomain +from gt4py.next.iterator.transforms import fixed_point_transformation def _range_complement( @@ -27,27 +30,41 @@ def _range_complement( ) -class _SimplifyDomainArgument(PreserveLocationVisitor, NodeTranslator): +class _SimplifyDomainArgument( + PreserveLocationVisitor, fixed_point_transformation.FixedPointTransformation +): @classmethod - def apply(cls, node: ir.Node): + def apply(cls, node: itir.Node): return cls().visit(node) - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: - node = self.generic_visit(node) - - # TODO: do not duplicate exprs + def transform(self, node: itir.Node) -> Optional[itir.Node]: # type: ignore[override] # ignore kwargs for simplicity if cpm.is_call_to(node, "concat_where"): cond_expr, field_a, field_b = node.args - # TODO: don't duplicate exprs here if cpm.is_call_to(cond_expr, "and_"): conds = cond_expr.args - return self.visit( - im.concat_where(conds[0], im.concat_where(conds[1], field_a, field_b), field_b) + return im.let(("__cwsda_field_a", field_a), ("__cwsda_field_b", field_b))( + self.fp_transform( + im.concat_where( + conds[0], + self.fp_transform( + im.concat_where(conds[1], "__cwsda_field_a", "__cwsda_field_b") + ), + "__cwsda_field_b", + ) + ) ) if cpm.is_call_to(cond_expr, "or_"): conds = cond_expr.args - return self.visit( - im.concat_where(conds[0], field_a, im.concat_where(conds[1], field_a, field_b)) + return im.let(("__cwsda_field_a", field_a), ("__cwsda_field_b", field_b))( + self.fp_transform( + im.concat_where( + conds[0], + "__cwsda_field_a", + self.fp_transform( + im.concat_where(conds[1], "__cwsda_field_a", "__cwsda_field_b") + ), + ) + ) ) # concat_where([1, 2[, a, b) -> concat_where([-inf, 1] | [2, inf[, b, a) @@ -61,16 +78,16 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: im.domain(domain.grid_type, {dim: (cr.start, cr.stop)}) for cr in complement ] - # TODO: fp transform - return self.visit( + return self.fp_transform( im.concat_where(im.call("or_")(*new_domains), field_b, field_a) ) else: # TODO(tehrengruber): Implement. Note that this case can not be triggered by - # the frontend. + # the frontend yet since domains can only be created by expressions like + # `IDim < 10`. raise NotImplementedError() - return node + return None simplify_domain_argument = _SimplifyDomainArgument.apply diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 774e7a6702..4a16122a71 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -56,7 +56,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: @dataclasses.dataclass(frozen=True, kw_only=True) class ConstantFolding( - fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor + fixed_point_transformation.CombinedFixedPointTransform, eve.PreserveLocationVisitor ): PRESERVED_ANNEX_ATTRS = ( "type", diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index 598edaaf1f..c9b253ed7a 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -15,20 +15,11 @@ from gt4py.next.iterator.type_system import inference as itir_type_inference -@dataclasses.dataclass(frozen=True, kw_only=True) class FixedPointTransformation(eve.NodeTranslator): """ - Transformation pass that transforms until no transformation is applicable anymore. + Base class for iterative transformations that converge when a fixed-point is reached. """ - #: Enum of all transformation (names). The transformations need to be defined as methods - #: named `transform_`. - Transformation: ClassVar[Type[enum.Flag]] - - #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. - #: Usually the default value is chosen to be all transformations. - enabled_transformations: enum.Flag - REINFER_TYPES: ClassVar[bool] = False def visit(self, node, **kwargs): @@ -43,18 +34,44 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: new_node = self.transform(node, **kwargs) if new_node is None: break + else: + new_node = self.post_transform(node, new_node) assert new_node != node node = new_node return node - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - """ - Transform node once. + def post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node: + if self.REINFER_TYPES: + itir_type_inference.reinfer(new_node) + self._preserve_annex(node, new_node) + return new_node - Execute transformations until one is applicable. As soon as a transformation occured - the function will return the transformed node. Note that the transformation itself - may call other transformations on child nodes again. - """ + """ + Transform node once. + + Execute transformation if applicable. When a transformation occurred the function will return + the transformed node. Note that the transformation itself may call other transformations on + child nodes again. + """ + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: ... + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class CombinedFixedPointTransform(FixedPointTransformation): + """ + Base class for a set of iterative transformations that converge when a fixed-point is reached. + """ + + #: Enum of all transformation (names). The transformations need to be defined as methods + #: named `transform_`. + Transformation: ClassVar[Type[enum.Flag]] + + #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. + #: Usually the default value is chosen to be all transformations. + enabled_transformations: enum.Flag + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: for transformation in self.Transformation: if self.enabled_transformations & transformation: assert isinstance(transformation.name, str) @@ -64,8 +81,5 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: assert ( result is not node ), f"Transformation {transformation.name.lower()} should have returned None, since nothing changed." - if self.REINFER_TYPES: - itir_type_inference.reinfer(result) - self._preserve_annex(node, result) return result return None diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 26a8bcad1c..0e01dafed0 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -231,7 +231,7 @@ def _make_tuple_element_inline_predicate(node: itir.Expr): @dataclasses.dataclass(frozen=True, kw_only=True) class FuseAsFieldOp( - fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor + fixed_point_transformation.CombinedFixedPointTransform, eve.PreserveLocationVisitor ): """ Merge multiple `as_fieldop` calls into one. diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py index cbc5f61950..beca1084b4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py @@ -8,15 +8,12 @@ from gt4py.next import common import pytest from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import ( concat_where, inline_lambdas, - infer_domain, - collapse_tuple, ) from gt4py.next.type_system import type_specifications as ts -from gt4py.next.iterator.type_system import type_specifications as it_ts int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -54,5 +51,6 @@ def test_data(): @pytest.mark.parametrize("testee, expected", test_data()) def test_nested_concat_where(testee, expected): actual = concat_where.simplify_domain_argument(testee) + actual = inline_lambdas.InlineLambdas.apply(actual, opcount_preserving=True) assert actual == expected From 52c96ed7f2c471c60ce27059c0f9d130b2d0eb32 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 10:07:35 +0200 Subject: [PATCH 095/124] Cleanup --- src/gt4py/next/iterator/transforms/infer_domain.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 61bf8f0c73..cd41d94451 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -458,10 +458,13 @@ def infer_expr( name that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. - # TODO: describe why this is needed with concat_where (if inside as_fieldop might shrinken the actually access domain) - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and - use them to propagate the domain further. + use them to propagate the domain further. This is useful in cases where after a + transformation some nodes are missing domain information that needs to be repopulated, + but we can't reinfer everything because some domain access information has been lost. + For example when a `concat_where` is transformed into an `as_fieldop` with an if we lose + some information that could lead to unnecessary overcomputation and out-of-bounds accesses. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) @@ -548,7 +551,6 @@ def infer_program( offset_provider: common.OffsetProvider | common.OffsetProviderType, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, - # TODO: add test keep_existing_domains: bool = False, ) -> itir.Program: """ From 034c660a80729ef2360a60fbb40c28637b9cc55f Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 10:16:15 +0200 Subject: [PATCH 096/124] Cleanup --- .../iterator/transforms/infer_domain_ops.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/infer_domain_ops.py diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py new file mode 100644 index 0000000000..6447ceca32 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -0,0 +1,107 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, + misc as ir_misc, +) +from gt4py.next.iterator.type_system import inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class InferDomainOps(PreserveLocationVisitor, NodeTranslator): + grid_type: common.GridType + + @classmethod + def apply(cls, program: itir.Program): + return cls(grid_type=ir_misc.grid_type_from_program(program)).visit(program, recurse=True) + + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: + if kwargs["recurse"]: + node = self.generic_visit(node, **kwargs) + + # e.g. `IDim < a` + if cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) and any( + isinstance(arg, itir.AxisLiteral) for arg in node.args + ): + arg1, arg2 = node.args + if isinstance(arg2, itir.AxisLiteral): + # take complementary operation if we have e.g. `0 < IDim` use `IDim > 0` + complementary_op = { + "less": "greater", + "less_equal": "greater_equal", + "greater": "less", + "greater_equal": "less_equal", + "eq": "eq", + "not_eq": "not_eq", + } + return self.visit( + im.call(complementary_op[node.fun.id])(arg2, arg1), + **{**kwargs, "recurse": False}, + ) + + inference.reinfer(arg1) + assert isinstance(arg1.type, ts.DimensionType) + dim: common.Dimension = arg1.type.dim + value: itir.Expr = arg2 + + if cpm.is_call_to(node, ("less", "less_equal", "greater", "greater_equal", "eq")): + min_: itir.Expr + max_: itir.Expr + + # `IDim < 1` + if cpm.is_call_to(node, "less"): + min_ = itir.InfinityLiteral.NEGATIVE + max_ = value + # `IDim <= 1` + elif cpm.is_call_to(node, "less_equal"): + min_ = itir.InfinityLiteral.NEGATIVE + max_ = im.plus(value, 1) + # `IDim > 1` + elif cpm.is_call_to(node, "greater"): + min_ = im.plus(value, 1) + max_ = itir.InfinityLiteral.POSITIVE + # `IDim >= 1` + elif cpm.is_call_to(node, "greater_equal"): + min_ = value + max_ = itir.InfinityLiteral.POSITIVE + # `IDim == 1` + elif cpm.is_call_to(node, "eq"): + min_ = value + max_ = im.plus(value, 1) + + domain = domain_utils.SymbolicDomain( + self.grid_type, + ranges={dim: domain_utils.SymbolicRange(start=min_, stop=max_)}, + ) + + return domain.as_expr() + elif cpm.is_call_to(node, "not_eq"): + # `IDim != a` -> `IDim < a & IDim > a` + return self.visit( + im.call("and_")( + self.visit( + im.less(im.axis_literal(dim), value), **(kwargs | {"recurse": False}) + ), + self.visit( + im.greater(im.axis_literal(dim), value), **(kwargs | {"recurse": False}) + ), + ), + **(kwargs | {"recurse": False}), + ) + else: + raise AssertionError() + + return node From c46045927cf283cae528a80c59351ddfe20e4e06 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 10:22:26 +0200 Subject: [PATCH 097/124] Cleanup --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index c01b0a0dcc..f8923be3bb 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -228,7 +228,7 @@ def promote_to_same_dimensions( dims_dict[dim] = SymbolicRange( itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE ) - return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured + return SymbolicDomain(domain_small.grid_type, dims_dict) def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: From bbf00165bf27560c3dd2f67f742a6cedb2b65348 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 10:23:52 +0200 Subject: [PATCH 098/124] Cleanup --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f8923be3bb..667d4ea9fe 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -238,7 +238,6 @@ def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: The expression is required to be constant folded before for the result to be reliable. """ if isinstance(range_ := range_or_domain, SymbolicRange): - # TODO: assert no infinity literal in here if any( v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] for v in [range_.start, range_.stop] From cbeee8e244fbad9d134c591c4ac77c64dd1c964b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 10:24:07 +0200 Subject: [PATCH 099/124] Cleanup --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 667d4ea9fe..b745051cce 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -246,4 +246,5 @@ def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: return True elif isinstance(domain := range_or_domain, SymbolicDomain): return all(is_finite(range_) for range_ in domain.ranges.values()) - raise ValueError("Expected a SymbolicRange or SymbolicDomain.") + raise ValueError("Expected a 'SymbolicRange' or 'SymbolicDomain'.") + From 9d179c6701240982cbc1d8af6865a0414c3dfae3 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 10:26:09 +0200 Subject: [PATCH 100/124] Fix infer domain ops --- .../next/iterator/ir_utils/domain_utils.py | 1 - src/gt4py/next/iterator/ir_utils/misc.py | 19 ++++++++++++++++++ .../codegens/gtfn/itir_to_gtfn_ir.py | 20 +------------------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index b745051cce..9cab0bc9cd 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -247,4 +247,3 @@ def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: elif isinstance(domain := range_or_domain, SymbolicDomain): return all(is_finite(range_) for range_ in domain.ranges.values()) raise ValueError("Expected a 'SymbolicRange' or 'SymbolicDomain'.") - diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 63090903df..988b26c793 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -12,6 +12,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import inline_lambdas @@ -216,3 +217,21 @@ def extract_projector( projector = projector if cur_projector is None else im.compose(cur_projector, projector) projector = inline_lambdas.InlineLambdas.apply(projector) return extract_projector(expr, projector, _depth + 1) + + +def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: + if cpm.is_call_to(domain, "cartesian_domain"): + return common.GridType.CARTESIAN + else: + assert cpm.is_call_to(domain, "unstructured_domain") + return common.GridType.UNSTRUCTURED + + +def grid_type_from_program(program: itir.Program) -> common.GridType: + domains = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + grid_types = {grid_type_from_domain(d) for d in domains} + if len(grid_types) != 1: + raise ValueError( + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." + ) + return grid_types.pop() diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 696cfc62ea..a445390583 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -84,24 +84,6 @@ def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: return result -def _extract_grid_type(domain: itir.FunCall) -> common.GridType: - if domain.fun == itir.SymRef(id="cartesian_domain"): - return common.GridType.CARTESIAN - else: - assert domain.fun == itir.SymRef(id="unstructured_domain") - return common.GridType.UNSTRUCTURED - - -def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: - domains = _get_domains(body) - grid_types = {_extract_grid_type(d) for d in domains} - if len(grid_types) != 1: - raise ValueError( - f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." - ) - return grid_types.pop() - - def _name_from_named_range(named_range_call: itir.FunCall) -> str: assert isinstance(named_range_call, itir.FunCall) and named_range_call.fun == itir.SymRef( id="named_range" @@ -342,7 +324,7 @@ def apply( raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - grid_type = _get_gridtype(node.body) + grid_type = ir_utils_misc.grid_type_from_program(node) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( From a421f79740f0ca096cec0673bb11c5d952598ac1 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Jun 2025 12:49:06 +0200 Subject: [PATCH 101/124] Fix failing doctest --- src/gt4py/next/iterator/transforms/inline_fundefs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 03b20d14fe..2b8767e4a2 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -32,6 +32,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: """ Remove all function declarations that are never called. + >>> from gt4py.next import common >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> fun1 = itir.FunctionDefinition( ... id="fun1", @@ -43,6 +44,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: ... params=[im.sym("a")], ... expr=im.deref("a"), ... ) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) >>> program = itir.Program( ... id="testee", ... function_definitions=[fun1, fun2], @@ -51,7 +53,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: ... body=[ ... itir.SetAt( ... expr=im.call("fun1")("inp"), - ... domain=im.domain("cartesian_domain", {"IDim": (0, 10)}), + ... domain=im.domain("cartesian_domain", {IDim: (0, 10)}), ... target=im.ref("out"), ... ) ... ], From aadf5823e18fe0ed0854bb67aecba876f1c55043 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Jun 2025 14:07:40 +0200 Subject: [PATCH 102/124] remove uses_concat_where from COMMON_SKIP_TEST_LIST --- tests/next_tests/definitions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index d4d30f3e9d..ad3ff4bbfc 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -144,7 +144,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ From 2bd9d3b8438fb520ec4ef9488d9195965fa400e8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 25 Jun 2025 11:44:28 +0200 Subject: [PATCH 103/124] add test cases for empty branches --- .../ffront_tests/test_concat_where.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index b95ead17f3..d7eb667746 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -65,6 +65,38 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) +def test_concat_where_scalar_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: + return concat_where(KDim < N - 1, a, b) + + a = 3 + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.concatenate( + ( + np.full((*out.domain.shape[0:2], out.domain.shape[2] - 1), a), + b.asnumpy()[:, :, -1:], + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, cartesian_case.default_sizes[KDim], out=out, ref=ref) + + +def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case): + @gtx.field_operator + def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + + a = 3 + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain=b.domain.slice_at[:, :, 1:])() + + ref = b.asnumpy()[:, :, 1:] + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + + def test_concat_where_single_level_broadcast(cartesian_case): @gtx.field_operator def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: @@ -144,6 +176,21 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +def test_boundary_single_layer_2d_bc_on_empty_branch(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate( + cartesian_case, testee, cases.RETURN, domain=interior.domain.slice_at[:, :, 1:] + )() + + ref = interior.asnumpy()[:, :, 1:] + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + + def test_dimension_two_nested_conditions(cartesian_case): @gtx.field_operator def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: @@ -273,6 +320,28 @@ def testee( ) +def test_nested_conditions_with_empty_branches(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IField, boundary: cases.IField, N: gtx.int32) -> cases.IField: + interior = concat_where(IDim == 0, boundary, interior) + interior = concat_where((1 <= IDim) & (IDim < N - 1), interior * 2, interior) + interior = concat_where(IDim == N - 1, boundary, interior) + return interior + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + N = cartesian_case.default_sizes[IDim] + + i = np.arange(0, cartesian_case.default_sizes[IDim]) + ref = np.where( + (i[:] == 0) | (i[:] == N - 1), + boundary.asnumpy(), + interior.asnumpy() * 2, + ) + cases.verify(cartesian_case, testee, interior, boundary, N, out=out, ref=ref) + + @pytest.mark.uses_tuple_returns def test_with_tuples_different_domain(cartesian_case): @gtx.field_operator From 78a61ca9d7d56d2cb226b38c01743d7cb8a467ef Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 27 Jun 2025 15:40:04 +0200 Subject: [PATCH 104/124] extend test case scalar_broadcast_on_empty_branch --- .../feature_tests/ffront_tests/test_concat_where.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index d7eb667746..0ef99d6b50 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -86,15 +86,15 @@ def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case): @gtx.field_operator - def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField: - return concat_where(KDim == 0, a, b) + def testee(a: np.int32, b: cases.KField, N: np.int32) -> cases.KField: + return concat_where(KDim < N, a, b) a = 3 b = cases.allocate(cartesian_case, testee, "b")() - out = cases.allocate(cartesian_case, testee, cases.RETURN, domain=b.domain.slice_at[:, :, 1:])() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain=b.domain.slice_at[1:])() - ref = b.asnumpy()[:, :, 1:] - cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + ref = b.asnumpy()[1:] + cases.verify(cartesian_case, testee, a, b, 1, out=out, ref=ref) def test_concat_where_single_level_broadcast(cartesian_case): From 31da410ac94d0a4af50a1e5369cd852f25985480 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 27 Jun 2025 20:09:50 +0200 Subject: [PATCH 105/124] pre-commit - format code --- .../transforms/concat_where/transform_to_as_fieldop.py | 6 +++--- src/gt4py/next/iterator/type_system/type_synthesizer.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py index d5195c120b..1a9b2d13ef 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -69,9 +69,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: domains: tuple[domain_utils.SymbolicDomain, ...] = utils.flatten_nested_tuple( node.annex.domain ) - assert all( - domain == domains[0] for domain in domains - ), "At this point all `concat_where` arguments should be posed on the same domain." + assert all(domain == domains[0] for domain in domains), ( + "At this point all `concat_where` arguments should be posed on the same domain." + ) assert isinstance(domains[0], domain_utils.SymbolicDomain) domain_expr = domains[0].as_expr() diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 8a69bda9c0..ce99532645 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -282,9 +282,9 @@ def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.S tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) - assert ( - tb_dtype == fb_dtype - ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + assert tb_dtype == fb_dtype, ( + f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + ) dtype = tb_dtype return_dims = common.promote_dims( From 8c642b822e87feee6a79bf61bc9ff9a66658f1a0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 8 Jul 2025 16:29:43 +0200 Subject: [PATCH 106/124] address review comments --- .../next/iterator/ir_utils/domain_utils.py | 68 +++++++++++-------- .../transforms/concat_where/__init__.py | 8 +-- ...ent.py => canonicalize_domain_argument.py} | 20 +++--- .../concat_where/transform_to_as_fieldop.py | 13 ++-- .../transforms/fixed_point_transformation.py | 20 +++--- .../next/iterator/transforms/infer_domain.py | 4 +- ..._concat_where_canonicalize_domain_args.py} | 18 ++++- 7 files changed, 88 insertions(+), 63 deletions(-) rename src/gt4py/next/iterator/transforms/concat_where/{simplify_domain_argument.py => canonicalize_domain_argument.py} (83%) rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_concat_where_simplify_domain_args.py => test_concat_where_canonicalize_domain_args.py} (76%) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 9cab0bc9cd..7cd3fb9ce5 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Literal, Mapping, Optional +from typing import Any, Iterable, Literal, Mapping, Optional from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir @@ -49,6 +49,12 @@ def translate(self, distance: int) -> SymbolicRange: return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) +_GRID_TYPE_MAPPING = { + "unstructured_domain": common.GridType.UNSTRUCTURED, + "cartesian_domain": common.GridType.CARTESIAN, +} + + @dataclasses.dataclass(frozen=True) class SymbolicDomain: grid_type: common.GridType @@ -62,7 +68,6 @@ def __hash__(self) -> int: @classmethod def from_expr(cls, node: itir.Node) -> SymbolicDomain: assert cpm.is_call_to(node, ("unstructured_domain", "cartesian_domain")) - grid_type = getattr(common.GridType, node.fun.id[: -len("_domain")].upper()) ranges: dict[common.Dimension, SymbolicRange] = {} for named_range in node.args: @@ -73,7 +78,7 @@ def from_expr(cls, node: itir.Node) -> SymbolicDomain: ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( SymbolicRange(lower_bound, upper_bound) ) - return cls(grid_type, ranges) + return cls(_GRID_TYPE_MAPPING[node.fun.id], ranges) def as_expr(self) -> itir.FunCall: converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { @@ -200,35 +205,38 @@ def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: - """Return the (set) complement of a domain.""" + """ + Return the (set) complement of a half-infinite domain. + + Note: after canonicalization of concat_where, the domain is always half-infinite, + i.e. it has ranges of the form `]-inf, a[` or `[a, inf[`. + """ dims_dict = {} for dim in domain.ranges.keys(): lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop + assert (lb == itir.InfinityLiteral.NEGATIVE) != (ub == itir.InfinityLiteral.POSITIVE) # `]-inf, a[` -> `[a, inf[` if lb == itir.InfinityLiteral.NEGATIVE: dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral.POSITIVE) # `[a, inf]` -> `]-inf, a]` - elif ub == itir.InfinityLiteral.POSITIVE: + else: # ub == itir.InfinityLiteral.POSITIVE: dims_dict[dim] = SymbolicRange(start=itir.InfinityLiteral.NEGATIVE, stop=lb) - else: - raise ValueError("Invalid domain ranges") return SymbolicDomain(domain.grid_type, dims_dict) -def promote_to_same_dimensions( - domain_small: SymbolicDomain, domain_large: SymbolicDomain +def promote_domain( + domain: SymbolicDomain, target_dims: Iterable[common.Dimension] ) -> SymbolicDomain: - """Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions.""" + """Return a domain that is extended with the dimensions of target_dims.""" + assert set(domain.ranges.keys()).issubset(target_dims) dims_dict = {} - for dim in domain_large.ranges.keys(): - if dim in domain_small.ranges.keys(): - lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop - dims_dict[dim] = SymbolicRange(lb, ub) - else: - dims_dict[dim] = SymbolicRange( - itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE - ) - return SymbolicDomain(domain_small.grid_type, dims_dict) + for dim in target_dims: + dims_dict[dim] = ( + domain.ranges[dim] + if dim in domain.ranges + else SymbolicRange(itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE) + ) + return SymbolicDomain(domain.grid_type, dims_dict) def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: @@ -237,13 +245,15 @@ def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: The expression is required to be constant folded before for the result to be reliable. """ - if isinstance(range_ := range_or_domain, SymbolicRange): - if any( - v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] - for v in [range_.start, range_.stop] - ): - return False - return True - elif isinstance(domain := range_or_domain, SymbolicDomain): - return all(is_finite(range_) for range_ in domain.ranges.values()) - raise ValueError("Expected a 'SymbolicRange' or 'SymbolicDomain'.") + match range_or_domain: + case SymbolicRange() as range_: + if any( + v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] + for v in [range_.start, range_.stop] + ): + return False + return True + case SymbolicDomain() as domain: + return all(is_finite(range_) for range_ in domain.ranges.values()) + case _: + raise ValueError("Expected a 'SymbolicRange' or 'SymbolicDomain'.") diff --git a/src/gt4py/next/iterator/transforms/concat_where/__init__.py b/src/gt4py/next/iterator/transforms/concat_where/__init__.py index a9c3fb2576..31f6872aac 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/__init__.py +++ b/src/gt4py/next/iterator/transforms/concat_where/__init__.py @@ -6,13 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator.transforms.concat_where.expand_tuple_args import expand_tuple_args -from gt4py.next.iterator.transforms.concat_where.simplify_domain_argument import ( - simplify_domain_argument, +from gt4py.next.iterator.transforms.concat_where.canonicalize_domain_argument import ( + canonicalize_domain_argument, ) +from gt4py.next.iterator.transforms.concat_where.expand_tuple_args import expand_tuple_args from gt4py.next.iterator.transforms.concat_where.transform_to_as_fieldop import ( transform_to_as_fieldop, ) -__all__ = ["expand_tuple_args", "simplify_domain_argument", "transform_to_as_fieldop"] +__all__ = ["canonicalize_domain_argument", "expand_tuple_args", "transform_to_as_fieldop"] diff --git a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py b/src/gt4py/next/iterator/transforms/concat_where/canonicalize_domain_argument.py similarity index 83% rename from src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py rename to src/gt4py/next/iterator/transforms/concat_where/canonicalize_domain_argument.py index b626bcb5a1..371466047c 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/simplify_domain_argument.py +++ b/src/gt4py/next/iterator/transforms/concat_where/canonicalize_domain_argument.py @@ -30,9 +30,13 @@ def _range_complement( ) -class _SimplifyDomainArgument( +class _CanonicalizeDomainArgument( PreserveLocationVisitor, fixed_point_transformation.FixedPointTransformation ): + """ + TODO(tehrengruber): Explain why this is the canonical form. + """ + @classmethod def apply(cls, node: itir.Node): return cls().visit(node) @@ -42,26 +46,26 @@ def transform(self, node: itir.Node) -> Optional[itir.Node]: # type: ignore[ove cond_expr, field_a, field_b = node.args if cpm.is_call_to(cond_expr, "and_"): conds = cond_expr.args - return im.let(("__cwsda_field_a", field_a), ("__cwsda_field_b", field_b))( + return im.let(("__cwcda_field_a", field_a), ("__cwcda_field_b", field_b))( self.fp_transform( im.concat_where( conds[0], self.fp_transform( - im.concat_where(conds[1], "__cwsda_field_a", "__cwsda_field_b") + im.concat_where(conds[1], "__cwcda_field_a", "__cwcda_field_b") ), - "__cwsda_field_b", + "__cwcda_field_b", ) ) ) if cpm.is_call_to(cond_expr, "or_"): conds = cond_expr.args - return im.let(("__cwsda_field_a", field_a), ("__cwsda_field_b", field_b))( + return im.let(("__cwcda_field_a", field_a), ("__cwcda_field_b", field_b))( self.fp_transform( im.concat_where( conds[0], - "__cwsda_field_a", + "__cwcda_field_a", self.fp_transform( - im.concat_where(conds[1], "__cwsda_field_a", "__cwsda_field_b") + im.concat_where(conds[1], "__cwcda_field_a", "__cwcda_field_b") ), ) ) @@ -90,4 +94,4 @@ def transform(self, node: itir.Node) -> Optional[itir.Node]: # type: ignore[ove return None -simplify_domain_argument = _SimplifyDomainArgument.apply +canonicalize_domain_argument = _CanonicalizeDomainArgument.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py index d5195c120b..d46f338f9a 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -28,14 +28,13 @@ def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr: `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` """ - ret = [] - for i, v in enumerate(domain_utils.SymbolicDomain.from_expr(domain).ranges.values()): - ret.append( - im.and_( - im.less_equal(v.start, im.tuple_get(i, pos)), - im.less(im.tuple_get(i, pos), v.stop), - ) + ret = [ + im.and_( + im.less_equal(v.start, im.tuple_get(i, pos)), + im.less(im.tuple_get(i, pos), v.stop), ) + for i, v in enumerate(domain_utils.SymbolicDomain.from_expr(domain).ranges.values()) + ] return functools.reduce(im.and_, ret) diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index c9b253ed7a..58aa8912ad 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -35,26 +35,26 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: if new_node is None: break else: - new_node = self.post_transform(node, new_node) + new_node = self._post_transform(node, new_node) assert new_node != node node = new_node return node - def post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node: + def _post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node: if self.REINFER_TYPES: itir_type_inference.reinfer(new_node) self._preserve_annex(node, new_node) return new_node - """ - Transform node once. - - Execute transformation if applicable. When a transformation occurred the function will return - the transformed node. Note that the transformation itself may call other transformations on - child nodes again. - """ + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + """ + Transform node once. - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: ... + Execute transformation if applicable. When a transformation occurred the function will return + the transformed node. Note that the transformation itself may call other transformations on + child nodes again. + """ + ... @dataclasses.dataclass(frozen=True, kw_only=True) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index cd41d94451..5c21702d86 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -374,9 +374,9 @@ def _infer_concat_where( def mapper(d: NonTupleDomainAccess): if isinstance(d, DomainAccessDescriptor): return d - promoted_cond = domain_utils.promote_to_same_dimensions( + promoted_cond = domain_utils.promote_domain( symbolic_cond if arg == true_field else cond_complement, # noqa: B023 # function is never used outside the loop - d, + d.ranges.keys(), ) return domain_utils.domain_intersection(d, promoted_cond) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_canonicalize_domain_args.py similarity index 76% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_canonicalize_domain_args.py index beca1084b4..58086670f9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_simplify_domain_args.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_concat_where_canonicalize_domain_args.py @@ -20,7 +20,7 @@ field_type = ts.FieldType(dims=[IDim], dtype=int_type) -def test_data(): +def cases(): return [ # testee, expected ( @@ -45,12 +45,24 @@ def test_data(): ), ), ), + ( # not transformed + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, itir.InfinityLiteral.POSITIVE)}), + "a", + "b", + ), + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, itir.InfinityLiteral.POSITIVE)}), + "a", + "b", + ), + ), ] -@pytest.mark.parametrize("testee, expected", test_data()) +@pytest.mark.parametrize("testee, expected", cases()) def test_nested_concat_where(testee, expected): - actual = concat_where.simplify_domain_argument(testee) + actual = concat_where.canonicalize_domain_argument(testee) actual = inline_lambdas.InlineLambdas.apply(actual, opcount_preserving=True) assert actual == expected From 7877f6d3d88407ef70b55eb02952ee4d2f64ed44 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 10:38:43 +0200 Subject: [PATCH 107/124] add domain_utils tests --- .../next/iterator/ir_utils/domain_utils.py | 5 + .../ir_utils_test.py/test_domain_utils.py | 145 ++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 7cd3fb9ce5..e1d27c00e1 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -45,6 +45,11 @@ class SymbolicRange: start: itir.Expr stop: itir.Expr + def __post_init__(self) -> None: + # TODO(havogt): added this defensive checks as code seems to make this reasonable assumption + assert self.start is not itir.InfinityLiteral.POSITIVE + assert self.stop is not itir.InfinityLiteral.NEGATIVE + def translate(self, distance: int) -> SymbolicRange: return SymbolicRange(im.plus(self.start, distance), im.plus(self.stop, distance)) diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py new file mode 100644 index 0000000000..136b1aeaad --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -0,0 +1,145 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im + +I = common.Dimension("I") +J = common.Dimension("J") + +a_range = domain_utils.SymbolicRange(0, 10) +another_range = domain_utils.SymbolicRange(5, 15) +infinity_range = domain_utils.SymbolicRange( + itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE +) +# the two next ranges are complement of each other (which is used in tests) +right_infinity_range = domain_utils.SymbolicRange(0, itir.InfinityLiteral.POSITIVE) +left_infinity_range = domain_utils.SymbolicRange(itir.InfinityLiteral.NEGATIVE, 0) + + +def _make_domain(i: int): + return domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, + ranges={I: domain_utils.SymbolicRange(im.ref(f"start{i}"), im.ref(f"end{i}"))}, + ) + + +def test_symbolic_range(): + with pytest.raises(AssertionError): + domain_utils.SymbolicRange(itir.InfinityLiteral.POSITIVE, 0) + with pytest.raises(AssertionError): + domain_utils.SymbolicRange(0, itir.InfinityLiteral.NEGATIVE) + + +def test_domain_union(): + domain0 = _make_domain(0) + domain1 = _make_domain(1) + domain2 = _make_domain(2) + + expected = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, + ranges={ + I: domain_utils.SymbolicRange( + im.minimum(im.minimum(im.ref("start0"), im.ref("start1")), im.ref("start2")), + im.maximum(im.maximum(im.ref("end0"), im.ref("end1")), im.ref("end2")), + ) + }, + ) + assert expected == domain_utils.domain_union(domain0, domain1, domain2) + + +def test_domain_intersection(): + domain0 = _make_domain(0) + domain1 = _make_domain(1) + domain2 = _make_domain(2) + + expected = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, + ranges={ + I: domain_utils.SymbolicRange( + im.maximum(im.maximum(im.ref("start0"), im.ref("start1")), im.ref("start2")), + im.minimum(im.minimum(im.ref("end0"), im.ref("end1")), im.ref("end2")), + ) + }, + ) + assert expected == domain_utils.domain_intersection(domain0, domain1, domain2) + + +@pytest.mark.parametrize( + "ranges, expected", + [ + ({I: a_range}, None), + ({I: infinity_range}, None), + ({I: a_range, J: right_infinity_range}, None), + ( + {I: right_infinity_range, J: left_infinity_range}, + {I: left_infinity_range, J: right_infinity_range}, + ), + ], +) +def test_domain_complement(ranges, expected): + if expected is None: + with pytest.raises(AssertionError): + domain_utils.domain_complement( + domain_utils.SymbolicDomain(grid_type=common.GridType.CARTESIAN, ranges=ranges) + ) + else: + assert domain_utils.domain_complement( + domain_utils.SymbolicDomain(grid_type=common.GridType.CARTESIAN, ranges=ranges) + ) == domain_utils.SymbolicDomain(grid_type=common.GridType.CARTESIAN, ranges=expected) + + +@pytest.mark.parametrize( + "testee_ranges, dimensions, expected_ranges", + [ + ({I: a_range}, [I, J], {I: a_range, J: infinity_range}), + ({I: a_range}, [J, I], {I: a_range, J: infinity_range}), + ({I: a_range}, [I], {I: a_range}), + ({I: a_range}, [J], None), + ({I: a_range, J: another_range}, [J], None), + ], +) +def test_promote_domain(testee_ranges, dimensions, expected_ranges): + testee = domain_utils.SymbolicDomain(grid_type=common.GridType.CARTESIAN, ranges=testee_ranges) + if expected_ranges is None: + with pytest.raises(AssertionError): + domain_utils.promote_domain(testee, dimensions) + else: + expected = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, ranges=expected_ranges + ) + promoted = domain_utils.promote_domain(testee, dimensions) + assert promoted == expected + + +def test_is_finite_symbolic_range(): + assert not domain_utils.is_finite(infinity_range) + assert not domain_utils.is_finite(left_infinity_range) + assert not domain_utils.is_finite(right_infinity_range) + assert domain_utils.is_finite(a_range) + + +@pytest.mark.parametrize( + "ranges, expected", + [ + ({I: a_range, J: a_range}, True), + ({I: a_range, J: another_range}, True), + ({I: right_infinity_range, J: a_range}, False), + ({I: a_range, J: right_infinity_range}, False), + ], +) +def test_is_finite_symbolic_domain(ranges, expected): + assert ( + domain_utils.is_finite( + domain_utils.SymbolicDomain(grid_type=common.GridType.CARTESIAN, ranges=ranges) + ) + == expected + ) From 086910ccf841b715738af3ad46335dc5ae902976 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 10:47:07 +0200 Subject: [PATCH 108/124] refactor domain ops --- .../next/iterator/ir_utils/domain_utils.py | 75 ++++++++++--------- .../ir_utils_test.py/test_domain_utils.py | 18 +++++ 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index e1d27c00e1..3daf3af6bc 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Iterable, Literal, Mapping, Optional +from typing import Any, Callable, Iterable, Literal, Mapping, Optional from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir @@ -63,9 +63,7 @@ def translate(self, distance: int) -> SymbolicRange: @dataclasses.dataclass(frozen=True) class SymbolicDomain: grid_type: common.GridType - ranges: dict[ - common.Dimension, SymbolicRange - ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere + ranges: dict[common.Dimension, SymbolicRange] def __hash__(self) -> int: return hash((self.grid_type, frozenset(self.ranges.items()))) @@ -168,47 +166,52 @@ def translate( raise AssertionError("Number of shifts must be a multiple of 2.") -def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: - """Return the (set) union of a list of domains.""" - new_domain_ranges = {} +def _range_op( + start_op: Callable[[itir.Expr, itir.Expr], itir.Expr], + stop_op: Callable[[itir.Expr, itir.Expr], itir.Expr], + *ranges: SymbolicRange, +) -> SymbolicRange: + """Uses start_op and stop_op to fold the start and stop of a list of ranges.""" + start = functools.reduce( + lambda current_expr, el_expr: start_op(current_expr, el_expr), + [range_.start for range_ in ranges], + ) + stop = functools.reduce( + lambda current_expr, el_expr: stop_op(current_expr, el_expr), + [range_.stop for range_ in ranges], + ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr + return SymbolicRange(start, stop) + + +_range_union = functools.partial(_range_op, im.minimum, im.maximum) +_range_intersection = functools.partial(_range_op, im.maximum, im.minimum) + + +def _domain_op( + range_op: Callable[..., SymbolicRange], + *domains: SymbolicDomain, +) -> SymbolicDomain: + """ + Applies range_op to the ranges of a list of domains with same dimensions and grid_type. + """ assert all(domain.grid_type == domains[0].grid_type for domain in domains) assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) - for dim in domains[0].ranges.keys(): - start = functools.reduce( - lambda current_expr, el_expr: im.minimum(current_expr, el_expr), - [domain.ranges[dim].start for domain in domains], - ) - stop = functools.reduce( - lambda current_expr, el_expr: im.maximum(current_expr, el_expr), - [domain.ranges[dim].stop for domain in domains], - ) - # constant fold expression to keep the tree small - start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr - new_domain_ranges[dim] = SymbolicRange(start, stop) - return SymbolicDomain(domains[0].grid_type, new_domain_ranges) - - -def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain: - """Return the (set) intersection of a list of domains.""" new_domain_ranges = {} - assert all(domain.grid_type == domains[0].grid_type for domain in domains) for dim in domains[0].ranges.keys(): - start = functools.reduce( - lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), - [domain.ranges[dim].start for domain in domains], - ) - stop = functools.reduce( - lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), - [domain.ranges[dim].stop for domain in domains], - ) - # constant fold expression to keep the tree small - start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr - new_domain_ranges[dim] = SymbolicRange(start, stop) + new_domain_ranges[dim] = range_op(*[domain.ranges[dim] for domain in domains]) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) +domain_union = functools.partial(_domain_op, _range_union) +"""Return the (set) union of a list of domains.""" +domain_intersection = functools.partial(_domain_op, _range_intersection) +"""Return the intersection of a list of domains.""" + + def domain_complement(domain: SymbolicDomain) -> SymbolicDomain: """ Return the (set) complement of a half-infinite domain. diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index 136b1aeaad..04544dd220 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -39,6 +39,24 @@ def test_symbolic_range(): domain_utils.SymbolicRange(0, itir.InfinityLiteral.NEGATIVE) +def test_domain_op_preconditions(): + domain_a = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, + ranges={I: domain_utils.SymbolicRange(0, 10)}, + ) + domain_b = domain_utils.SymbolicDomain( + grid_type=common.GridType.CARTESIAN, + ranges={J: domain_utils.SymbolicRange(5, 15)}, + ) + with pytest.raises(AssertionError): + domain_utils._domain_op(domain_utils._range_union, domain_a, domain_b) + domain_c = domain_utils.SymbolicDomain( + grid_type=common.GridType.UNSTRUCTURED, ranges={I: domain_utils.SymbolicRange(0, 10)} + ) + with pytest.raises(AssertionError): + domain_utils._domain_op(domain_utils._range_union, domain_a, domain_c) + + def test_domain_union(): domain0 = _make_domain(0) domain1 = _make_domain(1) From 58a0492419d5d0f14ba56177dc133467a7daf5a5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 10:57:23 +0200 Subject: [PATCH 109/124] fix formatting --- .../transforms/concat_where/transform_to_as_fieldop.py | 6 +++--- src/gt4py/next/iterator/type_system/type_synthesizer.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py index d46f338f9a..a693770ad8 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -68,9 +68,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: domains: tuple[domain_utils.SymbolicDomain, ...] = utils.flatten_nested_tuple( node.annex.domain ) - assert all( - domain == domains[0] for domain in domains - ), "At this point all `concat_where` arguments should be posed on the same domain." + assert all(domain == domains[0] for domain in domains), ( + "At this point all `concat_where` arguments should be posed on the same domain." + ) assert isinstance(domains[0], domain_utils.SymbolicDomain) domain_expr = domains[0].as_expr() diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 8a69bda9c0..ce99532645 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -282,9 +282,9 @@ def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.S tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) - assert ( - tb_dtype == fb_dtype - ), f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + assert tb_dtype == fb_dtype, ( + f"Field arguments must be of same dtype, got '{tb_dtype}' != '{fb_dtype}'." + ) dtype = tb_dtype return_dims = common.promote_dims( From 81b9309e619483c7986c78147640dcbd72fbca1e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 11:17:52 +0200 Subject: [PATCH 110/124] add type inference test --- .../unit_tests/iterator_tests/test_type_inference.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 0589463777..a0361e7ba2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -276,6 +276,14 @@ def expression_test_cases(): ), ts.TupleType(types=[float_i_field] * 2), ), + ( + im.concat_where( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 1)}), + im.ref("a", ts.TupleType(types=[float_i_field, float_ij_field])), + im.ref("b", ts.TupleType(types=[float_i_field] * 2)), + ), + ts.TupleType(types=[float_i_field, float_ij_field]), + ), ) From a176174674bb85e3ac8c3bb4dff56d12fb51165c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 11:43:32 +0200 Subject: [PATCH 111/124] cleanup --- .../next/iterator/ir_utils/domain_utils.py | 28 ++++++++------ .../ir_utils_test.py/test_domain_utils.py | 37 ++++++++++++++----- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 3daf3af6bc..e4c09d0564 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -166,18 +166,18 @@ def translate( raise AssertionError("Number of shifts must be a multiple of 2.") -def _range_op( - start_op: Callable[[itir.Expr, itir.Expr], itir.Expr], - stop_op: Callable[[itir.Expr, itir.Expr], itir.Expr], +def _reduce_ranges( *ranges: SymbolicRange, + start_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr], + stop_reduce_op: Callable[[itir.Expr, itir.Expr], itir.Expr], ) -> SymbolicRange: """Uses start_op and stop_op to fold the start and stop of a list of ranges.""" start = functools.reduce( - lambda current_expr, el_expr: start_op(current_expr, el_expr), + lambda current_expr, el_expr: start_reduce_op(current_expr, el_expr), [range_.start for range_ in ranges], ) stop = functools.reduce( - lambda current_expr, el_expr: stop_op(current_expr, el_expr), + lambda current_expr, el_expr: stop_reduce_op(current_expr, el_expr), [range_.stop for range_ in ranges], ) # constant fold expression to keep the tree small @@ -185,13 +185,17 @@ def _range_op( return SymbolicRange(start, stop) -_range_union = functools.partial(_range_op, im.minimum, im.maximum) -_range_intersection = functools.partial(_range_op, im.maximum, im.minimum) +_range_union = functools.partial( + _reduce_ranges, start_reduce_op=im.minimum, stop_reduce_op=im.maximum +) +_range_intersection = functools.partial( + _reduce_ranges, start_reduce_op=im.maximum, stop_reduce_op=im.minimum +) -def _domain_op( - range_op: Callable[..., SymbolicRange], +def _reduce_domains( *domains: SymbolicDomain, + range_reduce_op: Callable[..., SymbolicRange], ) -> SymbolicDomain: """ Applies range_op to the ranges of a list of domains with same dimensions and grid_type. @@ -201,14 +205,14 @@ def _domain_op( new_domain_ranges = {} for dim in domains[0].ranges.keys(): - new_domain_ranges[dim] = range_op(*[domain.ranges[dim] for domain in domains]) + new_domain_ranges[dim] = range_reduce_op(*[domain.ranges[dim] for domain in domains]) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) -domain_union = functools.partial(_domain_op, _range_union) +domain_union = functools.partial(_reduce_domains, range_reduce_op=_range_union) """Return the (set) union of a list of domains.""" -domain_intersection = functools.partial(_domain_op, _range_intersection) +domain_intersection = functools.partial(_reduce_domains, range_reduce_op=_range_intersection) """Return the intersection of a list of domains.""" diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index 04544dd220..69a2ed772b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -28,7 +28,10 @@ def _make_domain(i: int): return domain_utils.SymbolicDomain( grid_type=common.GridType.CARTESIAN, - ranges={I: domain_utils.SymbolicRange(im.ref(f"start{i}"), im.ref(f"end{i}"))}, + ranges={ + I: domain_utils.SymbolicRange(im.ref(f"start_I_{i}"), im.ref(f"end_I_{i}")), + J: domain_utils.SymbolicRange(im.ref(f"start_J_{i}"), im.ref(f"end_J_{i}")), + }, ) @@ -49,12 +52,12 @@ def test_domain_op_preconditions(): ranges={J: domain_utils.SymbolicRange(5, 15)}, ) with pytest.raises(AssertionError): - domain_utils._domain_op(domain_utils._range_union, domain_a, domain_b) + domain_utils._reduce_domains(domain_a, domain_b, range_reduce_op=domain_utils._range_union) domain_c = domain_utils.SymbolicDomain( grid_type=common.GridType.UNSTRUCTURED, ranges={I: domain_utils.SymbolicRange(0, 10)} ) with pytest.raises(AssertionError): - domain_utils._domain_op(domain_utils._range_union, domain_a, domain_c) + domain_utils._reduce_domains(domain_a, domain_c, range_reduce_op=domain_utils._range_union) def test_domain_union(): @@ -66,9 +69,17 @@ def test_domain_union(): grid_type=common.GridType.CARTESIAN, ranges={ I: domain_utils.SymbolicRange( - im.minimum(im.minimum(im.ref("start0"), im.ref("start1")), im.ref("start2")), - im.maximum(im.maximum(im.ref("end0"), im.ref("end1")), im.ref("end2")), - ) + im.minimum( + im.minimum(im.ref("start_I_0"), im.ref("start_I_1")), im.ref("start_I_2") + ), + im.maximum(im.maximum(im.ref("end_I_0"), im.ref("end_I_1")), im.ref("end_I_2")), + ), + J: domain_utils.SymbolicRange( + im.minimum( + im.minimum(im.ref("start_J_0"), im.ref("start_J_1")), im.ref("start_J_2") + ), + im.maximum(im.maximum(im.ref("end_J_0"), im.ref("end_J_1")), im.ref("end_J_2")), + ), }, ) assert expected == domain_utils.domain_union(domain0, domain1, domain2) @@ -83,9 +94,17 @@ def test_domain_intersection(): grid_type=common.GridType.CARTESIAN, ranges={ I: domain_utils.SymbolicRange( - im.maximum(im.maximum(im.ref("start0"), im.ref("start1")), im.ref("start2")), - im.minimum(im.minimum(im.ref("end0"), im.ref("end1")), im.ref("end2")), - ) + im.maximum( + im.maximum(im.ref("start_I_0"), im.ref("start_I_1")), im.ref("start_I_2") + ), + im.minimum(im.minimum(im.ref("end_I_0"), im.ref("end_I_1")), im.ref("end_I_2")), + ), + J: domain_utils.SymbolicRange( + im.maximum( + im.maximum(im.ref("start_J_0"), im.ref("start_J_1")), im.ref("start_J_2") + ), + im.minimum(im.minimum(im.ref("end_J_0"), im.ref("end_J_1")), im.ref("end_J_2")), + ), }, ) assert expected == domain_utils.domain_intersection(domain0, domain1, domain2) From 6a1087d15cef247e822208faa38b2244fe913522 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 12:51:06 +0200 Subject: [PATCH 112/124] delete an obsolete assert --- .../program_processors/runners/dace/gtir_builtin_translators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py index ac7df22ccc..3de4bd87b4 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -523,7 +523,6 @@ def parse_range_boundary(expr: gtir.Expr) -> str: domain.append((dim, lower_bound, upper_bound)) elif isinstance(node, domain_utils.SymbolicDomain): - assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} for dim, drange in node.ranges.items(): domain.append( (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) From 3678a389b35974420737419dced30e11a4adbf5f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 13:27:19 +0200 Subject: [PATCH 113/124] remove embedded implementation --- src/gt4py/next/common.py | 53 -------- src/gt4py/next/embedded/nd_array_field.py | 147 +++++++++++----------- 2 files changed, 70 insertions(+), 130 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 30726ee862..dc6f24e9dd 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -109,41 +109,6 @@ def __add__(self, offset: int) -> Connectivity: def __sub__(self, offset: int) -> Connectivity: return self + (-offset) - def __gt__(self, value: core_defs.IntegralScalar) -> Domain: - return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),)) - - def __ge__(self, value: core_defs.IntegralScalar) -> Domain: - return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),)) - - def __lt__(self, value: core_defs.IntegralScalar) -> Domain: - return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) - - def __le__(self, value: core_defs.IntegralScalar) -> Domain: - # TODO add test - return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),)) - - def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain: - if isinstance(value, Dimension): - return self.value == value.value - elif isinstance(value, core_defs.INTEGRAL_TYPES): - # TODO probably only within valid embedded context? - return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),)) - else: - return False - - def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | tuple[Domain, Domain]: - # TODO add test - if isinstance(value, Dimension): - return self.value != value.value - elif isinstance(value, core_defs.INTEGRAL_TYPES): - # TODO probably only within valid embedded context? - return ( - Domain(self, UnitRange(Infinity.NEGATIVE, value)), - Domain(self, UnitRange(value + 1, Infinity.POSITIVE)), - ) - else: - return True - class Infinity(enum.Enum): """Describes an unbounded `UnitRange`.""" @@ -535,24 +500,6 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) - def __or__(self, other: Domain) -> Domain: - # TODO support arbitrary union of domains - # TODO add tests - if self.ndim > 1 or other.ndim > 1: - raise NotImplementedError("Union of multidimensional domains is not supported.") - if self.ndim == 0: - return other - if other.ndim == 0: - return self - sorted_ = sorted((self, other), key=lambda x: x.ranges[0].start) - if sorted_[0].ranges[0].stop >= sorted_[1].ranges[0].start: - return Domain( - dims=(self.dims[0],), - ranges=(UnitRange(sorted_[0].ranges[0].start, sorted_[1].ranges[0].stop),), - ) - else: - return (sorted_[0], sorted_[1]) - @functools.cached_property def slice_at(self) -> utils.IndexerCallable[slice, Domain]: """ diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index a2101e6c99..25ce060c7c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -810,6 +810,25 @@ def _hyperslice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) +def _compute_mask_slices( + mask: core_defs.NDArrayObject, +) -> list[tuple[bool, slice]]: + """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" + # TODO: does it make sense to upgrade this naive algorithm to numpy? + assert mask.ndim == 1 + cur = bool(mask[0].item()) + ind = 0 + res = [] + for i in range(1, mask.shape[0]): + # Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy + if (mask_i := bool(mask[i].item())) != cur: + res.append((cur, slice(ind, i))) + cur = mask_i + ind = i + res.append((cur, slice(ind, mask.shape[0]))) + return res + + def _trim_empty_domains( lst: Iterable[tuple[bool, common.Domain]], ) -> list[tuple[bool, common.Domain]]: @@ -877,108 +896,82 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: # TODO(havogt): this function could be extended to a general concat - # currently only concatenate along the given dimension - sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start) + # currently only concatenate along the given dimension and requires the fields to be ordered if ( - len(sorted_fields) > 1 - and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty() + len(fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() ): raise ValueError("Fields to concatenate must not overlap.") - new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim) + new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) if new_domain is None: raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") - nd_array_class = _get_nd_array_class(*sorted_fields) + nd_array_class = _get_nd_array_class(*fields) return nd_array_class.from_array( nd_array_class.array_ns.concatenate( - [ - nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) - for f in sorted_fields - ], + [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], axis=new_domain.dim_index(dim, allow_missing=False), ), domain=new_domain, ) -def _invert_domain( - domains: common.Domain | tuple[common.Domain], -) -> common.Domain | tuple[common.Domain, ...]: - if not isinstance(domains, tuple): - domains = (domains,) - - assert all(d.ndim == 1 for d in domains) - dim = domains[0].dims[0] - assert all(d.dims[0] == dim for d in domains) - sorted_domains = sorted(domains, key=lambda d: d.ranges[0].start) - - result = [] - if domains[0].ranges[0].start is not common.Infinity.NEGATIVE: - result.append( - common.Domain( - dims=(dim,), - ranges=(common.UnitRange(common.Infinity.NEGATIVE, domains[0].ranges[0].start),), - ) - ) - for i in range(len(sorted_domains) - 1): - if sorted_domains[i].ranges[0].stop != sorted_domains[i + 1].ranges[0].start: - result.append( - common.Domain( - dims=(dim,), - ranges=( - common.UnitRange( - sorted_domains[i].ranges[0].stop, sorted_domains[i + 1].ranges[0].start - ), - ), - ) - ) - if domains[-1].ranges[0].stop is not common.Infinity.POSITIVE: - result.append( - common.Domain( - dims=(dim,), - ranges=(common.UnitRange(domains[-1].ranges[0].stop, common.Infinity.POSITIVE),), - ) - ) - return tuple(result) - - -def _intersect_multiple( - domain: common.Domain, domains: common.Domain | tuple[common.Domain] -) -> tuple[common.Domain, ...]: - if not isinstance(domains, tuple): - domains = (domains,) - - return tuple( - intersection - for d in domains - if not (intersection := embedded_common.domain_intersection(domain, d)).is_empty() - ) - - def _concat_where( - masks: common.Domain | tuple[common.Domain, ...], - true_field: common.Field, - false_field: common.Field, + mask_field: common.Field, true_field: common.Field, false_field: common.Field ) -> common.Field: - if not isinstance(masks, tuple): - masks = (masks,) - if any(m.ndim for m in masks) != 1: + cls_ = _get_nd_array_class(mask_field, true_field, false_field) + xp = cls_.array_ns + if mask_field.domain.ndim != 1: raise NotImplementedError( "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) - mask_dim = masks[0].dims[0] + mask_dim = mask_field.domain.dims[0] # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) - true_domains = _intersect_multiple(t_broadcasted.domain, masks) - t_slices = tuple(t_broadcasted[d] for d in true_domains) + # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils + # compute the consecutive ranges (first relative, then domain) of true and false values + mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices( + mask_field.ndarray + ) + mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( + (mask, mask_field.domain.slice_at[domain_slice]) + for mask, domain_slice in mask_values_to_slices_mapping + ) + # mask domains intersected with the respective fields + mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( + ( + mask_value, + embedded_common.domain_intersection( + t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain + ), + ) + for mask_value, mask_domain in mask_values_to_domain_mapping + ) + + # remove the empty domains from the beginning and end + mask_values_to_intersected_domains_mapping = _trim_empty_domains( + mask_values_to_intersected_domains_mapping + ) + if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): + raise embedded_exceptions.NonContiguousDomain( + f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." + ) - inverted_masks = _invert_domain(masks) - false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks) - f_slices = tuple(f_broadcasted[d] for d in false_domains) + # slice the fields with the domain ranges + transformed = [ + t_broadcasted[d] if v else f_broadcasted[d] + for v, d in mask_values_to_intersected_domains_mapping + ] - return _concat(*f_slices, *t_slices, dim=mask_dim) + # stack the fields together + if transformed: + return _concat(*transformed, dim=mask_dim) + else: + result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) + result_array = xp.empty(result_domain.shape) + return cls_.from_array(result_array, domain=result_domain) NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR From 94ef41ae34b72842325632d33fe426bd07b93af9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 13:47:11 +0200 Subject: [PATCH 114/124] address review comments --- .../next/iterator/ir_utils/domain_utils.py | 27 ++++++++----------- .../next/iterator/transforms/infer_domain.py | 3 ++- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index e4c09d0564..52853899c4 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -203,15 +203,15 @@ def _reduce_domains( assert all(domain.grid_type == domains[0].grid_type for domain in domains) assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) - new_domain_ranges = {} - for dim in domains[0].ranges.keys(): - new_domain_ranges[dim] = range_reduce_op(*[domain.ranges[dim] for domain in domains]) + dims = domains[0].ranges.keys() + new_domain_ranges = {dim: range_reduce_op(*(d.ranges[dim] for d in domains)) for dim in dims} return SymbolicDomain(domains[0].grid_type, new_domain_ranges) domain_union = functools.partial(_reduce_domains, range_reduce_op=_range_union) """Return the (set) union of a list of domains.""" + domain_intersection = functools.partial(_reduce_domains, range_reduce_op=_range_intersection) """Return the intersection of a list of domains.""" @@ -241,13 +241,12 @@ def promote_domain( ) -> SymbolicDomain: """Return a domain that is extended with the dimensions of target_dims.""" assert set(domain.ranges.keys()).issubset(target_dims) - dims_dict = {} - for dim in target_dims: - dims_dict[dim] = ( - domain.ranges[dim] - if dim in domain.ranges - else SymbolicRange(itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE) - ) + dims_dict = { + dim: domain.ranges[dim] + if dim in domain.ranges + else SymbolicRange(itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE) + for dim in target_dims + } return SymbolicDomain(domain.grid_type, dims_dict) @@ -259,12 +258,8 @@ def is_finite(range_or_domain: SymbolicRange | SymbolicDomain) -> bool: """ match range_or_domain: case SymbolicRange() as range_: - if any( - v in [itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE] - for v in [range_.start, range_.stop] - ): - return False - return True + infinity_literals = (itir.InfinityLiteral.POSITIVE, itir.InfinityLiteral.NEGATIVE) + return not (range_.start in infinity_literals or range_.stop in infinity_literals) case SymbolicDomain() as domain: return all(is_finite(range_) for range_ in domain.ranges.values()) case _: diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 61b1883b17..c22b775468 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -454,11 +454,12 @@ def infer_expr( Arguments: - expr: The expression to be inferred. - domain: The domain `expr` is read at. + + Keyword Arguments: - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. - actually access domain) - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and use them to propagate the domain further. This is useful in cases where after a transformation some nodes are missing domain information that needs to be repopulated, From de19ee92302ab11965febc11361d5034ab04a281 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 09:28:58 +0200 Subject: [PATCH 115/124] fix merge conflict --- src/gt4py/next/iterator/ir_utils/misc.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 863fe5ef68..00ff9abbd9 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -253,21 +253,3 @@ def unique_symbol(sym: SymOrStr, reserved_names: Iterable[str]) -> SymOrStr: name = name + "_" return name - - -def grid_type_from_domain(domain: itir.FunCall) -> common.GridType: - if cpm.is_call_to(domain, "cartesian_domain"): - return common.GridType.CARTESIAN - else: - assert cpm.is_call_to(domain, "unstructured_domain") - return common.GridType.UNSTRUCTURED - - -def grid_type_from_program(program: itir.Program) -> common.GridType: - domains = program.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() - grid_types = {grid_type_from_domain(d) for d in domains} - if len(grid_types) != 1: - raise ValueError( - f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." - ) - return grid_types.pop() From 0feead0ab9b4669438b4a5003866ebdca0d72aa1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 09:29:30 +0200 Subject: [PATCH 116/124] document some tests --- .../feature_tests/ffront_tests/test_concat_where.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 0ef99d6b50..691810c9ef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -51,6 +51,8 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: def test_concat_where_non_overlapping(cartesian_case): + """Fields only defined in their respective region in concat_where.""" + @gtx.field_operator def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: return concat_where(KDim == 0, ground, air) @@ -85,6 +87,8 @@ def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case): + """Output domain such that the scalar branch is never active.""" + @gtx.field_operator def testee(a: np.int32, b: cases.KField, N: np.int32) -> cases.KField: return concat_where(KDim < N, a, b) @@ -224,7 +228,7 @@ def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> ca cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) -def test_dimension_two_conditions_eq(cartesian_case): +def test_dimension_eq_in_middle_of_domain(cartesian_case): @gtx.field_operator def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: return concat_where((KDim == 2), interior, boundary) From 7ba38d536e247ef6f64dac9370235b9b747ef041 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 09:35:21 +0200 Subject: [PATCH 117/124] fix test structure in constant folding --- .../transforms_tests/test_constant_folding.py | 152 ++++++++---------- 1 file changed, 65 insertions(+), 87 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index bbc6433348..a56b539014 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -193,6 +193,71 @@ def test_value_from_literal(value, expected): im.plus(im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a")), ), ), + # InfinityLiteral folding + ( + im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + itir.InfinityLiteral.POSITIVE, + ), + ( + im.call("maximum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + itir.InfinityLiteral.POSITIVE, + ), + ( + im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + im.literal_from_value(1), + ), + ( + im.call("maximum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + im.literal_from_value(1), + ), + ( + im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + im.literal_from_value(1), + ), + ( + im.call("minimum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + im.literal_from_value(1), + ), + ( + im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + itir.InfinityLiteral.NEGATIVE, + ), + ( + im.call("minimum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + itir.InfinityLiteral.NEGATIVE, + ), + ( + im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + im.literal_from_value(False), + ), + ( + im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + im.literal_from_value(True), + ), + ( + im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + im.literal_from_value(True), + ), + ( + im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + im.literal_from_value(False), + ), + ( + im.call("greater")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + im.literal_from_value(True), + ), + ( + im.call("greater")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + im.literal_from_value(False), + ), + ( + im.call("less")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + im.literal_from_value(False), + ), + ( + im.call("less")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + im.literal_from_value(True), + ), ), ids=lambda x: str(x[0]), ) @@ -200,90 +265,3 @@ def test_constant_folding(test_case): testee, expected = test_case actual = constant_folding.ConstantFolding.apply(testee) assert actual == im.ensure_expr(expected) - - -# TODO: integrate into test structure above -def test_constant_folding_inf_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) - expected = itir.InfinityLiteral.POSITIVE - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("maximum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) - expected = itir.InfinityLiteral.POSITIVE - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) - expected = im.literal_from_value(1) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("maximum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) - expected = im.literal_from_value(1) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_inf_minimum(): - testee = im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) - expected = im.literal_from_value(1) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("minimum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) - expected = im.literal_from_value(1) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) - expected = itir.InfinityLiteral.NEGATIVE - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("minimum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) - expected = itir.InfinityLiteral.NEGATIVE - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_greater_less(): - testee = im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) - expected = im.literal_from_value(False) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) - expected = im.literal_from_value(True) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE) - expected = im.literal_from_value(True) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE) - expected = im.literal_from_value(False) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("greater")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) - expected = im.literal_from_value(True) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("greater")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) - expected = im.literal_from_value(False) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("less")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)) - expected = im.literal_from_value(False) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected - - testee = im.call("less")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)) - expected = im.literal_from_value(True) - actual = constant_folding.ConstantFolding.apply(testee) - assert actual == expected From 4722c0507bbe23a5c414d76d2d80b968db3f0fca Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 09:48:45 +0200 Subject: [PATCH 118/124] remove resolved todos --- .../transforms_tests/test_domain_inference.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 0a090dc525..0d9a55ceef 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,7 +8,7 @@ # TODO(SF-N): test scan operator -from typing import Iterable, Literal, Optional, Union +from typing import Iterable, Literal, Optional import numpy as np import pytest @@ -26,6 +26,7 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) @@ -1250,10 +1251,6 @@ def test_concat_where(offset_provider): assert expected_domains == constant_fold_accessed_domains(actual_domains) -# Todo: 2 dimensional test with cond im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 4)}) -# Todo: nested concat wheres - - def test_concat_where_two_dimensions(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 10)}) From 06a70e32900e6470052b196ed23f2d21ee81b837 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 12:33:36 +0200 Subject: [PATCH 119/124] refactorings --- src/gt4py/next/ffront/foast_to_gtir.py | 3 -- src/gt4py/next/iterator/builtins.py | 6 --- src/gt4py/next/iterator/embedded.py | 5 --- .../concat_where/transform_to_as_fieldop.py | 4 +- .../iterator/transforms/constant_folding.py | 43 ++++++++++--------- 5 files changed, 25 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 98e47e31ae..aabcc062bc 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -408,9 +408,6 @@ def create_if( def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: domain, true_branch, false_branch = self.visit(node.args, **kwargs) - # TODO: use this case again. breaks domain inference in fused_velocity_advection_stencil_1_to_7 - # because some tuple elements are never accessed and the collapse tuple - # does not propagate across concat where return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index d2bcded0cc..e3f45f6c74 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -412,11 +412,6 @@ def concat_where(*args): raise BackendNotSelectedError() -@builtin_dispatch -def in_(*args): - raise BackendNotSelectedError() - - UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -505,7 +500,6 @@ def in_(*args): "tuple_get", "unstructured_domain", "concat_where", - "in_", *ARITHMETIC_BUILTINS, *TYPE_BUILTINS, } diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index d7c455763e..3888ccf2de 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1807,11 +1807,6 @@ def concat_where(*args): raise NotImplementedError("To be implemented in frontend embedded.") -@builtins.in_.register(EMBEDDED) -def in_(*args): - raise NotImplementedError("To be implemented in frontend embedded.") - - def closure( domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py index a693770ad8..108488add6 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -25,8 +25,8 @@ def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr: """ Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain. - `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` - -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` + pos = `{i, j, k}`, domain = `u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` + -> `((i0 <= i) & (i < i1)) & ((j0 <= j) & (j < j1)) & ((k0 <= k)l & (k < k1))` """ ret = [ im.and_( diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index bfb8378be6..48653ba5b5 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -244,36 +244,36 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.Node]: if cpm.is_call_to(node, "plus"): - for arg in node.args: - # `a + inf` -> `inf` - if arg == ir.InfinityLiteral.POSITIVE: - return ir.InfinityLiteral.POSITIVE - # `a + (-inf)` -> `-inf` - if arg == ir.InfinityLiteral.NEGATIVE: - return ir.InfinityLiteral.NEGATIVE + # `a + +/-inf` -> `+/-inf` + a, b = node.args + assert not (isinstance(a, ir.InfinityLiteral) and isinstance(b, ir.InfinityLiteral)) + for arg in a, b: + if isinstance(arg, ir.InfinityLiteral): + return arg if cpm.is_call_to(node, "minimum"): - a, b = node.args - for arg, other_arg in ((a, b), (b, a)): - # `minimum(inf, a)` -> `a` - if arg == ir.InfinityLiteral.POSITIVE: - return other_arg + if ir.InfinityLiteral.NEGATIVE in node.args: # `minimum(-inf, a)` -> `-inf` - if arg == ir.InfinityLiteral.NEGATIVE: - return ir.InfinityLiteral.NEGATIVE + return ir.InfinityLiteral.NEGATIVE + if ir.InfinityLiteral.POSITIVE in node.args: + # `minimum(inf, a)` -> `a` + a, b = node.args + return b if a == ir.InfinityLiteral.POSITIVE else a if cpm.is_call_to(node, "maximum"): - a, b = node.args - for arg, other_arg in ((a, b), (b, a)): + if ir.InfinityLiteral.POSITIVE in node.args: # `maximum(inf, a)` -> `inf` - if arg == ir.InfinityLiteral.POSITIVE: - return ir.InfinityLiteral.POSITIVE + return ir.InfinityLiteral.POSITIVE + if ir.InfinityLiteral.NEGATIVE in node.args: # `maximum(-inf, a)` -> `a` - if arg == ir.InfinityLiteral.NEGATIVE: - return other_arg + a, b = node.args + return b if a == ir.InfinityLiteral.NEGATIVE else a if cpm.is_call_to(node, ("less", "less_equal")): a, b = node.args + # we don't handle `inf < inf` or `-inf < -inf`.args + assert a != b or not isinstance(a, ir.InfinityLiteral) + # `-inf < v` -> `True` # `v < inf` -> `True` if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: @@ -285,6 +285,9 @@ def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.No if cpm.is_call_to(node, ("greater", "greater_equal")): a, b = node.args + # we don't handle `inf > inf` or `-inf > -inf`.args + assert a != b or not isinstance(a, ir.InfinityLiteral) + # `inf > v` -> `True` # `v > -inf ` -> `True` if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: From 9a116c16df8515b542dbc8f2616d02583230a3f6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 23:53:52 +0200 Subject: [PATCH 120/124] cleanup test --- .../feature_tests/ffront_tests/test_concat_where.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 691810c9ef..7be2ad6999 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -259,19 +259,17 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: def test_lap_like(cartesian_case): @gtx.field_operator def testee( - input: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] + inp: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] ) -> cases.IJField: # TODO add support for multi-dimensional concat_where masks return concat_where( (IDim == 0) | (IDim == shape[0] - 1), boundary, - concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, input), + concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, inp), ) out = cases.allocate(cartesian_case, testee, cases.RETURN)() - input = cases.allocate( - cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1] - )() + inp = cases.allocate(cartesian_case, testee, "inp", domain=out.domain.slice_at[1:-1, 1:-1])() boundary = 2 ref = np.full(out.domain.shape, np.nan) @@ -279,8 +277,8 @@ def testee( ref[:, 0] = boundary ref[-1, :] = boundary ref[:, -1] = boundary - ref[1:-1, 1:-1] = input.asnumpy() - cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref) + ref[1:-1, 1:-1] = inp.asnumpy() + cases.verify(cartesian_case, testee, inp, boundary, out.domain.shape, out=out, ref=ref) @pytest.mark.uses_tuple_returns From 8635b2c45074f71c636c7e60969afc9e177c7b31 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 23:55:34 +0200 Subject: [PATCH 121/124] fix[next]: symbol clash in inline_lambda --- .../iterator/transforms/inline_lambdas.py | 17 +-- .../transforms_tests/test_inline_lambdas.py | 109 ++++++++++++++++++ 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index b7eb45d156..a41f74ebc1 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -71,9 +71,9 @@ def inline_lambda( # see todo above if eligible ) ) - syms: set[str] = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() + syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() clashes = refs & syms - expr = node.fun.expr + fun = node.fun if clashes: # TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work: # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code] @@ -82,14 +82,17 @@ def inline_lambda( # see todo above for sym in clashes: name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()}) - expr = RenameSymbols().visit(expr, name_map=name_map) + # Let's rename the symbols (including params) of the function. + # If we would like to preserve the original param names, we could alternatively + # rename the eligible symrefs in `args`. + fun = RenameSymbols().visit(fun, name_map=name_map) symbol_map = { param.id: arg - for param, arg, eligible in zip(node.fun.params, node.args, eligible_params) + for param, arg, eligible in zip(fun.params, node.args, eligible_params) if eligible } - new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) + new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map) if all(eligible_params): new_expr.location = node.location @@ -97,9 +100,7 @@ def inline_lambda( # see todo above new_expr = ir.FunCall( fun=ir.Lambda( params=[ - param - for param, eligible in zip(node.fun.params, eligible_params) - if not eligible + param for param, eligible in zip(fun.params, eligible_params) if not eligible ], expr=new_expr, ), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index c10d48ad06..00f1fb1a1b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -55,6 +55,17 @@ False: im.plus(im.call("opaque")(), im.call("opaque")()), }, ), + ( + # symbol clash when partially inlining (opcount preserving) + "symbol_clash", + im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))(im.plus("y", "y"), "x"), + { + True: im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))( + im.plus("y", "y") + ), + False: im.call("f")(im.plus("y", "y"), im.plus(im.plus("y", "y"), "x")), + }, + ), ] @@ -91,3 +102,101 @@ def test_type_preservation(): testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) inlined = InlineLambdas.apply(testee) assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + + +def test_dbg1(): + # (λ(foo, bar) → + # (λ(foo, bar) → + # if_(c0, + # foo, + # if_(c1, foo, bar)))( + # if_(c2, foo, bar), foo + # ))(a, b) + + testee = im.call( + im.lambda_("foo", "bar")( + im.call( + im.lambda_("foo", "bar")( + im.if_(im.ref("c0"), "foo", im.if_(im.ref("c1"), "foo", "bar")) + ) + )( + im.if_(im.ref("c2"), "foo", "bar"), + "foo", + ) + ) + )("a", "b") + print(testee) + inlined = InlineLambdas.apply(testee, opcount_preserving=True) + print(inlined) + inlined = InlineLambdas.apply(inlined, opcount_preserving=False) + print(inlined) + + # if c0 then b else if c1 then b else b + + # expected: + # if_(c0, + # if_(c2, a, b), + # if_(c1, + # if_(c2, a, b), + # a) + # ) + expected = im.if_( + im.ref("c0"), + im.if_(im.ref("c2"), "a", "b"), + im.if_(im.ref("c1"), im.if_(im.ref("c2"), "a", "b"), "a"), + ) + print(expected) + assert inlined == expected + + +# def test_dbg2(): +# testee = im.call( +# im.lambda_("x", "y")(im.multiplies_(im.call(im.lambda_("x")(im.plus("x", 1)))("y"), "x")) +# )(im.plus("x", "x"), "x") + +# print(testee) +# inlined = InlineLambdas.apply(testee, opcount_preserving=True) +# print(inlined) + + +def test_dbg2(): + testee = im.call( + im.lambda_("x", "y")( + im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( + im.plus("y", "y"), "x" + ) + ) + )("a", "b") + + print(testee) + inlined = InlineLambdas.apply(testee, opcount_preserving=True) + print(inlined) + inlined = InlineLambdas.apply(inlined, opcount_preserving=False) + print(inlined) + + direct = InlineLambdas.apply(testee, opcount_preserving=False) + print(direct) + + +def test_dbg3(): + testee = im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( + im.plus("y", "y"), "x" + ) + + print(testee) + # inlined = InlineLambdas.apply(testee, opcount_preserving=True) + inlined = InlineLambdas.apply(testee, opcount_preserving=True) + print(inlined) + # inlined = inline_lambda(testee, opcount_preserving=False) + # print(inlined) + + # expected = (λ(x) → f(x_, x_ + x))(y + y) + expected = im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))(im.plus("y", "y")) + print(expected) + assert inlined == expected + + # inlined = InlineLambdas.apply(inlined, opcount_preserving=False) + # print(inlined) + + # direct = InlineLambdas.apply(testee, opcount_preserving=False) + # print(direct) From c72f4945e5d0c98dc50c7900f822e57cb8474106 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 11 Jul 2025 07:15:05 +0200 Subject: [PATCH 122/124] cleanup --- .../transforms_tests/test_inline_lambdas.py | 98 ------------------- 1 file changed, 98 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 00f1fb1a1b..314d14cd7f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -102,101 +102,3 @@ def test_type_preservation(): testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) inlined = InlineLambdas.apply(testee) assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) - - -def test_dbg1(): - # (λ(foo, bar) → - # (λ(foo, bar) → - # if_(c0, - # foo, - # if_(c1, foo, bar)))( - # if_(c2, foo, bar), foo - # ))(a, b) - - testee = im.call( - im.lambda_("foo", "bar")( - im.call( - im.lambda_("foo", "bar")( - im.if_(im.ref("c0"), "foo", im.if_(im.ref("c1"), "foo", "bar")) - ) - )( - im.if_(im.ref("c2"), "foo", "bar"), - "foo", - ) - ) - )("a", "b") - print(testee) - inlined = InlineLambdas.apply(testee, opcount_preserving=True) - print(inlined) - inlined = InlineLambdas.apply(inlined, opcount_preserving=False) - print(inlined) - - # if c0 then b else if c1 then b else b - - # expected: - # if_(c0, - # if_(c2, a, b), - # if_(c1, - # if_(c2, a, b), - # a) - # ) - expected = im.if_( - im.ref("c0"), - im.if_(im.ref("c2"), "a", "b"), - im.if_(im.ref("c1"), im.if_(im.ref("c2"), "a", "b"), "a"), - ) - print(expected) - assert inlined == expected - - -# def test_dbg2(): -# testee = im.call( -# im.lambda_("x", "y")(im.multiplies_(im.call(im.lambda_("x")(im.plus("x", 1)))("y"), "x")) -# )(im.plus("x", "x"), "x") - -# print(testee) -# inlined = InlineLambdas.apply(testee, opcount_preserving=True) -# print(inlined) - - -def test_dbg2(): - testee = im.call( - im.lambda_("x", "y")( - im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( - im.plus("y", "y"), "x" - ) - ) - )("a", "b") - - print(testee) - inlined = InlineLambdas.apply(testee, opcount_preserving=True) - print(inlined) - inlined = InlineLambdas.apply(inlined, opcount_preserving=False) - print(inlined) - - direct = InlineLambdas.apply(testee, opcount_preserving=False) - print(direct) - - -def test_dbg3(): - testee = im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( - im.plus("y", "y"), "x" - ) - - print(testee) - # inlined = InlineLambdas.apply(testee, opcount_preserving=True) - inlined = InlineLambdas.apply(testee, opcount_preserving=True) - print(inlined) - # inlined = inline_lambda(testee, opcount_preserving=False) - # print(inlined) - - # expected = (λ(x) → f(x_, x_ + x))(y + y) - expected = im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))(im.plus("y", "y")) - print(expected) - assert inlined == expected - - # inlined = InlineLambdas.apply(inlined, opcount_preserving=False) - # print(inlined) - - # direct = InlineLambdas.apply(testee, opcount_preserving=False) - # print(direct) From 9fe9c5c8edf2915d85190c975e972822ac29fb7c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 11 Jul 2025 08:18:09 +0200 Subject: [PATCH 123/124] address review comments --- tests/next_tests/integration_tests/cases.py | 5 ++--- .../iterator_tests/transforms_tests/test_collapse_tuple.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index abf5707cf6..967cf0ab11 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -356,12 +356,11 @@ def allocate( if sizes: assert not domain and all(dim in case.default_sizes for dim in sizes) domain = { - dim: (0, sizes[dim] if dim in sizes else default_size) + dim: (0, sizes.get(dim, default_size)) for dim, default_size in case.default_sizes.items() } - if not domain: - domain = {dim: (0, size) for dim, size in case.default_sizes.items()} + domain = domain or {dim: (0, size) for dim, size in case.default_sizes.items()} if not isinstance(domain, gtx.Domain): domain = gtx.domain(domain) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index f0ecd8aff0..97ffb0ee6d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -5,12 +5,14 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next import common import pytest + +from gt4py.next import common from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts +from gt4py.next.type_system import type_specifications as ts + int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) From 8f71edfeef3dee2ce38c49455bc01df855fb8d66 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 11 Jul 2025 08:31:42 +0200 Subject: [PATCH 124/124] cleanup todo --- src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index daeec2b675..b3c81ca2d0 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -329,7 +329,7 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - # TODO: document why to keep existing domains, add test + # TODO(tehrengruber): document why to keep existing domains and add test program = infer_domain.infer_program( program, offset_provider=offset_provider,