From 768e4585e09563155edece4e1839dacc80331641 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 12:51:06 +0100 Subject: [PATCH 01/93] First draft --- src/gt4py/next/ffront/past_to_itir.py | 4 +++- src/gt4py/next/iterator/builtins.py | 6 ++++++ src/gt4py/next/iterator/embedded.py | 5 +++++ src/gt4py/next/iterator/runtime.py | 5 +++++ .../next/iterator/type_system/type_synthesizer.py | 7 +++++++ .../next/program_processors/codegens/gtfn/codegen.py | 12 ++++++++++++ .../next/program_processors/codegens/gtfn/gtfn_ir.py | 1 + .../codegens/gtfn/itir_to_gtfn_ir.py | 8 ++++++++ src/gt4py/next/program_processors/runners/gtfn.py | 4 ++-- 9 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4bc1dfb2f8..b480203050 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -349,7 +349,9 @@ def _construct_itir_domain_arg( domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension - dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i)) + dim_range = im.call("get_domain")( + out_field.id, itir.AxisLiteral(value=dim.value, kind=dim.kind) + ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds lower: itir.Expr diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 8e5f7addca..f197107542 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -402,6 +402,11 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def get_domain(*args): + raise BackendNotSelectedError() + + UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -474,6 +479,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "cartesian_domain", "cast_", "deref", + "get_domain", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index da0516d26b..6e5bb4608c 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1683,6 +1683,11 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable operators._tuple_assign_field(target, expr, common.domain(domain)) +@runtime.get_domain.register(EMBEDDED) +def get_domain(field: common.Field, dim: common.Dimension) -> tuple[int, int]: + return (field.domain[dim].unit_range.start, field.domain[dim].unit_range.stop) + + @runtime.if_stmt.register(EMBEDDED) def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: """ diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index c9a5b15de7..c831f33f26 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -213,6 +213,11 @@ def set_at(*args): return BackendNotSelectedError() +@builtin_dispatch +def get_domain(*args): + return BackendNotSelectedError() + + @builtin_dispatch def if_stmt(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 131b773dd2..d7baa72f12 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -329,6 +329,13 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: return applied_as_fieldop +@_register_builtin_type_synthesizer +def get_domain(field: ts.FieldType, dim: ts.DimensionType) -> ts.TupleType: + return ts.TupleType( + types=[ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))] * 2 + ) + + @_register_builtin_type_synthesizer def scan( scan_pass: TypeSynthesizer, direction: ts.ScalarType, init: ts.ScalarType diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 969e203689..a9f596effc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -110,6 +110,8 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: return self.asfloat(node.value) case "bool": return node.value.lower() + case "axis_literal": + return node.value + "_t" case _: # TODO(tehrengruber): we should probably shouldn't just allow anything here. Revisit. return node.value @@ -272,6 +274,16 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include + namespace gridtools::fn { + // TODO(tehrengruber): `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type` + // fails as type used for index calculations in gtfn differs + template + GT_FUNCTION gridtools::tuple get_domain(S &&sid, D) { + return {gridtools::host_device::at_key(gridtools::sid::get_lower_bounds(sid)), + gridtools::host_device::at_key(gridtools::sid::get_upper_bounds(sid))}; + } + } + namespace generated{ namespace gtfn = ::gridtools::fn; diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 831694791a..e45ce983e5 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -226,6 +226,7 @@ class TemporaryAllocation(Node): "can_deref", "cartesian_domain", "unstructured_domain", + "get_domain", "named_range", "reduce", "index", diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 104e2eccc1..2bc25de203 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -485,6 +485,14 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: tagged_sizes=sizes, tagged_offsets=domain_offsets, connectivities=connectivities ) + def _visit_get_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: + field, dim = node.args + + return FunCall( + fun=SymRef(id="get_domain"), + args=[self.visit(field, **kwargs), self.visit(dim, **kwargs)], + ) + def visit_FunCall(self, node: itir.FunCall, **kwargs: Any) -> Node: if isinstance(node.fun, itir.SymRef): if node.fun.id in self._unary_op_map: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a8961fd9bc..2fa273322e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler -from gt4py.next.otf.compilation.build_systems import compiledb +from gt4py.next.otf.compilation.build_systems import cmake from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -141,7 +141,7 @@ class Params: lambda: config.CMAKE_BUILD_TYPE ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( - lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) + lambda o: cmake.CMakeFactory(cmake_build_type=o.cmake_build_type) ) cached_translation = factory.Trait( From ac7db53d988ea9635b70cdf9eb6341bbce10681e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 2 Mar 2025 12:54:03 +0100 Subject: [PATCH 02/93] Remove debugging leftovers --- src/gt4py/next/program_processors/runners/gtfn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2fa273322e..a8961fd9bc 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler -from gt4py.next.otf.compilation.build_systems import cmake +from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -141,7 +141,7 @@ class Params: lambda: config.CMAKE_BUILD_TYPE ) builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( - lambda o: cmake.CMakeFactory(cmake_build_type=o.cmake_build_type) + lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) cached_translation = factory.Trait( From 61b4a0942d5b8ef8b68dd56a692201d0706cc176 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 16 Jul 2025 19:52:42 +0200 Subject: [PATCH 03/93] Add transformation for get_domain to named_range --- .../transforms/transform_get_domain.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/transform_get_domain.py diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain.py b/src/gt4py/next/iterator/transforms/transform_get_domain.py new file mode 100644 index 0000000000..d42dc250b6 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/transform_get_domain.py @@ -0,0 +1,100 @@ +import dataclasses +from typing import Dict + +from gt4py.eve import PreserveLocationVisitor, NodeTranslator +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import ir_makers as im + + +@dataclasses.dataclass(frozen=True) +class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): + """ + Transforms `get_domain` calls into `named_range` calls with given size. + + Example: + >>> from gt4py.next.type_system import type_specifications as ts + >>> from gt4py import next as gtx + >>> float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + >>> KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) + >>> Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) + >>> float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) + + >>> unstructured_domain_get = im.call("unstructured_domain")( + ... im.call("get_domain")("out", im.axis_literal(Vertex)), + ... im.call("get_domain")("out", im.axis_literal(KDim)), + ... ) + + >>> unstructured_domain = im.call("unstructured_domain")( + ... im.call("named_range")(im.axis_literal(Vertex), 0, 10), + ... im.call("named_range")(im.axis_literal(KDim), 0, 20), + ... ) + + >>> ir = itir.Program( + ... id="test", + ... function_definitions=[], + ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + ... domain=unstructured_domain_get, + ... target=im.ref("out"), + ... ), + ... ], + ... ) + + >>> sizes = {"out": gtx.domain({Vertex: (0,10), KDim: (0,20)})} + + >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> print(result) + test(inp, out) { + out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ ← (⇑deref)(inp); + } + + >>> ir = itir.Program( + ... id="test", + ... function_definitions=[], + ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)(im.ref("inp")), # TODO: unstructured_domain_get raises AssertionError in domain_utils.py line 77: assert cpm.is_call_to(named_range, "named_range") + ... domain=unstructured_domain_get, + ... target=im.ref("out"), + ... ), + ... ], + ... ) + + >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> print (result) # TODO: this test still fails because of the AssertionError + test(inp, out) { + out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ + ← as_fieldop(deref, u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩)(inp); + } + """ + + @classmethod + def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): + return cls().visit(program, sizes=sizes) + + def visit_FunCall(self, node: itir.SetAt, **kwargs) -> itir.FunCall: + sizes = kwargs["sizes"] + + if cpm.is_call_to(node, "get_domain"): + ref, dim = node.args + if isinstance(ref, itir.SymRef): + assert ref.id in sizes, f"Symbol '{ref.id}' not found in sizes Dict." + input_dims = sizes[ref.id].dims + index = next((i for i, d in enumerate(input_dims) if d.value == dim.value), None) + assert index is not None, f"Dimension {dim.value} not found in {input_dims}" + dim = input_dims[index] + start = sizes[ref.id].ranges[index].start + stop = sizes[ref.id].ranges[index].stop + return im.call("named_range")(im.axis_literal(dim), start, stop) + + # TODO: handle tuples: get_domain(tuple_get(0, "out")) + + return self.generic_visit(node, sizes=sizes) \ No newline at end of file From 4a65e7a47808eaed6695f7886c384fe05d2b34c3 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 17 Jul 2025 15:43:37 +0200 Subject: [PATCH 04/93] Add tuple suppoert --- .../transforms/transform_get_domain.py | 243 ++++++++++++------ 1 file changed, 163 insertions(+), 80 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain.py b/src/gt4py/next/iterator/transforms/transform_get_domain.py index d42dc250b6..fc5a894d66 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain.py +++ b/src/gt4py/next/iterator/transforms/transform_get_domain.py @@ -1,80 +1,147 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import dataclasses from typing import Dict -from gt4py.eve import PreserveLocationVisitor, NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import collapse_tuple @dataclasses.dataclass(frozen=True) class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): """ - Transforms `get_domain` calls into `named_range` calls with given size. - - Example: - >>> from gt4py.next.type_system import type_specifications as ts - >>> from gt4py import next as gtx - >>> float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) - >>> KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) - >>> Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) - >>> float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) - - >>> unstructured_domain_get = im.call("unstructured_domain")( - ... im.call("get_domain")("out", im.axis_literal(Vertex)), - ... im.call("get_domain")("out", im.axis_literal(KDim)), - ... ) - - >>> unstructured_domain = im.call("unstructured_domain")( - ... im.call("named_range")(im.axis_literal(Vertex), 0, 10), - ... im.call("named_range")(im.axis_literal(KDim), 0, 20), - ... ) - - >>> ir = itir.Program( - ... id="test", - ... function_definitions=[], - ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - ... declarations=[], - ... body=[ - ... itir.SetAt( - ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - ... domain=unstructured_domain_get, - ... target=im.ref("out"), - ... ), - ... ], - ... ) - - >>> sizes = {"out": gtx.domain({Vertex: (0,10), KDim: (0,20)})} - - >>> result = TransformGetDomain.apply(ir, sizes=sizes) - >>> print(result) - test(inp, out) { - out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ ← (⇑deref)(inp); - } - - >>> ir = itir.Program( - ... id="test", - ... function_definitions=[], - ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - ... declarations=[], - ... body=[ - ... itir.SetAt( - ... expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)(im.ref("inp")), # TODO: unstructured_domain_get raises AssertionError in domain_utils.py line 77: assert cpm.is_call_to(named_range, "named_range") - ... domain=unstructured_domain_get, - ... target=im.ref("out"), - ... ), - ... ], - ... ) - - >>> result = TransformGetDomain.apply(ir, sizes=sizes) - >>> print (result) # TODO: this test still fails because of the AssertionError - test(inp, out) { - out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ - ← as_fieldop(deref, u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩)(inp); - } - """ + Transforms `get_domain` calls into `named_range` calls with given size. + + Example: + >>> from gt4py.next.type_system import type_specifications as ts + >>> from gt4py import next as gtx + >>> float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + >>> KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) + >>> Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) + >>> float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) + + + >>> sizes = { + ... "out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)}), + ... "a": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)})), + ... "b": gtx.domain({KDim: (0, 3)}), + ... "c": gtx.domain({KDim: (0, 4)}), + ... } + + >>> unstructured_domain_get_out = im.call("unstructured_domain")( + ... im.call("get_domain")("out", im.axis_literal(Vertex)), + ... im.call("get_domain")("out", im.axis_literal(KDim)), + ... ) + >>> ir = itir.Program( + ... id="test1", + ... function_definitions=[], + ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + ... domain=unstructured_domain_get_out, + ... target=im.ref("out"), + ... ), + ... ], + ... ) + >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> print(result) + test1(inp, out) { + out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ ← (⇑deref)(inp); + } + + >>> unstructured_domain = im.call( + ... "unstructured_domain" + ... )( # TODO: remove once the AssertionError is fixed + ... im.call("named_range")(im.axis_literal(Vertex), 0, 10), + ... im.call("named_range")(im.axis_literal(KDim), 0, 20), + ... ) + >>> ir = itir.Program( + ... id="test2", + ... function_definitions=[], + ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get_out)( + ... im.ref("inp") + ... ), # TODO: unstructured_domain_get raises AssertionError in domain_utils.py line 77: assert cpm.is_call_to(named_range, "named_range") + ... domain=unstructured_domain_get_out, + ... target=im.ref("out"), + ... ), + ... ], + ... ) + >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> print(result) # TODO: this test still fails because of the AssertionError + test2(inp, out) { + out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ + ← as_fieldop(deref, u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩)(inp); + } + + >>> unstructured_domain_get_a = im.call("unstructured_domain")( + ... im.call("get_domain")(im.tuple_get(0, "a"), im.axis_literal(Vertex)) + ... ) + >>> ir = itir.Program( + ... id="test3", + ... function_definitions=[], + ... params=[im.sym("inp", float_i_field), im.sym("a")], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + ... domain=unstructured_domain_get_a, + ... target=im.tuple_get(0, "a"), + ... ), + ... ], + ... ) + >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> print(result) + test3(inp, a) { + a[0] @ u⟨ Vertexₕ: [0, 5[ ⟩ ← (⇑deref)(inp); + } + + >>> t0 = im.make_tuple("b", "c") + >>> t1 = im.make_tuple("d", "e") + >>> tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) + >>> unstructured_domain_get_make_tuple_b = im.call("unstructured_domain")( + ... im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim)) + ... ) + >>> ir = itir.Program( + ... id="test4", + ... function_definitions=[], + ... params=[ + ... im.sym("inp", float_i_field), + ... im.sym("b"), + ... im.sym("c"), + ... im.sym("d"), + ... im.sym("e"), + ... ], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + ... domain=unstructured_domain_get_make_tuple_b, + ... target=im.ref("b"), + ... ), + ... ], + ... ) + >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> print(result) + test4(inp, b, c, d, e) { + b @ u⟨ KDimᵥ: [0, 3[ ⟩ ← (⇑deref)(inp); + } + """ @classmethod def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): @@ -83,18 +150,34 @@ def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): def visit_FunCall(self, node: itir.SetAt, **kwargs) -> itir.FunCall: sizes = kwargs["sizes"] - if cpm.is_call_to(node, "get_domain"): - ref, dim = node.args - if isinstance(ref, itir.SymRef): - assert ref.id in sizes, f"Symbol '{ref.id}' not found in sizes Dict." - input_dims = sizes[ref.id].dims - index = next((i for i, d in enumerate(input_dims) if d.value == dim.value), None) - assert index is not None, f"Dimension {dim.value} not found in {input_dims}" - dim = input_dims[index] - start = sizes[ref.id].ranges[index].start - stop = sizes[ref.id].ranges[index].stop - return im.call("named_range")(im.axis_literal(dim), start, stop) + if not cpm.is_call_to(node, "get_domain"): + return self.generic_visit(node, sizes=sizes) - # TODO: handle tuples: get_domain(tuple_get(0, "out")) + field, dim = node.args - return self.generic_visit(node, sizes=sizes) \ No newline at end of file + if cpm.is_call_to(field, "tuple_get"): + ref = field.args[1] + if isinstance(ref, itir.SymRef): + assert ref.id in sizes, f"Symbol '{ref.id}' not found in sizes Dict." + domain = sizes[ref.id][int(field.args[0].value)] + else: + field = collapse_tuple.CollapseTuple.apply( + field, within_stencil=False, allow_undeclared_symbols=True + ) + return self.visit(im.call("get_domain")(field, dim), sizes=sizes) + elif isinstance(field, itir.SymRef): + assert field.id in sizes, f"Symbol '{field.id}' not found in sizes Dict." + domain = sizes[field.id] + else: + raise NotImplementedError( + "Only calls to tuple_get or SymRefs are supported as first argument of get_domain." + ) + + input_dims = domain.dims + index = next((i for i, d in enumerate(input_dims) if d.value == dim.value), None) + assert index is not None, f"Dimension {dim.value} not found in {input_dims}" + + dim = input_dims[index] + start = domain.ranges[index].start + stop = domain.ranges[index].stop + return im.call("named_range")(im.axis_literal(dim), start, stop) From 28804a32279bef49fbd80e095a80bd85bc64e4fd Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 17 Jul 2025 18:21:29 +0200 Subject: [PATCH 05/93] Move tests to new file --- .../transforms/transform_get_domain.py | 96 +-------- .../test_transform_get_domain.py | 192 ++++++++++++++++++ 2 files changed, 196 insertions(+), 92 deletions(-) create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain.py b/src/gt4py/next/iterator/transforms/transform_get_domain.py index fc5a894d66..2e4ee85aa3 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain.py +++ b/src/gt4py/next/iterator/transforms/transform_get_domain.py @@ -22,20 +22,12 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): Transforms `get_domain` calls into `named_range` calls with given size. Example: - >>> from gt4py.next.type_system import type_specifications as ts >>> from gt4py import next as gtx - >>> float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) >>> KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) >>> Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) - >>> float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) - >>> sizes = { ... "out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)}), - ... "a": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)})), - ... "b": gtx.domain({KDim: (0, 3)}), - ... "c": gtx.domain({KDim: (0, 4)}), ... } >>> unstructured_domain_get_out = im.call("unstructured_domain")( @@ -43,9 +35,9 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): ... im.call("get_domain")("out", im.axis_literal(KDim)), ... ) >>> ir = itir.Program( - ... id="test1", + ... id="test", ... function_definitions=[], - ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + ... params=[im.sym("inp"), im.sym("out")], ... declarations=[], ... body=[ ... itir.SetAt( @@ -57,90 +49,9 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): ... ) >>> result = TransformGetDomain.apply(ir, sizes=sizes) >>> print(result) - test1(inp, out) { + test(inp, out) { out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ ← (⇑deref)(inp); } - - >>> unstructured_domain = im.call( - ... "unstructured_domain" - ... )( # TODO: remove once the AssertionError is fixed - ... im.call("named_range")(im.axis_literal(Vertex), 0, 10), - ... im.call("named_range")(im.axis_literal(KDim), 0, 20), - ... ) - >>> ir = itir.Program( - ... id="test2", - ... function_definitions=[], - ... params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - ... declarations=[], - ... body=[ - ... itir.SetAt( - ... expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get_out)( - ... im.ref("inp") - ... ), # TODO: unstructured_domain_get raises AssertionError in domain_utils.py line 77: assert cpm.is_call_to(named_range, "named_range") - ... domain=unstructured_domain_get_out, - ... target=im.ref("out"), - ... ), - ... ], - ... ) - >>> result = TransformGetDomain.apply(ir, sizes=sizes) - >>> print(result) # TODO: this test still fails because of the AssertionError - test2(inp, out) { - out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ - ← as_fieldop(deref, u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩)(inp); - } - - >>> unstructured_domain_get_a = im.call("unstructured_domain")( - ... im.call("get_domain")(im.tuple_get(0, "a"), im.axis_literal(Vertex)) - ... ) - >>> ir = itir.Program( - ... id="test3", - ... function_definitions=[], - ... params=[im.sym("inp", float_i_field), im.sym("a")], - ... declarations=[], - ... body=[ - ... itir.SetAt( - ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - ... domain=unstructured_domain_get_a, - ... target=im.tuple_get(0, "a"), - ... ), - ... ], - ... ) - >>> result = TransformGetDomain.apply(ir, sizes=sizes) - >>> print(result) - test3(inp, a) { - a[0] @ u⟨ Vertexₕ: [0, 5[ ⟩ ← (⇑deref)(inp); - } - - >>> t0 = im.make_tuple("b", "c") - >>> t1 = im.make_tuple("d", "e") - >>> tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - >>> unstructured_domain_get_make_tuple_b = im.call("unstructured_domain")( - ... im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim)) - ... ) - >>> ir = itir.Program( - ... id="test4", - ... function_definitions=[], - ... params=[ - ... im.sym("inp", float_i_field), - ... im.sym("b"), - ... im.sym("c"), - ... im.sym("d"), - ... im.sym("e"), - ... ], - ... declarations=[], - ... body=[ - ... itir.SetAt( - ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - ... domain=unstructured_domain_get_make_tuple_b, - ... target=im.ref("b"), - ... ), - ... ], - ... ) - >>> result = TransformGetDomain.apply(ir, sizes=sizes) - >>> print(result) - test4(inp, b, c, d, e) { - b @ u⟨ KDimᵥ: [0, 3[ ⟩ ← (⇑deref)(inp); - } """ @classmethod @@ -159,6 +70,7 @@ def visit_FunCall(self, node: itir.SetAt, **kwargs) -> itir.FunCall: ref = field.args[1] if isinstance(ref, itir.SymRef): assert ref.id in sizes, f"Symbol '{ref.id}' not found in sizes Dict." + assert isinstance(sizes[ref.id], tuple), "A domain-tuple must be passed." domain = sizes[ref.id][int(field.args[0].value)] else: field = collapse_tuple.CollapseTuple.apply( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py new file mode 100644 index 0000000000..5bb226b93f --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -0,0 +1,192 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain + +KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) +Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) + + +def test_get_domain(): + sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} + + unstructured_domain_get = im.call("unstructured_domain")( + im.call("get_domain")("out", im.axis_literal(Vertex)), + im.call("get_domain")("out", im.axis_literal(KDim)), + ) + testee = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=unstructured_domain_get, + target=im.ref("out"), + ), + ], + ) + + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 0, 10), + im.call("named_range")(im.axis_literal(KDim), 0, 20), + ) + expected = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=unstructured_domain, + target=im.ref("out"), + ), + ], + ) + + actual = TransformGetDomain.apply(testee, sizes=sizes) + assert actual == expected + + +def test_get_domain_inside_as_fieldop(): + sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} + + unstructured_domain_get = im.call("unstructured_domain")( + im.call("get_domain")("out", im.axis_literal(Vertex)), + im.call("get_domain")("out", im.axis_literal(KDim)), + ) + testee = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)( + im.ref("inp") + ), # TODO: unstructured_domain_get raises AssertionError in domain_utils.py line 77: assert cpm.is_call_to(named_range, "named_range") + domain=unstructured_domain_get, + target=im.ref("out"), + ), + ], + ) + + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 0, 10), + im.call("named_range")(im.axis_literal(KDim), 0, 20), + ) + expected = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"), unstructured_domain)(im.ref("inp")), + domain=unstructured_domain, + target=im.ref("out"), + ), + ], + ) + + actual = TransformGetDomain.apply(testee, sizes=sizes) + assert actual == expected # TODO: this test still fails because of the AssertionError + + +def test_get_domain_tuples(): + sizes = {"out": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)}))} + + unstructured_domain_get = im.call("unstructured_domain")( + im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + ) + + testee = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=unstructured_domain_get, + target=im.tuple_get(0, "out"), + ), + ], + ) + + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 0, 5), + ) + expected = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=unstructured_domain, + target=im.tuple_get(0, "out"), + ), + ], + ) + + actual = TransformGetDomain.apply(testee, sizes=sizes) + assert actual == expected + + +def test_get_domain_nested_tuples(): + sizes = {"a": gtx.domain({KDim: (0, 3)})} + + t0 = im.make_tuple("a", "b") + t1 = im.make_tuple("c", "d") + tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) + unstructured_domain_get = im.call("unstructured_domain")( + im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim)) + ) + + testee = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=unstructured_domain_get, + target=im.ref("a"), + ), + ], + ) + + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(KDim), 0, 3), + ) + expected = itir.Program( + id="test", + function_definitions=[], + params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=unstructured_domain, + target=im.ref("a"), + ), + ], + ) + + actual = TransformGetDomain.apply(testee, sizes=sizes) + assert actual == expected From 5455f34333dd6e41a6b161a6fd57a3ef15e1f2fa Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 23 Jul 2025 16:04:52 +0200 Subject: [PATCH 06/93] Update TransformGetDomain to return a tuple, introduce named_range in ir.makers and update tests --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +- .../transforms/transform_get_domain.py | 22 ++++-- .../iterator_tests/test_type_inference.py | 28 ++++---- .../test_transform_get_domain.py | 72 +++++++++++-------- .../gtfn_tests/test_itir_to_gtfn_ir.py | 2 +- 5 files changed, 78 insertions(+), 52 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 739aa5d90d..e6477253fc 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -445,7 +445,7 @@ def domain( grid_type = f"{grid_type!s}_domain" expr = call(grid_type)( *[ - call("named_range")( + named_range( axis_literal(d), r[0], r[1], @@ -457,6 +457,10 @@ def domain( return expr +def named_range(dim: itir.AxisLiteral, start: itir.Expr, stop: itir.Expr) -> itir.FunCall: + return call("named_range")(dim, start, stop) + + def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain.py b/src/gt4py/next/iterator/transforms/transform_get_domain.py index 2e4ee85aa3..d2a94f20be 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain.py +++ b/src/gt4py/next/iterator/transforms/transform_get_domain.py @@ -19,7 +19,7 @@ @dataclasses.dataclass(frozen=True) class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): """ - Transforms `get_domain` calls into `named_range` calls with given size. + Transforms `get_domain` calls into a tuple containing start and stop. Example: >>> from gt4py import next as gtx @@ -31,8 +31,16 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): ... } >>> unstructured_domain_get_out = im.call("unstructured_domain")( - ... im.call("get_domain")("out", im.axis_literal(Vertex)), - ... im.call("get_domain")("out", im.axis_literal(KDim)), + ... im.named_range( + ... im.axis_literal(Vertex), + ... im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), + ... im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))), + ... ), + ... im.named_range( + ... im.axis_literal(KDim), + ... im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(KDim))), + ... im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(KDim))), + ... ), ... ) >>> ir = itir.Program( ... id="test", @@ -50,7 +58,7 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): >>> result = TransformGetDomain.apply(ir, sizes=sizes) >>> print(result) test(inp, out) { - out @ u⟨ Vertexₕ: [0, 10[, KDimᵥ: [0, 20[ ⟩ ← (⇑deref)(inp); + out @ u⟨ Vertexₕ: [{0, 10}[0], {0, 10}[1][, KDimᵥ: [{0, 20}[0], {0, 20}[1][ ⟩ ← (⇑deref)(inp); } """ @@ -58,7 +66,7 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): return cls().visit(program, sizes=sizes) - def visit_FunCall(self, node: itir.SetAt, **kwargs) -> itir.FunCall: + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: sizes = kwargs["sizes"] if not cpm.is_call_to(node, "get_domain"): @@ -89,7 +97,7 @@ def visit_FunCall(self, node: itir.SetAt, **kwargs) -> itir.FunCall: index = next((i for i, d in enumerate(input_dims) if d.value == dim.value), None) assert index is not None, f"Dimension {dim.value} not found in {input_dims}" - dim = input_dims[index] start = domain.ranges[index].start stop = domain.ranges[index].stop - return im.call("named_range")(im.axis_literal(dim), start, stop) + node = im.make_tuple(start, stop) + return node diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a0361e7ba2..378fe562db 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -95,20 +95,20 @@ def expression_test_cases(): (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), (im.list_get(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( - im.call("named_range")( + im.named_range( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ), it_ts.NamedRangeType(dim=Vertex), ), ( im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ), ts.DomainType(dims=[IDim]), ), ( im.call("unstructured_domain")( - im.call("named_range")( + im.named_range( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ) ), @@ -239,13 +239,13 @@ def expression_test_cases(): im.as_fieldop( im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ), )(im.ref("inp", float_i_field), 1.0), im.as_fieldop( "deref", im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ), )(im.ref("inp", float_i_field)), ), @@ -390,7 +390,7 @@ def test_cast_first_arg_inference(): def test_cartesian_fencil_definition(): cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ) testee = itir.Program( @@ -420,10 +420,10 @@ def test_cartesian_fencil_definition(): def test_unstructured_fencil_definition(): mesh = simple_mesh(None) unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")( + im.named_range( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ), - im.call("named_range")( + im.named_range( itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 ), ) @@ -458,7 +458,7 @@ def test_unstructured_fencil_definition(): def test_function_definition(): cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ) testee = itir.Program( @@ -489,10 +489,10 @@ def test_function_definition(): def test_fencil_with_nb_field_input(): mesh = simple_mesh(None) unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")( + im.named_range( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ), - im.call("named_range")( + im.named_range( itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 ), ) @@ -522,7 +522,7 @@ def test_fencil_with_nb_field_input(): def test_program_tuple_setat_short_target(): cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ) testee = itir.Program( @@ -553,7 +553,7 @@ def test_program_tuple_setat_short_target(): def test_program_setat_without_domain(): cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ) testee = itir.Program( @@ -577,7 +577,7 @@ def test_program_setat_without_domain(): def test_if_stmt(): cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) ) testee = itir.IfStmt( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 5bb226b93f..b3229163c1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -6,24 +6,41 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from next_tests.integration_tests.cases import ( + IDim, + KDim, + Vertex, +) from gt4py import next as gtx -from gt4py.next import common +from gt4py.next import Domain from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain -KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) -Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) +IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) + + +def construct_domains(domain_resolved: Domain, symbol_name: str, type: str): + named_ranges_get, named_ranges_resolved = [], [] + + for dim, range_ in zip(domain_resolved.dims, domain_resolved.ranges): + get_domain_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) + named_ranges_get.append( + im.named_range(im.axis_literal(dim), im.tuple_get(0, get_domain_call), im.tuple_get(1, get_domain_call)) + ) + bounds_tuple = im.make_tuple(range_.start, range_.stop) + named_ranges_resolved.append( + im.named_range(im.axis_literal(dim), im.tuple_get(0, bounds_tuple), im.tuple_get(1, bounds_tuple)) + ) + + return im.call(type)(*named_ranges_resolved), im.call(type)(*named_ranges_get) def test_get_domain(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} + unstructured_domain, unstructured_domain_get = construct_domains(sizes["out"], "out", "unstructured_domain") - unstructured_domain_get = im.call("unstructured_domain")( - im.call("get_domain")("out", im.axis_literal(Vertex)), - im.call("get_domain")("out", im.axis_literal(KDim)), - ) testee = itir.Program( id="test", function_definitions=[], @@ -38,10 +55,6 @@ def test_get_domain(): ], ) - unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 0, 10), - im.call("named_range")(im.axis_literal(KDim), 0, 20), - ) expected = itir.Program( id="test", function_definitions=[], @@ -62,11 +75,8 @@ def test_get_domain(): def test_get_domain_inside_as_fieldop(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} + unstructured_domain, unstructured_domain_get = construct_domains(sizes["out"], "out", "unstructured_domain") - unstructured_domain_get = im.call("unstructured_domain")( - im.call("get_domain")("out", im.axis_literal(Vertex)), - im.call("get_domain")("out", im.axis_literal(KDim)), - ) testee = itir.Program( id="test", function_definitions=[], @@ -76,17 +86,13 @@ def test_get_domain_inside_as_fieldop(): itir.SetAt( expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)( im.ref("inp") - ), # TODO: unstructured_domain_get raises AssertionError in domain_utils.py line 77: assert cpm.is_call_to(named_range, "named_range") + ), domain=unstructured_domain_get, target=im.ref("out"), ), ], ) - unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 0, 10), - im.call("named_range")(im.axis_literal(KDim), 0, 20), - ) expected = itir.Program( id="test", function_definitions=[], @@ -102,14 +108,21 @@ def test_get_domain_inside_as_fieldop(): ) actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected # TODO: this test still fails because of the AssertionError + assert actual == expected def test_get_domain_tuples(): sizes = {"out": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)}))} unstructured_domain_get = im.call("unstructured_domain")( - im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + im.named_range(im.axis_literal(Vertex), + im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), + im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))) + ) + ) + unstructured_domain = im.call("unstructured_domain")( + im.named_range(im.axis_literal(Vertex), im.tuple_get(0, im.make_tuple(0, 5)), + im.tuple_get(1, im.make_tuple(0, 5))), ) testee = itir.Program( @@ -126,9 +139,6 @@ def test_get_domain_tuples(): ], ) - unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 0, 5), - ) expected = itir.Program( id="test", function_definitions=[], @@ -154,7 +164,14 @@ def test_get_domain_nested_tuples(): t1 = im.make_tuple("c", "d") tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) unstructured_domain_get = im.call("unstructured_domain")( - im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim)) + im.named_range(im.axis_literal(KDim), + im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))) + ) + ) + unstructured_domain = im.call("unstructured_domain")( + im.named_range(im.axis_literal(KDim), im.tuple_get(0, im.make_tuple(0, 3)), + im.tuple_get(1, im.make_tuple(0, 3))), ) testee = itir.Program( @@ -171,9 +188,6 @@ def test_get_domain_nested_tuples(): ], ) - unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(KDim), 0, 3), - ) expected = itir.Program( id="test", function_definitions=[], diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 50e8fa43f0..7da31a35bf 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -39,7 +39,7 @@ def test_unapplied_funcall_to_function_object(): def test_get_domains(): - domain = im.call("cartesian_domain")(im.call("named_range")(itir.AxisLiteral(value="D"), 1, 2)) + domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="D"), 1, 2)) testee = itir.Program( id="foo", function_definitions=[], From 1b002f787d6137f69216ebaac5799babc432314b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 24 Jul 2025 18:01:54 +0200 Subject: [PATCH 07/93] Compute actual tempory sizes in unstructured case --- .../next/iterator/ir_utils/domain_utils.py | 28 +- .../test_transform_get_domain.py | 657 ++++++++++++++++-- 2 files changed, 634 insertions(+), 51 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 52853899c4..be9f19c560 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -12,6 +12,8 @@ import functools from typing import Any, Callable, Iterable, Literal, Mapping, Optional +import numpy as np + from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @@ -129,27 +131,37 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - horizontal_sizes: dict[str, itir.Expr] + horizontal_sizes: dict[str, tuple[itir.Expr, itir.Expr]] + old_dim = connectivity_type.source_dim + new_dim = connectivity_type.codomain if symbolic_domain_sizes is not None: horizontal_sizes = { - k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items() + k: (im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), im.ensure_expr(v)) + for k, v in symbolic_domain_sizes.items() } else: # note: ugly but cheap re-computation, but should disappear assert common.is_offset_provider(offset_provider) horizontal_sizes = { - k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) + k: ( + im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN), + ) for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - - old_dim = connectivity_type.source_dim - new_dim = connectivity_type.codomain + min_ = np.min( + offset_provider[off.value].ndarray[:, val.value] + ) # TODO: multible shifts? multible occurences of that dimension? + max_ = np.max(offset_provider[off.value].ndarray[:, val.value]) + 1 + horizontal_sizes[new_dim.value] = ( + im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str(max_), builtins.INTEGER_INDEX_BUILTIN), + ) assert new_dim not in new_ranges or old_dim == new_dim new_range = SymbolicRange( - im.literal("0", builtins.INTEGER_INDEX_BUILTIN), - horizontal_sizes[new_dim.value], + horizontal_sizes[new_dim.value][0], horizontal_sizes[new_dim.value][1] ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index b3229163c1..9eef26c8c9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -5,21 +5,64 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional +import pytest from next_tests.integration_tests.cases import ( IDim, KDim, Vertex, + unstructured_case, + exec_alloc_descriptor, ) from gt4py import next as gtx from gt4py.next import Domain from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_fundefs, global_tmps, inline_lambdas from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts +from tests.next_tests.integration_tests import cases +from tests.next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + simple_cartesian_grid, + Edge, + simple_mesh, +) IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) +float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +v_field_type = ts.FieldType(dims=[Vertex], dtype=float_type) +ve_field_type = ts.FieldType(dims=[Edge, Vertex], dtype=float_type) +e_field_type = ts.FieldType(dims=[Edge], dtype=float_type) +i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) + + +# TODO: maybe check if domains are consistent in global_tmps + + +# override mesh descriptor to contain only the simple mesh +@pytest.fixture +def mesh_descriptor(exec_alloc_descriptor): + return simple_mesh(exec_alloc_descriptor.allocator) + + +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=params, + declarations=declarations or [], + body=body, + ) + def construct_domains(domain_resolved: Domain, symbol_name: str, type: str): named_ranges_get, named_ranges_resolved = [], [] @@ -27,11 +70,17 @@ def construct_domains(domain_resolved: Domain, symbol_name: str, type: str): for dim, range_ in zip(domain_resolved.dims, domain_resolved.ranges): get_domain_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) named_ranges_get.append( - im.named_range(im.axis_literal(dim), im.tuple_get(0, get_domain_call), im.tuple_get(1, get_domain_call)) + im.named_range( + im.axis_literal(dim), + im.tuple_get(0, get_domain_call), + im.tuple_get(1, get_domain_call), + ) ) bounds_tuple = im.make_tuple(range_.start, range_.stop) named_ranges_resolved.append( - im.named_range(im.axis_literal(dim), im.tuple_get(0, bounds_tuple), im.tuple_get(1, bounds_tuple)) + im.named_range( + im.axis_literal(dim), im.tuple_get(0, bounds_tuple), im.tuple_get(1, bounds_tuple) + ) ) return im.call(type)(*named_ranges_resolved), im.call(type)(*named_ranges_get) @@ -39,11 +88,11 @@ def construct_domains(domain_resolved: Domain, symbol_name: str, type: str): def test_get_domain(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} - unstructured_domain, unstructured_domain_get = construct_domains(sizes["out"], "out", "unstructured_domain") + unstructured_domain, unstructured_domain_get = construct_domains( + sizes["out"], "out", "unstructured_domain" + ) - testee = itir.Program( - id="test", - function_definitions=[], + testee = program_factory( params=[im.sym("inp"), im.sym("out")], declarations=[], body=[ @@ -55,9 +104,7 @@ def test_get_domain(): ], ) - expected = itir.Program( - id="test", - function_definitions=[], + expected = program_factory( params=[im.sym("inp"), im.sym("out")], declarations=[], body=[ @@ -75,27 +122,23 @@ def test_get_domain(): def test_get_domain_inside_as_fieldop(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} - unstructured_domain, unstructured_domain_get = construct_domains(sizes["out"], "out", "unstructured_domain") + unstructured_domain, unstructured_domain_get = construct_domains( + sizes["out"], "out", "unstructured_domain" + ) - testee = itir.Program( - id="test", - function_definitions=[], + testee = program_factory( params=[im.sym("inp"), im.sym("out")], declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)( - im.ref("inp") - ), + expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)(im.ref("inp")), domain=unstructured_domain_get, target=im.ref("out"), ), ], ) - expected = itir.Program( - id="test", - function_definitions=[], + expected = program_factory( params=[im.sym("inp"), im.sym("out")], declarations=[], body=[ @@ -115,19 +158,21 @@ def test_get_domain_tuples(): sizes = {"out": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)}))} unstructured_domain_get = im.call("unstructured_domain")( - im.named_range(im.axis_literal(Vertex), - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))) - ) + im.named_range( + im.axis_literal(Vertex), + im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), + im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), + ) ) unstructured_domain = im.call("unstructured_domain")( - im.named_range(im.axis_literal(Vertex), im.tuple_get(0, im.make_tuple(0, 5)), - im.tuple_get(1, im.make_tuple(0, 5))), + im.named_range( + im.axis_literal(Vertex), + im.tuple_get(0, im.make_tuple(0, 5)), + im.tuple_get(1, im.make_tuple(0, 5)), + ), ) - testee = itir.Program( - id="test", - function_definitions=[], + testee = program_factory( params=[im.sym("inp"), im.sym("out")], declarations=[], body=[ @@ -139,9 +184,7 @@ def test_get_domain_tuples(): ], ) - expected = itir.Program( - id="test", - function_definitions=[], + expected = program_factory( params=[im.sym("inp"), im.sym("out")], declarations=[], body=[ @@ -164,19 +207,21 @@ def test_get_domain_nested_tuples(): t1 = im.make_tuple("c", "d") tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) unstructured_domain_get = im.call("unstructured_domain")( - im.named_range(im.axis_literal(KDim), - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))) - ) + im.named_range( + im.axis_literal(KDim), + im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + ) ) unstructured_domain = im.call("unstructured_domain")( - im.named_range(im.axis_literal(KDim), im.tuple_get(0, im.make_tuple(0, 3)), - im.tuple_get(1, im.make_tuple(0, 3))), + im.named_range( + im.axis_literal(KDim), + im.tuple_get(0, im.make_tuple(0, 3)), + im.tuple_get(1, im.make_tuple(0, 3)), + ), ) - testee = itir.Program( - id="test", - function_definitions=[], + testee = program_factory( params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], declarations=[], body=[ @@ -188,9 +233,7 @@ def test_get_domain_nested_tuples(): ], ) - expected = itir.Program( - id="test", - function_definitions=[], + expected = program_factory( params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], declarations=[], body=[ @@ -204,3 +247,531 @@ def test_get_domain_nested_tuples(): actual = TransformGetDomain.apply(testee, sizes=sizes) assert actual == expected + + +@pytest.fixture +def testee(): + @gtx.field_operator + def testee_tmp(x: cases.IField) -> cases.IField: + y = x(IOff[2]) + return y(IOff[3]) + + # @gtx.field_operator + # def testee_tmp(x: cases.IField) -> cases.IField: + # return x(IOff[1]) + + # @gtx.field_operator + # def testee_op(x: cases.IField) -> cases.IField: + # return testee_tmp(x) + + @gtx.field_operator + def testee_op(x: cases.IField) -> cases.IField: + return testee_tmp(x) + testee_tmp(x) + + @gtx.program(static_domain_sizes=True, grid_type=gtx.GridType.UNSTRUCTURED) + def prog( + inp: cases.IField, + out: cases.IField, + ): + testee_op(inp, out=out) + + return prog + + +def test_get_domain_inference_temporary_symbols(testee, unstructured_case): + sizes = {"out": gtx.domain({IDim: (0, 20)})} + ir = inline_fundefs.InlineFundefs().visit(testee.gtir) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = type_inference.infer( + ir, offset_provider_type={**unstructured_case.offset_provider, "IOff": IDim} + ) + ir = inline_lambdas.InlineLambdas.apply(ir) + ir = global_tmps.create_global_tmps( + ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} + ) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program( + # ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} + # ) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + unstructured_domain = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(IDim), + im.tuple_get(0, im.make_tuple(0, 20)), + im.tuple_get(1, im.make_tuple(0, 20)), + ) + ) + + unstructured_domain_p3 = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(IDim), + im.plus(im.tuple_get(0, im.make_tuple(0, 20)), 3), + im.plus(im.tuple_get(1, im.make_tuple(0, 20)), 3), + ) + ) + + expected = itir.Program( + id="prog", + function_definitions=[], + params=[im.sym("inp"), im.sym("out"), im.sym("__inp_0_range"), im.sym("__out_0_range")], + declarations=[ + itir.Temporary(id="__tmp_1", domain=unstructured_domain, dtype=int_type), + itir.Temporary(id="__tmp_2", domain=unstructured_domain_p3, dtype=int_type), + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_2"), + expr=im.as_fieldop( + im.lambda_("__it")(im.deref(im.shift("IOff", 2)("__it"))), + unstructured_domain_p3, + )("inp"), + domain=unstructured_domain_p3, + ), + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("__it_")(im.deref(im.shift("IOff", 3)("__it_"))), unstructured_domain + )("__tmp_2"), + domain=unstructured_domain, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("__arg0", "__arg1")(im.plus(im.deref("__arg0"), im.deref("__arg1"))), + unstructured_domain, + )("__tmp_1", "__tmp_1"), + domain=unstructured_domain, + ), + ], + ) + assert ir == expected + + +def test_trivial_shift(unstructured_case): + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + unstructured_domain_get_E = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Edge), + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + ) + + unstructured_domain_E = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Edge), + im.tuple_get(0, im.make_tuple(0, 18)), + im.tuple_get(1, im.make_tuple(0, 18)), + ) + ) + + unstructured_domain_V_p1_expected = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 1, 9), + ) + + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + im.as_fieldop("deref")("vertex_values") + ), + domain=unstructured_domain_get_E, + ) + ], + ) + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_p1_expected, dtype=float_type) + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", unstructured_domain_V_p1_expected)("vertex_values"), + domain=unstructured_domain_V_p1_expected, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E + )("__tmp_1"), + domain=unstructured_domain_E, + ), + ], + ) + + assert ir == expected + + +def test_trivial_shift_switched(unstructured_case): + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + unstructured_domain_get_E = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Edge), + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + ) + + unstructured_domain_E = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Edge), + im.tuple_get(0, im.make_tuple(0, 18)), + im.tuple_get(1, im.make_tuple(0, 18)), + ) + ) + + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref")( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + "vertex_values" + ) + ), + domain=unstructured_domain_get_E, + ) + ], + ) + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_E, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E + )("vertex_values"), + domain=unstructured_domain_E, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", unstructured_domain_E)("__tmp_1"), + domain=unstructured_domain_E, + ), + ], + ) + + assert ir == expected + + +def test_two_shifts(unstructured_case): + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + unstructured_domain_get_E = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Edge), + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + ) + + unstructured_domain_E = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Edge), + im.tuple_get(0, im.make_tuple(0, 18)), + im.tuple_get(1, im.make_tuple(0, 18)), + ) + ) + + unstructured_domain_V_expected = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 0, 9), + ) + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) + ) + ) + )(im.as_fieldop("deref")("vertex_values")), + domain=unstructured_domain_get_E, + ) + ], + ) + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_expected, dtype=float_type) + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", unstructured_domain_V_expected)("vertex_values"), + domain=unstructured_domain_V_expected, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) + ) + ), + unstructured_domain_E, + )("__tmp_1"), + domain=unstructured_domain_E, + ), + ], + ) + + assert ir == expected + + +def test_nested_shift(unstructured_case): + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + unstructured_domain_V = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Vertex), + im.tuple_get(0, im.make_tuple(0, 9)), + im.tuple_get(1, im.make_tuple(0, 9)), + ) + ) + unstructured_domain_get_V = im.call("unstructured_domain")( + im.call("named_range")( + im.axis_literal(Vertex), + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))), + ) + ) + + unstructured_domain_V_p1_expected = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 1, 9), + ) + + unstructured_domain_E_918_expected = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Edge), 9, 18), + ) + + offset_provider = unstructured_case.offset_provider + + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", v_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))))( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + im.as_fieldop("deref")("vertex_values") + ) + ), + domain=unstructured_domain_get_V, + ) + ], + ) + + # testee = program_factory( + # params=[im.sym("vertex_values", v_field_type), im.sym("out", v_field_type)], + # body=[ + # itir.SetAt( + # target=im.ref("out"), + # expr=im.as_fieldop( + # im.lambda_("x")(im.deref(im.shift("E2V", 1)( + # im.shift("V2E", 1)("x")))))(im.as_fieldop("deref")("vertex_values")), + # domain=unstructured_domain_get_V, # TODO: why is the order switched in here? + # ) + # ], + # ) + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[ + itir.Temporary( + id="__tmp_1", domain=unstructured_domain_E_918_expected, dtype=float_type + ), + itir.Temporary( + id="__tmp_2", domain=unstructured_domain_V_p1_expected, dtype=float_type + ), + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_2"), + expr=im.as_fieldop("deref", unstructured_domain_V_p1_expected)("vertex_values"), + domain=unstructured_domain_V_p1_expected, + ), + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), + unstructured_domain_E_918_expected, + )("__tmp_2"), + domain=unstructured_domain_E_918_expected, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))), unstructured_domain_V + )("__tmp_1"), + domain=unstructured_domain_V, + ), + ], + ) + + assert ir == expected + + +def test_trivial_cartesian(): # TODO: fix/remove? + grid = simple_cartesian_grid() + offset_provider = {"Ioff": grid.offset_provider["Ioff"]} + sizes = {"out": gtx.domain({IDim: (0, 8)})} + cartesian_domain, cartesian_domain_get = construct_domains( + sizes["out"], "out", "cartesian_domain" + ) + cartesian_domain_p1 = im.call("cartesian_domain")( + im.call("named_range")( + im.axis_literal(IDim), + im.plus(im.tuple_get(0, im.make_tuple(0, 8)), 1), + im.plus(im.tuple_get(1, im.make_tuple(0, 8)), 1), + ) + ) + + testee = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))))( + im.as_fieldop("deref")("i_values") + ), + domain=cartesian_domain_get, + ) + ], + ) + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + + expected = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_p1, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", cartesian_domain_p1)("i_values"), + domain=cartesian_domain_p1, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))), cartesian_domain + )("__tmp_1"), + domain=cartesian_domain, + ), + ], + ) + + assert ir == expected + + +def test_trivial_cartesian_forward(): # TODO: fix/remove? + grid = simple_cartesian_grid() + offset_provider = {"Ioff": grid.offset_provider["Ioff"]} + sizes = {"out": gtx.domain({IDim: (0, 8)})} + + cartesian_domain_get = im.call("cartesian_domain")( + im.call("named_range")( + im.axis_literal(IDim), + im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), + im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4), + ) + ) + testee = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), + )( + im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), + )("i_values") + ), + domain=cartesian_domain_get, + ) + ], + ) + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = TransformGetDomain.apply(ir, sizes=sizes) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + + cartesian_domain_m2 = im.call("cartesian_domain")( + im.call("named_range")( + im.axis_literal(IDim), + im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 2), + im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 2), + ) + ) + + cartesian_domain_m4 = im.call("cartesian_domain")( + im.call("named_range")( + im.axis_literal(IDim), + im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 4), + im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 4), + ) + ) + + expected = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_m2, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m2 + )("i_values"), + domain=cartesian_domain_m2, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m4 + )("__tmp_1"), + domain=cartesian_domain_m4, + ), + ], + ) + + assert ir == expected From 0a84e7ea082fbe821bf8dae956d1f0fb59a767b0 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 24 Jul 2025 18:21:27 +0200 Subject: [PATCH 08/93] Fix tests --- .../transforms_tests/test_transform_get_domain.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 9eef26c8c9..1db04014f7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -9,6 +9,7 @@ import pytest from next_tests.integration_tests.cases import ( + IField, IDim, KDim, Vertex, @@ -24,7 +25,6 @@ from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts -from tests.next_tests.integration_tests import cases from tests.next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( simple_cartesian_grid, Edge, @@ -252,26 +252,26 @@ def test_get_domain_nested_tuples(): @pytest.fixture def testee(): @gtx.field_operator - def testee_tmp(x: cases.IField) -> cases.IField: + def testee_tmp(x: IField) -> IField: y = x(IOff[2]) return y(IOff[3]) # @gtx.field_operator - # def testee_tmp(x: cases.IField) -> cases.IField: + # def testee_tmp(x: IField) -> IField: # return x(IOff[1]) # @gtx.field_operator - # def testee_op(x: cases.IField) -> cases.IField: + # def testee_op(x: IField) -> IField: # return testee_tmp(x) @gtx.field_operator - def testee_op(x: cases.IField) -> cases.IField: + def testee_op(x: IField) -> IField: return testee_tmp(x) + testee_tmp(x) @gtx.program(static_domain_sizes=True, grid_type=gtx.GridType.UNSTRUCTURED) def prog( - inp: cases.IField, - out: cases.IField, + inp: IField, + out: IField, ): testee_op(inp, out=out) From 81f9de61092bffb74046112045cb3c222ea0e94a Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 24 Jul 2025 20:03:50 +0200 Subject: [PATCH 09/93] Only compute where values are used --- .../next/iterator/ir_utils/domain_utils.py | 23 +++- .../test_transform_get_domain.py | 100 +++++++++--------- 2 files changed, 68 insertions(+), 55 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index be9f19c560..a1ab2c8daa 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -17,7 +17,7 @@ from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms import collapse_tuple, trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -149,10 +149,23 @@ def translate( ) for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - min_ = np.min( - offset_provider[off.value].ndarray[:, val.value] - ) # TODO: multible shifts? multible occurences of that dimension? - max_ = np.max(offset_provider[off.value].ndarray[:, val.value]) + 1 + start = 0 + stop = -1 + start_ = collapse_tuple.CollapseTuple.apply( + new_ranges[old_dim].start, + within_stencil=False, + allow_undeclared_symbols=True, + ) + stop_ = collapse_tuple.CollapseTuple.apply( + new_ranges[old_dim].stop, + within_stencil=False, + allow_undeclared_symbols=True, + ) + if isinstance(start_, itir.Literal) and isinstance(stop_, itir.Literal): + start = int(start_.value) + stop = int(stop_.value) + min_ = np.min(offset_provider[off.value].ndarray[start:stop, val.value]) + max_ = np.max(offset_provider[off.value].ndarray[start:stop, val.value]) + 1 horizontal_sizes[new_dim.value] = ( im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), im.literal(str(max_), builtins.INTEGER_INDEX_BUILTIN), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 1db04014f7..1b12b28974 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -21,7 +21,7 @@ from gt4py.next import Domain from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import inline_fundefs, global_tmps, inline_lambdas +from gt4py.next.iterator.transforms import inline_fundefs, global_tmps, inline_lambdas, infer_domain from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts @@ -286,13 +286,13 @@ def test_get_domain_inference_temporary_symbols(testee, unstructured_case): ir, offset_provider_type={**unstructured_case.offset_provider, "IOff": IDim} ) ir = inline_lambdas.InlineLambdas.apply(ir) + ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps( ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} ) - ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program( - # ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} - # ) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = infer_domain.infer_program( + ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} + ) unstructured_domain = im.call("unstructured_domain")( im.call("named_range")( im.axis_literal(IDim), @@ -347,7 +347,7 @@ def test_get_domain_inference_temporary_symbols(testee, unstructured_case): def test_trivial_shift(unstructured_case): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} unstructured_domain_get_E = im.call("unstructured_domain")( im.call("named_range")( im.axis_literal(Edge), @@ -359,15 +359,19 @@ def test_trivial_shift(unstructured_case): unstructured_domain_E = im.call("unstructured_domain")( im.call("named_range")( im.axis_literal(Edge), - im.tuple_get(0, im.make_tuple(0, 18)), - im.tuple_get(1, im.make_tuple(0, 18)), + im.tuple_get(0, im.make_tuple(9, 13)), + im.tuple_get(1, im.make_tuple(9, 13)), ) ) - unstructured_domain_V_p1_expected = im.call("unstructured_domain")( + unstructured_domain_V_p1 = im.call("unstructured_domain")( im.call("named_range")(im.axis_literal(Vertex), 1, 9), ) + unstructured_domain_V_37 = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 3, 7), + ) + offset_provider = unstructured_case.offset_provider testee = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -384,20 +388,20 @@ def test_trivial_shift(unstructured_case): ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_p1_expected, dtype=float_type) + itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_37, dtype=float_type) ], body=[ itir.SetAt( target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V_p1_expected)("vertex_values"), - domain=unstructured_domain_V_p1_expected, + expr=im.as_fieldop("deref", unstructured_domain_V_37)("vertex_values"), + domain=unstructured_domain_V_37, ), itir.SetAt( target=im.ref("out"), @@ -448,9 +452,9 @@ def test_trivial_shift_switched(unstructured_case): ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -492,7 +496,7 @@ def test_two_shifts(unstructured_case): ) ) - unstructured_domain_V_expected = im.call("unstructured_domain")( + unstructured_domain_V = im.call("unstructured_domain")( im.call("named_range")(im.axis_literal(Vertex), 0, 9), ) offset_provider = unstructured_case.offset_provider @@ -515,20 +519,18 @@ def test_two_shifts(unstructured_case): ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_expected, dtype=float_type) - ], + declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_V, dtype=float_type)], body=[ itir.SetAt( target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V_expected)("vertex_values"), - domain=unstructured_domain_V_expected, + expr=im.as_fieldop("deref", unstructured_domain_V)("vertex_values"), + domain=unstructured_domain_V, ), itir.SetAt( target=im.ref("out"), @@ -565,11 +567,11 @@ def test_nested_shift(unstructured_case): ) ) - unstructured_domain_V_p1_expected = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 1, 9), + unstructured_domain_V_39 = im.call("unstructured_domain")( + im.call("named_range")(im.axis_literal(Vertex), 3, 9), ) - unstructured_domain_E_918_expected = im.call("unstructured_domain")( + unstructured_domain_E_918 = im.call("unstructured_domain")( im.call("named_range")(im.axis_literal(Edge), 9, 18), ) @@ -605,33 +607,29 @@ def test_nested_shift(unstructured_case): ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], declarations=[ - itir.Temporary( - id="__tmp_1", domain=unstructured_domain_E_918_expected, dtype=float_type - ), - itir.Temporary( - id="__tmp_2", domain=unstructured_domain_V_p1_expected, dtype=float_type - ), + itir.Temporary(id="__tmp_1", domain=unstructured_domain_E_918, dtype=float_type), + itir.Temporary(id="__tmp_2", domain=unstructured_domain_V_39, dtype=float_type), ], body=[ itir.SetAt( target=im.ref("__tmp_2"), - expr=im.as_fieldop("deref", unstructured_domain_V_p1_expected)("vertex_values"), - domain=unstructured_domain_V_p1_expected, + expr=im.as_fieldop("deref", unstructured_domain_V_39)("vertex_values"), + domain=unstructured_domain_V_39, ), itir.SetAt( target=im.ref("__tmp_1"), expr=im.as_fieldop( im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), - unstructured_domain_E_918_expected, + unstructured_domain_E_918, )("__tmp_2"), - domain=unstructured_domain_E_918_expected, + domain=unstructured_domain_E_918, ), itir.SetAt( target=im.ref("out"), @@ -649,15 +647,15 @@ def test_nested_shift(unstructured_case): def test_trivial_cartesian(): # TODO: fix/remove? grid = simple_cartesian_grid() offset_provider = {"Ioff": grid.offset_provider["Ioff"]} - sizes = {"out": gtx.domain({IDim: (0, 8)})} + sizes = {"out": gtx.domain({IDim: (2, 7)})} cartesian_domain, cartesian_domain_get = construct_domains( sizes["out"], "out", "cartesian_domain" ) - cartesian_domain_p1 = im.call("cartesian_domain")( + cartesian_domain_27_p1 = im.call("cartesian_domain")( im.call("named_range")( im.axis_literal(IDim), - im.plus(im.tuple_get(0, im.make_tuple(0, 8)), 1), - im.plus(im.tuple_get(1, im.make_tuple(0, 8)), 1), + im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), + im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1), ) ) @@ -676,18 +674,20 @@ def test_trivial_cartesian(): # TODO: fix/remove? ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) expected = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_p1, dtype=float_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=cartesian_domain_27_p1, dtype=float_type) + ], body=[ itir.SetAt( target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", cartesian_domain_p1)("i_values"), - domain=cartesian_domain_p1, + expr=im.as_fieldop("deref", cartesian_domain_27_p1)("i_values"), + domain=cartesian_domain_27_p1, ), itir.SetAt( target=im.ref("out"), @@ -733,9 +733,9 @@ def test_trivial_cartesian_forward(): # TODO: fix/remove? ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) ir = TransformGetDomain.apply(ir, sizes=sizes) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) cartesian_domain_m2 = im.call("cartesian_domain")( im.call("named_range")( From c291c26c793677121094f9f17236aa47693a9a01 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Jul 2025 09:59:58 +0200 Subject: [PATCH 10/93] Fix some tests --- .../next/iterator/ir_utils/domain_utils.py | 9 +++++-- .../test_transform_get_domain.py | 24 ++++++++----------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index a1ab2c8daa..5239859138 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -164,8 +164,13 @@ def translate( if isinstance(start_, itir.Literal) and isinstance(stop_, itir.Literal): start = int(start_.value) stop = int(stop_.value) - min_ = np.min(offset_provider[off.value].ndarray[start:stop, val.value]) - max_ = np.max(offset_provider[off.value].ndarray[start:stop, val.value]) + 1 + + off_index = ( + slice(None) if val == trace_shifts.Sentinel.ALL_NEIGHBORS else val.value + ) + accessed = offset_provider[off.value].ndarray[start:stop, off_index] + min_ = np.min(accessed) + max_ = np.max(accessed) + 1 horizontal_sizes[new_dim.value] = ( im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), im.literal(str(max_), builtins.INTEGER_INDEX_BUILTIN), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 1b12b28974..441678223e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -25,7 +25,7 @@ from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts -from tests.next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( simple_cartesian_grid, Edge, simple_mesh, @@ -290,9 +290,9 @@ def test_get_domain_inference_temporary_symbols(testee, unstructured_case): ir = global_tmps.create_global_tmps( ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} ) - ir = infer_domain.infer_program( - ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} - ) + # ir = infer_domain.infer_program( + # ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} + # ) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps unstructured_domain = im.call("unstructured_domain")( im.call("named_range")( im.axis_literal(IDim), @@ -364,10 +364,6 @@ def test_trivial_shift(unstructured_case): ) ) - unstructured_domain_V_p1 = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 1, 9), - ) - unstructured_domain_V_37 = im.call("unstructured_domain")( im.call("named_range")(im.axis_literal(Vertex), 3, 7), ) @@ -390,7 +386,7 @@ def test_trivial_shift(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -454,7 +450,7 @@ def test_trivial_shift_switched(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -521,7 +517,7 @@ def test_two_shifts(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -609,7 +605,7 @@ def test_nested_shift(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -676,7 +672,7 @@ def test_trivial_cartesian(): # TODO: fix/remove? ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], @@ -735,7 +731,7 @@ def test_trivial_cartesian_forward(): # TODO: fix/remove? ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps cartesian_domain_m2 = im.call("cartesian_domain")( im.call("named_range")( From 649363d5cca20d601e89d47b8bd2278ac4a51506 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Jul 2025 12:16:57 +0200 Subject: [PATCH 11/93] Update tests --- .../test_transform_get_domain.py | 119 +----------------- 1 file changed, 2 insertions(+), 117 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 441678223e..30d8dd17ac 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -249,103 +249,6 @@ def test_get_domain_nested_tuples(): assert actual == expected -@pytest.fixture -def testee(): - @gtx.field_operator - def testee_tmp(x: IField) -> IField: - y = x(IOff[2]) - return y(IOff[3]) - - # @gtx.field_operator - # def testee_tmp(x: IField) -> IField: - # return x(IOff[1]) - - # @gtx.field_operator - # def testee_op(x: IField) -> IField: - # return testee_tmp(x) - - @gtx.field_operator - def testee_op(x: IField) -> IField: - return testee_tmp(x) + testee_tmp(x) - - @gtx.program(static_domain_sizes=True, grid_type=gtx.GridType.UNSTRUCTURED) - def prog( - inp: IField, - out: IField, - ): - testee_op(inp, out=out) - - return prog - - -def test_get_domain_inference_temporary_symbols(testee, unstructured_case): - sizes = {"out": gtx.domain({IDim: (0, 20)})} - ir = inline_fundefs.InlineFundefs().visit(testee.gtir) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = type_inference.infer( - ir, offset_provider_type={**unstructured_case.offset_provider, "IOff": IDim} - ) - ir = inline_lambdas.InlineLambdas.apply(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps( - ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} - ) - # ir = infer_domain.infer_program( - # ir, offset_provider={**unstructured_case.offset_provider, "IOff": IDim} - # ) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps - unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(IDim), - im.tuple_get(0, im.make_tuple(0, 20)), - im.tuple_get(1, im.make_tuple(0, 20)), - ) - ) - - unstructured_domain_p3 = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(IDim), - im.plus(im.tuple_get(0, im.make_tuple(0, 20)), 3), - im.plus(im.tuple_get(1, im.make_tuple(0, 20)), 3), - ) - ) - - expected = itir.Program( - id="prog", - function_definitions=[], - params=[im.sym("inp"), im.sym("out"), im.sym("__inp_0_range"), im.sym("__out_0_range")], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain, dtype=int_type), - itir.Temporary(id="__tmp_2", domain=unstructured_domain_p3, dtype=int_type), - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_2"), - expr=im.as_fieldop( - im.lambda_("__it")(im.deref(im.shift("IOff", 2)("__it"))), - unstructured_domain_p3, - )("inp"), - domain=unstructured_domain_p3, - ), - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("__it_")(im.deref(im.shift("IOff", 3)("__it_"))), unstructured_domain - )("__tmp_2"), - domain=unstructured_domain, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("__arg0", "__arg1")(im.plus(im.deref("__arg0"), im.deref("__arg1"))), - unstructured_domain, - )("__tmp_1", "__tmp_1"), - domain=unstructured_domain, - ), - ], - ) - assert ir == expected - - def test_trivial_shift(unstructured_case): sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} unstructured_domain_get_E = im.call("unstructured_domain")( @@ -386,7 +289,6 @@ def test_trivial_shift(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -450,7 +352,6 @@ def test_trivial_shift_switched(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -517,7 +418,6 @@ def test_two_shifts(unstructured_case): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -588,24 +488,10 @@ def test_nested_shift(unstructured_case): ], ) - # testee = program_factory( - # params=[im.sym("vertex_values", v_field_type), im.sym("out", v_field_type)], - # body=[ - # itir.SetAt( - # target=im.ref("out"), - # expr=im.as_fieldop( - # im.lambda_("x")(im.deref(im.shift("E2V", 1)( - # im.shift("V2E", 1)("x")))))(im.as_fieldop("deref")("vertex_values")), - # domain=unstructured_domain_get_V, # TODO: why is the order switched in here? - # ) - # ], - # ) - ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], @@ -640,7 +526,7 @@ def test_nested_shift(unstructured_case): assert ir == expected -def test_trivial_cartesian(): # TODO: fix/remove? +def test_trivial_cartesian(): grid = simple_cartesian_grid() offset_provider = {"Ioff": grid.offset_provider["Ioff"]} sizes = {"out": gtx.domain({IDim: (2, 7)})} @@ -698,7 +584,7 @@ def test_trivial_cartesian(): # TODO: fix/remove? assert ir == expected -def test_trivial_cartesian_forward(): # TODO: fix/remove? +def test_trivial_cartesian_forward(): grid = simple_cartesian_grid() offset_provider = {"Ioff": grid.offset_provider["Ioff"]} sizes = {"out": gtx.domain({IDim: (0, 8)})} @@ -731,7 +617,6 @@ def test_trivial_cartesian_forward(): # TODO: fix/remove? ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps cartesian_domain_m2 = im.call("cartesian_domain")( im.call("named_range")( From 453bc0c235726f04aaca0d7ea9530985665d9dd9 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Jul 2025 12:39:18 +0200 Subject: [PATCH 12/93] Minor --- .../iterator_tests/transforms_tests/test_transform_get_domain.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 30d8dd17ac..270360401d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -558,7 +558,6 @@ def test_trivial_cartesian(): ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - # ir = infer_domain.infer_program(ir, offset_provider=offset_provider) # TODO: domain inference does not seem to be necessary anymore since it is already done in create_global_tmps expected = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], From 373181042f631f36e470ec1e41c3b9c0953b9eaa Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 30 Jul 2025 09:23:00 +0200 Subject: [PATCH 13/93] Get domain from tuple element --- src/gt4py/next/ffront/past_to_itir.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 04265985bc..1cd8f65b90 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,11 +9,13 @@ from __future__ import annotations import dataclasses +import functools from typing import Any, Optional, cast import devtools from gt4py.eve import NodeTranslator, concepts, traits +from gt4py.eve import utils as eve_utils from gt4py.next import common, config, errors from gt4py.next.ffront import ( fbuiltins, @@ -362,13 +364,17 @@ def _construct_itir_domain_arg( " fields defined on the same dimensions. This error should be " " caught in type deduction already." ) + # if the out_field is a (potentially nested) tuple we get the domain from its first + # element + first_out_el_path = eve_utils.first(type_info.primitive_constituents(out_field.type, with_path_arg=True))[1] + first_out_el = functools.reduce(lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id) domain_args = [] domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension dim_range = im.call("get_domain")( - out_field.id, itir.AxisLiteral(value=dim.value, kind=dim.kind) + first_out_el, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds From 2c486edb6bbe872feae0e75248415d14390083af Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 6 Aug 2025 20:42:42 +0200 Subject: [PATCH 14/93] Refactor tests --- .../test_temporary_domain_inference.py | 411 ++++++++++++ .../test_transform_get_domain.py | 631 +++--------------- 2 files changed, 488 insertions(+), 554 deletions(-) create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py new file mode 100644 index 0000000000..ea630681a7 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -0,0 +1,411 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Optional, Dict + +import pytest +from next_tests.integration_tests.cases import ( + IDim, + Vertex, + unstructured_case, + exec_alloc_descriptor, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + simple_cartesian_grid, + Edge, + simple_mesh +) + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + ir_makers as im, +) +from gt4py.next.iterator.transforms import inline_fundefs, global_tmps +from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain +from gt4py.next.type_system import type_specifications as ts + +IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) + +float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +v_field_type = ts.FieldType(dims=[Vertex], dtype=float_type) +e_field_type = ts.FieldType(dims=[Edge], dtype=float_type) +i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) + + +# override mesh descriptor to contain only the simple mesh +@pytest.fixture +def mesh_descriptor(exec_alloc_descriptor): + return simple_mesh(exec_alloc_descriptor.allocator) + + +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=params, + declarations=declarations or [], + body=body, + ) + + +def run_test_program( + testee: itir.Program, expected: itir.Program, sizes: Dict[str, common.Domain], offset_provider: common.OffsetProvider +) -> None: + + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = TransformGetDomain.apply(ir, sizes=sizes) + actual_program = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + + assert actual_program == expected + + +def test_trivial_shift(unstructured_case): + sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} + unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + ) + + unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.make_tuple(9, 13)), + im.tuple_get(1, im.make_tuple(9, 13)))} + ) + + unstructured_domain_V_37 = im.domain(common.GridType.UNSTRUCTURED,{Vertex: (3, 7)}) + + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + im.as_fieldop("deref")("vertex_values") + ), + domain=unstructured_domain_get_E, + ) + ], + ) + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_37, dtype=float_type) + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", unstructured_domain_V_37)("vertex_values"), + domain=unstructured_domain_V_37, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E + )("__tmp_1"), + domain=unstructured_domain_E, + ), + ], + ) + + run_test_program(testee, expected, sizes, offset_provider) + + +def test_trivial_shift_switched(unstructured_case): + sizes = {"out": gtx.domain({Edge: (2, 16), Vertex: (0, 9)})} + unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + ) + + unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.make_tuple(2, 16)), + im.tuple_get(1, im.make_tuple(2, 16)))} + ) + + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref")( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + "vertex_values" + ) + ), + domain=unstructured_domain_get_E, + ) + ], + ) + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_E, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E + )("vertex_values"), + domain=unstructured_domain_E, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", unstructured_domain_E)("__tmp_1"), + domain=unstructured_domain_E, + ), + ], + ) + + run_test_program(testee, expected, sizes, offset_provider) + + +def test_two_shifts(unstructured_case): + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + ) + + unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.make_tuple(0, 18)), + im.tuple_get(1, im.make_tuple(0, 18)))} + ) + + unstructured_domain_V = im.domain(common.GridType.UNSTRUCTURED,{Vertex: ( 0, 9)}) + + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) + ) + ) + )(im.as_fieldop("deref")("vertex_values")), + domain=unstructured_domain_get_E, + ) + ], + ) + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_V, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", unstructured_domain_V)("vertex_values"), + domain=unstructured_domain_V, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")( + im.plus( + im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) + ) + ), + unstructured_domain_E, + )("__tmp_1"), + domain=unstructured_domain_E, + ), + ], + ) + + run_test_program(testee, expected, sizes, offset_provider) + + +def test_nested_shift(unstructured_case): + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + unstructured_domain_V = im.domain(common.GridType.UNSTRUCTURED, + {Vertex: (im.tuple_get(0, im.make_tuple(0, 9)), + im.tuple_get(1, im.make_tuple(0, 9)))} + ) + unstructured_domain_get_V = im.domain(common.GridType.UNSTRUCTURED, + {Vertex: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))))} + ) + + unstructured_domain_V_39 = im.domain(common.GridType.UNSTRUCTURED,{Vertex: ( 3, 9)}) + + unstructured_domain_E_918 = im.domain(common.GridType.UNSTRUCTURED,{Edge: ( 9, 18)}) + + offset_provider = unstructured_case.offset_provider + + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", v_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))))( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + im.as_fieldop("deref")("vertex_values") + ) + ), + domain=unstructured_domain_get_V, + ) + ], + ) + + expected = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=unstructured_domain_E_918, dtype=float_type), + itir.Temporary(id="__tmp_2", domain=unstructured_domain_V_39, dtype=float_type), + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_2"), + expr=im.as_fieldop("deref", unstructured_domain_V_39)("vertex_values"), + domain=unstructured_domain_V_39, + ), + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), + unstructured_domain_E_918, + )("__tmp_2"), + domain=unstructured_domain_E_918, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))), unstructured_domain_V + )("__tmp_1"), + domain=unstructured_domain_V, + ), + ], + ) + + run_test_program(testee, expected, sizes, offset_provider) + + +def test_trivial_cartesian(): + grid = simple_cartesian_grid() + offset_provider = {"Ioff": grid.offset_provider["Ioff"]} + sizes = {"out": gtx.domain({IDim: (2, 7)})} + + cartesian_domain = im.domain(common.GridType.CARTESIAN, + {IDim: (im.tuple_get(0, im.make_tuple(2, 7)), + im.tuple_get(1, im.make_tuple(2, 7)))} + ) + cartesian_domain_get = im.domain(common.GridType.CARTESIAN, + {IDim: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))))} + ) + + cartesian_domain_27_p1 = im.domain(common.GridType.CARTESIAN, + {IDim: (im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), + im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1))} + ) + testee = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))))( + im.as_fieldop("deref")("i_values") + ), + domain=cartesian_domain_get, + ) + ], + ) + + expected = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=cartesian_domain_27_p1, dtype=float_type) + ], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", cartesian_domain_27_p1)("i_values"), + domain=cartesian_domain_27_p1, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))), cartesian_domain + )("__tmp_1"), + domain=cartesian_domain, + ), + ], + ) + + run_test_program(testee, expected, sizes, offset_provider) + + +def test_trivial_cartesian_forward(): + grid = simple_cartesian_grid() + offset_provider = {"Ioff": grid.offset_provider["Ioff"]} + sizes = {"out": gtx.domain({IDim: (0, 8)})} + + cartesian_domain_get = im.domain(common.GridType.CARTESIAN, + {IDim: (im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), + im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4))} + ) + testee = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), + )( + im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), + )("i_values") + ), + domain=cartesian_domain_get, + ) + ], + ) + + cartesian_domain_m2 = im.domain(common.GridType.CARTESIAN, + {IDim: (im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 2), + im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 2))} + ) + + cartesian_domain_m4 = im.domain(common.GridType.CARTESIAN, + {IDim: (im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 4), + im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 4))} + ) + expected = program_factory( + params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_m2, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m2 + )("i_values"), + domain=cartesian_domain_m2, + ), + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop( + im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m4 + )("__tmp_1"), + domain=cartesian_domain_m4, + ), + ], + ) + + run_test_program(testee, expected, sizes, offset_provider) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 270360401d..be45b6980f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional +from typing import Dict, Union import pytest from next_tests.integration_tests.cases import ( @@ -18,12 +18,10 @@ ) from gt4py import next as gtx -from gt4py.next import Domain +from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import inline_fundefs, global_tmps, inline_lambdas, infer_domain from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain -from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( simple_cartesian_grid, @@ -31,59 +29,63 @@ simple_mesh, ) -IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) - -float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) -v_field_type = ts.FieldType(dims=[Vertex], dtype=float_type) -ve_field_type = ts.FieldType(dims=[Edge, Vertex], dtype=float_type) -e_field_type = ts.FieldType(dims=[Edge], dtype=float_type) -i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) - - -# TODO: maybe check if domains are consistent in global_tmps - - -# override mesh descriptor to contain only the simple mesh -@pytest.fixture -def mesh_descriptor(exec_alloc_descriptor): - return simple_mesh(exec_alloc_descriptor.allocator) - def program_factory( - params: list[itir.Sym], + params: list[str], body: list[itir.SetAt], - declarations: Optional[list[itir.Temporary]] = None, ) -> itir.Program: return itir.Program( id="testee", function_definitions=[], - params=params, - declarations=declarations or [], + params=[im.sym(par) for par in params], + declarations=[], body=body, ) -def construct_domains(domain_resolved: Domain, symbol_name: str, type: str): - named_ranges_get, named_ranges_resolved = [], [] +def setat_factory( + domain: common.Domain, + target: str, +) -> itir.SetAt: + return itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=domain, + target=im.ref(target), + ) - for dim, range_ in zip(domain_resolved.dims, domain_resolved.ranges): - get_domain_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) - named_ranges_get.append( - im.named_range( - im.axis_literal(dim), - im.tuple_get(0, get_domain_call), - im.tuple_get(1, get_domain_call), - ) - ) - bounds_tuple = im.make_tuple(range_.start, range_.stop) - named_ranges_resolved.append( - im.named_range( - im.axis_literal(dim), im.tuple_get(0, bounds_tuple), im.tuple_get(1, bounds_tuple) - ) - ) - return im.call(type)(*named_ranges_resolved), im.call(type)(*named_ranges_get) +def run_test_program( + params: list[str], + sizes: Dict[str, common.Domain], + target: str, + domain: itir.Expr, + domain_get: itir.Expr, +) -> None: + testee = program_factory( + params=params, + body=[setat_factory(domain=domain_get, target=im.ref(target))], + ) + expected = program_factory( + params=params, + body=[setat_factory(domain=domain, target=im.ref(target))], + ) + actual = TransformGetDomain.apply(testee, sizes=sizes) + assert actual == expected + + +def construct_domains( + domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] +) -> tuple[itir.FunCall, itir.FunCall]: + ranges_get = {} + ragnes_resolved = {} + + for dim, r in zip(domain_resolved.dims, domain_resolved.ranges): + get_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) + ranges_get[dim] = (im.tuple_get(0, get_call), im.tuple_get(1, get_call)) + bounds = im.make_tuple(r.start, r.stop) + ragnes_resolved[dim] = (im.tuple_get(0, bounds), im.tuple_get(1, bounds)) + + return im.domain(type, ragnes_resolved), im.domain(type, ranges_get) def test_get_domain(): @@ -92,32 +94,7 @@ def test_get_domain(): sizes["out"], "out", "unstructured_domain" ) - testee = program_factory( - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain_get, - target=im.ref("out"), - ), - ], - ) - - expected = program_factory( - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain, - target=im.ref("out"), - ), - ], - ) - - actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected + run_test_program(["inp", "out"], sizes, "out", unstructured_domain, unstructured_domain_get) def test_get_domain_inside_as_fieldop(): @@ -127,8 +104,7 @@ def test_get_domain_inside_as_fieldop(): ) testee = program_factory( - params=[im.sym("inp"), im.sym("out")], - declarations=[], + params=["inp", "out"], body=[ itir.SetAt( expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)(im.ref("inp")), @@ -139,8 +115,7 @@ def test_get_domain_inside_as_fieldop(): ) expected = program_factory( - params=[im.sym("inp"), im.sym("out")], - declarations=[], + params=["inp", "out"], body=[ itir.SetAt( expr=im.as_fieldop(im.ref("deref"), unstructured_domain)(im.ref("inp")), @@ -157,47 +132,25 @@ def test_get_domain_inside_as_fieldop(): def test_get_domain_tuples(): sizes = {"out": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)}))} - unstructured_domain_get = im.call("unstructured_domain")( - im.named_range( - im.axis_literal(Vertex), - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), - ) + unstructured_domain_get = im.domain( + common.GridType.UNSTRUCTURED, + { + Vertex: ( + im.tuple_get( + 0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + ), + im.tuple_get( + 1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + ), + ) + }, ) - unstructured_domain = im.call("unstructured_domain")( - im.named_range( - im.axis_literal(Vertex), - im.tuple_get(0, im.make_tuple(0, 5)), - im.tuple_get(1, im.make_tuple(0, 5)), - ), + unstructured_domain = im.domain( + common.GridType.UNSTRUCTURED, + {Vertex: (im.tuple_get(0, im.make_tuple(0, 5)), im.tuple_get(1, im.make_tuple(0, 5)))}, ) - testee = program_factory( - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain_get, - target=im.tuple_get(0, "out"), - ), - ], - ) - - expected = program_factory( - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain, - target=im.tuple_get(0, "out"), - ), - ], - ) - - actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected + run_test_program(["inp", "out"], sizes, "out", unstructured_domain, unstructured_domain_get) def test_get_domain_nested_tuples(): @@ -206,452 +159,22 @@ def test_get_domain_nested_tuples(): t0 = im.make_tuple("a", "b") t1 = im.make_tuple("c", "d") tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - unstructured_domain_get = im.call("unstructured_domain")( - im.named_range( - im.axis_literal(KDim), - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), - ) - ) - unstructured_domain = im.call("unstructured_domain")( - im.named_range( - im.axis_literal(KDim), - im.tuple_get(0, im.make_tuple(0, 3)), - im.tuple_get(1, im.make_tuple(0, 3)), - ), - ) - - testee = program_factory( - params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain_get, - target=im.ref("a"), - ), - ], - ) - - expected = program_factory( - params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain, - target=im.ref("a"), - ), - ], - ) - - actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected - - -def test_trivial_shift(unstructured_case): - sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} - unstructured_domain_get_E = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Edge), - im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), - ) - ) - - unstructured_domain_E = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Edge), - im.tuple_get(0, im.make_tuple(9, 13)), - im.tuple_get(1, im.make_tuple(9, 13)), - ) - ) - - unstructured_domain_V_37 = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 3, 7), - ) - - offset_provider = unstructured_case.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - im.as_fieldop("deref")("vertex_values") - ), - domain=unstructured_domain_get_E, - ) - ], - ) - - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_37, dtype=float_type) - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V_37)("vertex_values"), - domain=unstructured_domain_V_37, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E - )("__tmp_1"), - domain=unstructured_domain_E, - ), - ], - ) - - assert ir == expected - - -def test_trivial_shift_switched(unstructured_case): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} - unstructured_domain_get_E = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Edge), - im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), - ) - ) - - unstructured_domain_E = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Edge), - im.tuple_get(0, im.make_tuple(0, 18)), - im.tuple_get(1, im.make_tuple(0, 18)), - ) - ) - offset_provider = unstructured_case.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref")( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - "vertex_values" - ) - ), - domain=unstructured_domain_get_E, - ) - ], - ) - - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_E, dtype=float_type)], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E - )("vertex_values"), - domain=unstructured_domain_E, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref", unstructured_domain_E)("__tmp_1"), - domain=unstructured_domain_E, - ), - ], - ) - - assert ir == expected - - -def test_two_shifts(unstructured_case): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} - unstructured_domain_get_E = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Edge), - im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), - ) - ) - - unstructured_domain_E = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Edge), - im.tuple_get(0, im.make_tuple(0, 18)), - im.tuple_get(1, im.make_tuple(0, 18)), - ) - ) - - unstructured_domain_V = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 0, 9), - ) - offset_provider = unstructured_case.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")( - im.plus( - im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) - ) - ) - )(im.as_fieldop("deref")("vertex_values")), - domain=unstructured_domain_get_E, + unstructured_domain_get = im.domain( + common.GridType.UNSTRUCTURED, + { + KDim: ( + im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), ) - ], + }, ) - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_V, dtype=float_type)], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V)("vertex_values"), - domain=unstructured_domain_V, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")( - im.plus( - im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) - ) - ), - unstructured_domain_E, - )("__tmp_1"), - domain=unstructured_domain_E, - ), - ], - ) - - assert ir == expected - - -def test_nested_shift(unstructured_case): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} - unstructured_domain_V = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Vertex), - im.tuple_get(0, im.make_tuple(0, 9)), - im.tuple_get(1, im.make_tuple(0, 9)), - ) - ) - unstructured_domain_get_V = im.call("unstructured_domain")( - im.call("named_range")( - im.axis_literal(Vertex), - im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))), - ) - ) - - unstructured_domain_V_39 = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Vertex), 3, 9), + unstructured_domain = im.domain( + common.GridType.UNSTRUCTURED, + {KDim: (im.tuple_get(0, im.make_tuple(0, 3)), im.tuple_get(1, im.make_tuple(0, 3)))}, ) - unstructured_domain_E_918 = im.call("unstructured_domain")( - im.call("named_range")(im.axis_literal(Edge), 9, 18), + run_test_program( + ["inp", "a", "b", "c", "d"], sizes, "a", unstructured_domain, unstructured_domain_get ) - - offset_provider = unstructured_case.offset_provider - - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", v_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))))( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - im.as_fieldop("deref")("vertex_values") - ) - ), - domain=unstructured_domain_get_V, - ) - ], - ) - - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_E_918, dtype=float_type), - itir.Temporary(id="__tmp_2", domain=unstructured_domain_V_39, dtype=float_type), - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_2"), - expr=im.as_fieldop("deref", unstructured_domain_V_39)("vertex_values"), - domain=unstructured_domain_V_39, - ), - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), - unstructured_domain_E_918, - )("__tmp_2"), - domain=unstructured_domain_E_918, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))), unstructured_domain_V - )("__tmp_1"), - domain=unstructured_domain_V, - ), - ], - ) - - assert ir == expected - - -def test_trivial_cartesian(): - grid = simple_cartesian_grid() - offset_provider = {"Ioff": grid.offset_provider["Ioff"]} - sizes = {"out": gtx.domain({IDim: (2, 7)})} - cartesian_domain, cartesian_domain_get = construct_domains( - sizes["out"], "out", "cartesian_domain" - ) - cartesian_domain_27_p1 = im.call("cartesian_domain")( - im.call("named_range")( - im.axis_literal(IDim), - im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), - im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1), - ) - ) - - testee = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))))( - im.as_fieldop("deref")("i_values") - ), - domain=cartesian_domain_get, - ) - ], - ) - - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - expected = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=cartesian_domain_27_p1, dtype=float_type) - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", cartesian_domain_27_p1)("i_values"), - domain=cartesian_domain_27_p1, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))), cartesian_domain - )("__tmp_1"), - domain=cartesian_domain, - ), - ], - ) - - assert ir == expected - - -def test_trivial_cartesian_forward(): - grid = simple_cartesian_grid() - offset_provider = {"Ioff": grid.offset_provider["Ioff"]} - sizes = {"out": gtx.domain({IDim: (0, 8)})} - - cartesian_domain_get = im.call("cartesian_domain")( - im.call("named_range")( - im.axis_literal(IDim), - im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), - im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4), - ) - ) - testee = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), - )( - im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), - )("i_values") - ), - domain=cartesian_domain_get, - ) - ], - ) - - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - cartesian_domain_m2 = im.call("cartesian_domain")( - im.call("named_range")( - im.axis_literal(IDim), - im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 2), - im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 2), - ) - ) - - cartesian_domain_m4 = im.call("cartesian_domain")( - im.call("named_range")( - im.axis_literal(IDim), - im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 4), - im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 4), - ) - ) - - expected = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_m2, dtype=float_type)], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m2 - )("i_values"), - domain=cartesian_domain_m2, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m4 - )("__tmp_1"), - domain=cartesian_domain_m4, - ), - ], - ) - - assert ir == expected From abb93ef4da874cc5bb27e6502397f3c44a827a86 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 7 Aug 2025 12:54:55 +0200 Subject: [PATCH 15/93] Cleanup and restrict domains in tests --- .../test_temporary_domain_inference.py | 106 +++++++++--------- .../test_transform_get_domain.py | 26 ++--- 2 files changed, 64 insertions(+), 68 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py index ea630681a7..f2fce5e324 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -46,9 +46,9 @@ def mesh_descriptor(exec_alloc_descriptor): def program_factory( - params: list[itir.Sym], - body: list[itir.SetAt], - declarations: Optional[list[itir.Temporary]] = None, + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, ) -> itir.Program: return itir.Program( id="testee", @@ -60,9 +60,9 @@ def program_factory( def run_test_program( - testee: itir.Program, expected: itir.Program, sizes: Dict[str, common.Domain], offset_provider: common.OffsetProvider + testee: itir.Program, expected: itir.Program, sizes: Dict[str, common.Domain], + offset_provider: common.OffsetProvider ) -> None: - ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = TransformGetDomain.apply(ir, sizes=sizes) @@ -74,16 +74,16 @@ def run_test_program( def test_trivial_shift(unstructured_case): sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} ) unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(9, 13)), - im.tuple_get(1, im.make_tuple(9, 13)))} - ) + {Edge: (im.tuple_get(0, im.make_tuple(9, 13)), + im.tuple_get(1, im.make_tuple(9, 13)))} + ) - unstructured_domain_V_37 = im.domain(common.GridType.UNSTRUCTURED,{Vertex: (3, 7)}) + unstructured_domain_V_37 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (3, 7)}) offset_provider = unstructured_case.offset_provider testee = program_factory( @@ -126,13 +126,13 @@ def test_trivial_shift(unstructured_case): def test_trivial_shift_switched(unstructured_case): sizes = {"out": gtx.domain({Edge: (2, 16), Vertex: (0, 9)})} unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} ) unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(2, 16)), - im.tuple_get(1, im.make_tuple(2, 16)))} + {Edge: (im.tuple_get(0, im.make_tuple(2, 16)), + im.tuple_get(1, im.make_tuple(2, 16)))} ) offset_provider = unstructured_case.offset_provider @@ -174,18 +174,18 @@ def test_trivial_shift_switched(unstructured_case): def test_two_shifts(unstructured_case): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + sizes = {"out": gtx.domain({Edge: (3, 8), Vertex: (0, 9)})} unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} ) unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(0, 18)), - im.tuple_get(1, im.make_tuple(0, 18)))} + {Edge: (im.tuple_get(0, im.make_tuple(3, 8)), + im.tuple_get(1, im.make_tuple(3, 8)))} ) - unstructured_domain_V = im.domain(common.GridType.UNSTRUCTURED,{Vertex: ( 0, 9)}) + unstructured_domain_V_39 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (3, 9)}) offset_provider = unstructured_case.offset_provider testee = program_factory( @@ -207,12 +207,12 @@ def test_two_shifts(unstructured_case): expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_V, dtype=float_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_39, dtype=float_type)], body=[ itir.SetAt( target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V)("vertex_values"), - domain=unstructured_domain_V, + expr=im.as_fieldop("deref", unstructured_domain_V_39)("vertex_values"), + domain=unstructured_domain_V_39, ), itir.SetAt( target=im.ref("out"), @@ -233,19 +233,20 @@ def test_two_shifts(unstructured_case): def test_nested_shift(unstructured_case): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (0, 9)})} + sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (3, 7)})} unstructured_domain_V = im.domain(common.GridType.UNSTRUCTURED, - {Vertex: (im.tuple_get(0, im.make_tuple(0, 9)), - im.tuple_get(1, im.make_tuple(0, 9)))} + {Vertex: (im.tuple_get(0, im.make_tuple(3, 7)), + im.tuple_get(1, im.make_tuple(3, 7)))} ) unstructured_domain_get_V = im.domain(common.GridType.UNSTRUCTURED, - {Vertex: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))))} + {Vertex: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))))} ) - unstructured_domain_V_39 = im.domain(common.GridType.UNSTRUCTURED,{Vertex: ( 3, 9)}) + unstructured_domain_V_69 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (6, 9)}) - unstructured_domain_E_918 = im.domain(common.GridType.UNSTRUCTURED,{Edge: ( 9, 18)}) + unstructured_domain_E_1216 = im.domain(common.GridType.UNSTRUCTURED, {Edge: (12, 16)}) offset_provider = unstructured_case.offset_provider @@ -254,7 +255,7 @@ def test_nested_shift(unstructured_case): body=[ itir.SetAt( target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))))( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("V2E", 3)("x"))))( im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( im.as_fieldop("deref")("vertex_values") ) @@ -267,27 +268,27 @@ def test_nested_shift(unstructured_case): expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_E_918, dtype=float_type), - itir.Temporary(id="__tmp_2", domain=unstructured_domain_V_39, dtype=float_type), + itir.Temporary(id="__tmp_1", domain=unstructured_domain_E_1216, dtype=float_type), + itir.Temporary(id="__tmp_2", domain=unstructured_domain_V_69, dtype=float_type), ], body=[ itir.SetAt( target=im.ref("__tmp_2"), - expr=im.as_fieldop("deref", unstructured_domain_V_39)("vertex_values"), - domain=unstructured_domain_V_39, + expr=im.as_fieldop("deref", unstructured_domain_V_69)("vertex_values"), + domain=unstructured_domain_V_69, ), itir.SetAt( target=im.ref("__tmp_1"), expr=im.as_fieldop( im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), - unstructured_domain_E_918, + unstructured_domain_E_1216, )("__tmp_2"), - domain=unstructured_domain_E_918, + domain=unstructured_domain_E_1216, ), itir.SetAt( target=im.ref("out"), expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("V2E", 1)("x"))), unstructured_domain_V + im.lambda_("x")(im.deref(im.shift("V2E", 3)("x"))), unstructured_domain_V )("__tmp_1"), domain=unstructured_domain_V, ), @@ -303,17 +304,17 @@ def test_trivial_cartesian(): sizes = {"out": gtx.domain({IDim: (2, 7)})} cartesian_domain = im.domain(common.GridType.CARTESIAN, - {IDim: (im.tuple_get(0, im.make_tuple(2, 7)), - im.tuple_get(1, im.make_tuple(2, 7)))} + {IDim: (im.tuple_get(0, im.make_tuple(2, 7)), + im.tuple_get(1, im.make_tuple(2, 7)))} ) cartesian_domain_get = im.domain(common.GridType.CARTESIAN, - {IDim: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))))} + {IDim: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))))} ) cartesian_domain_27_p1 = im.domain(common.GridType.CARTESIAN, - {IDim: (im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), - im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1))} + {IDim: (im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), + im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1))} ) testee = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], @@ -355,11 +356,12 @@ def test_trivial_cartesian(): def test_trivial_cartesian_forward(): grid = simple_cartesian_grid() offset_provider = {"Ioff": grid.offset_provider["Ioff"]} - sizes = {"out": gtx.domain({IDim: (0, 8)})} + sizes = {"out": gtx.domain({IDim: (2, 7)})} cartesian_domain_get = im.domain(common.GridType.CARTESIAN, - {IDim: (im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), - im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4))} + {IDim: ( + im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), + im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4))} ) testee = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], @@ -379,13 +381,13 @@ def test_trivial_cartesian_forward(): ) cartesian_domain_m2 = im.domain(common.GridType.CARTESIAN, - {IDim: (im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 2), - im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 2))} + {IDim: (im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 2), + im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 2))} ) cartesian_domain_m4 = im.domain(common.GridType.CARTESIAN, - {IDim: (im.minus(im.tuple_get(0, im.make_tuple(0, 8)), 4), - im.minus(im.tuple_get(1, im.make_tuple(0, 8)), 4))} + {IDim: (im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 4), + im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 4))} ) expected = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index be45b6980f..41fb2a1122 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -22,17 +22,11 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain -from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - simple_cartesian_grid, - Edge, - simple_mesh, -) def program_factory( - params: list[str], - body: list[itir.SetAt], + params: list[str], + body: list[itir.SetAt], ) -> itir.Program: return itir.Program( id="testee", @@ -44,8 +38,8 @@ def program_factory( def setat_factory( - domain: common.Domain, - target: str, + domain: common.Domain, + target: str, ) -> itir.SetAt: return itir.SetAt( expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), @@ -55,11 +49,11 @@ def setat_factory( def run_test_program( - params: list[str], - sizes: Dict[str, common.Domain], - target: str, - domain: itir.Expr, - domain_get: itir.Expr, + params: list[str], + sizes: Dict[str, common.Domain], + target: str, + domain: itir.Expr, + domain_get: itir.Expr, ) -> None: testee = program_factory( params=params, @@ -74,7 +68,7 @@ def run_test_program( def construct_domains( - domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] + domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] ) -> tuple[itir.FunCall, itir.FunCall]: ranges_get = {} ragnes_resolved = {} From 7b5b9b4831223b4bb02b8bb17c5846459fd3d4ab Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 7 Aug 2025 12:57:30 +0200 Subject: [PATCH 16/93] Refactor and clean up tests --- .../test_transform_get_domain.py | 232 ++++++++---------- 1 file changed, 100 insertions(+), 132 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index b3229163c1..41fb2a1122 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -5,99 +5,111 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Dict, Union +import pytest from next_tests.integration_tests.cases import ( + IField, IDim, KDim, Vertex, + unstructured_case, + exec_alloc_descriptor, ) from gt4py import next as gtx -from gt4py.next import Domain +from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain -IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) +def program_factory( + params: list[str], + body: list[itir.SetAt], +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[im.sym(par) for par in params], + declarations=[], + body=body, + ) -def construct_domains(domain_resolved: Domain, symbol_name: str, type: str): - named_ranges_get, named_ranges_resolved = [], [] - for dim, range_ in zip(domain_resolved.dims, domain_resolved.ranges): - get_domain_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) - named_ranges_get.append( - im.named_range(im.axis_literal(dim), im.tuple_get(0, get_domain_call), im.tuple_get(1, get_domain_call)) - ) - bounds_tuple = im.make_tuple(range_.start, range_.stop) - named_ranges_resolved.append( - im.named_range(im.axis_literal(dim), im.tuple_get(0, bounds_tuple), im.tuple_get(1, bounds_tuple)) - ) +def setat_factory( + domain: common.Domain, + target: str, +) -> itir.SetAt: + return itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), + domain=domain, + target=im.ref(target), + ) - return im.call(type)(*named_ranges_resolved), im.call(type)(*named_ranges_get) +def run_test_program( + params: list[str], + sizes: Dict[str, common.Domain], + target: str, + domain: itir.Expr, + domain_get: itir.Expr, +) -> None: + testee = program_factory( + params=params, + body=[setat_factory(domain=domain_get, target=im.ref(target))], + ) + expected = program_factory( + params=params, + body=[setat_factory(domain=domain, target=im.ref(target))], + ) + actual = TransformGetDomain.apply(testee, sizes=sizes) + assert actual == expected -def test_get_domain(): - sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} - unstructured_domain, unstructured_domain_get = construct_domains(sizes["out"], "out", "unstructured_domain") - testee = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain_get, - target=im.ref("out"), - ), - ], - ) +def construct_domains( + domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] +) -> tuple[itir.FunCall, itir.FunCall]: + ranges_get = {} + ragnes_resolved = {} - expected = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain, - target=im.ref("out"), - ), - ], + for dim, r in zip(domain_resolved.dims, domain_resolved.ranges): + get_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) + ranges_get[dim] = (im.tuple_get(0, get_call), im.tuple_get(1, get_call)) + bounds = im.make_tuple(r.start, r.stop) + ragnes_resolved[dim] = (im.tuple_get(0, bounds), im.tuple_get(1, bounds)) + + return im.domain(type, ragnes_resolved), im.domain(type, ranges_get) + + +def test_get_domain(): + sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} + unstructured_domain, unstructured_domain_get = construct_domains( + sizes["out"], "out", "unstructured_domain" ) - actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected + run_test_program(["inp", "out"], sizes, "out", unstructured_domain, unstructured_domain_get) def test_get_domain_inside_as_fieldop(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} - unstructured_domain, unstructured_domain_get = construct_domains(sizes["out"], "out", "unstructured_domain") + unstructured_domain, unstructured_domain_get = construct_domains( + sizes["out"], "out", "unstructured_domain" + ) - testee = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], + testee = program_factory( + params=["inp", "out"], body=[ itir.SetAt( - expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)( - im.ref("inp") - ), + expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)(im.ref("inp")), domain=unstructured_domain_get, target=im.ref("out"), ), ], ) - expected = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], + expected = program_factory( + params=["inp", "out"], body=[ itir.SetAt( expr=im.as_fieldop(im.ref("deref"), unstructured_domain)(im.ref("inp")), @@ -114,47 +126,25 @@ def test_get_domain_inside_as_fieldop(): def test_get_domain_tuples(): sizes = {"out": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)}))} - unstructured_domain_get = im.call("unstructured_domain")( - im.named_range(im.axis_literal(Vertex), - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex))) - ) - ) - unstructured_domain = im.call("unstructured_domain")( - im.named_range(im.axis_literal(Vertex), im.tuple_get(0, im.make_tuple(0, 5)), - im.tuple_get(1, im.make_tuple(0, 5))), - ) - - testee = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain_get, - target=im.tuple_get(0, "out"), - ), - ], + unstructured_domain_get = im.domain( + common.GridType.UNSTRUCTURED, + { + Vertex: ( + im.tuple_get( + 0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + ), + im.tuple_get( + 1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + ), + ) + }, ) - - expected = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain, - target=im.tuple_get(0, "out"), - ), - ], + unstructured_domain = im.domain( + common.GridType.UNSTRUCTURED, + {Vertex: (im.tuple_get(0, im.make_tuple(0, 5)), im.tuple_get(1, im.make_tuple(0, 5)))}, ) - actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected + run_test_program(["inp", "out"], sizes, "out", unstructured_domain, unstructured_domain_get) def test_get_domain_nested_tuples(): @@ -163,44 +153,22 @@ def test_get_domain_nested_tuples(): t0 = im.make_tuple("a", "b") t1 = im.make_tuple("c", "d") tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - unstructured_domain_get = im.call("unstructured_domain")( - im.named_range(im.axis_literal(KDim), - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))) - ) - ) - unstructured_domain = im.call("unstructured_domain")( - im.named_range(im.axis_literal(KDim), im.tuple_get(0, im.make_tuple(0, 3)), - im.tuple_get(1, im.make_tuple(0, 3))), - ) - testee = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain_get, - target=im.ref("a"), - ), - ], + unstructured_domain_get = im.domain( + common.GridType.UNSTRUCTURED, + { + KDim: ( + im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + ) + }, ) - expected = itir.Program( - id="test", - function_definitions=[], - params=[im.sym("inp"), im.sym("a"), im.sym("b"), im.sym("c"), im.sym("d")], - declarations=[], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - domain=unstructured_domain, - target=im.ref("a"), - ), - ], + unstructured_domain = im.domain( + common.GridType.UNSTRUCTURED, + {KDim: (im.tuple_get(0, im.make_tuple(0, 3)), im.tuple_get(1, im.make_tuple(0, 3)))}, ) - actual = TransformGetDomain.apply(testee, sizes=sizes) - assert actual == expected + run_test_program( + ["inp", "a", "b", "c", "d"], sizes, "a", unstructured_domain, unstructured_domain_get + ) From 891a97a1c936fe7435fb1a8fbc30c8a27db63604 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 7 Aug 2025 14:58:28 +0200 Subject: [PATCH 17/93] Add check on mesh order --- .../next/iterator/ir_utils/domain_utils.py | 9 ++++++ .../test_temporary_domain_inference.py | 30 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 5239859138..b6bc206e88 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,6 +10,7 @@ import dataclasses import functools +import warnings from typing import Any, Callable, Iterable, Literal, Mapping, Optional import numpy as np @@ -171,6 +172,14 @@ def translate( accessed = offset_provider[off.value].ndarray[start:stop, off_index] min_ = np.min(accessed) max_ = np.max(accessed) + 1 + + if (covered := np.unique(accessed).size) < ( + max_ - min_) / 2: + warnings.warn( + f"For {new_dim} the accessed range [{min_}, {max_}[ covers {max_ - min_} values, " + f"but only {covered} are actually present and {max_ - min_ - covered} were added " + f"in between {accessed}. Please consider reordering the mesh.") + horizontal_sizes[new_dim.value] = ( im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), im.literal(str(max_), builtins.INTEGER_INDEX_BUILTIN), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py index f2fce5e324..beb068c877 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -123,6 +123,36 @@ def test_trivial_shift(unstructured_case): run_test_program(testee, expected, sizes, offset_provider) +def test_trivial_shift_warning(unstructured_case): + with pytest.warns(UserWarning, match=r"For Vertex\[horizontal\] the accessed range \[3, 9\[ covers 6 values, " + r"but only 2 are actually present and 4 were added in between \[8 3\]\. " + r"Please consider reordering the mesh\."): + sizes = {"out": gtx.domain({Edge: (8, 10), Vertex: (0, 9)})} + unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} + ) + + offset_provider = unstructured_case.offset_provider + testee = program_factory( + params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( + im.as_fieldop("deref")("vertex_values") + ), + domain=unstructured_domain_get_E, + ) + ], + ) + ir = inline_fundefs.InlineFundefs().visit(testee) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = TransformGetDomain.apply(ir, sizes=sizes) + + global_tmps.create_global_tmps(ir, offset_provider=offset_provider) + + def test_trivial_shift_switched(unstructured_case): sizes = {"out": gtx.domain({Edge: (2, 16), Vertex: (0, 9)})} unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, From 958183b17510e8cccd582b9b54dc9cd9c08dd8cf Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 7 Aug 2025 15:03:31 +0200 Subject: [PATCH 18/93] Reformat --- .../next/iterator/ir_utils/domain_utils.py | 13 +- .../test_temporary_domain_inference.py | 201 +++++++++++------- .../test_transform_get_domain.py | 20 +- 3 files changed, 146 insertions(+), 88 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index b6bc206e88..9aa116a5ed 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -173,12 +173,15 @@ def translate( min_ = np.min(accessed) max_ = np.max(accessed) + 1 - if (covered := np.unique(accessed).size) < ( - max_ - min_) / 2: + if (covered := np.unique(accessed).size) < (max_ - min_) / 2: warnings.warn( - f"For {new_dim} the accessed range [{min_}, {max_}[ covers {max_ - min_} values, " - f"but only {covered} are actually present and {max_ - min_ - covered} were added " - f"in between {accessed}. Please consider reordering the mesh.") + UserWarning( + f"For {new_dim} the accessed range [{min_}, {max_}[ covers {max_ - min_} values, " + f"but only {covered} are actually present and {max_ - min_ - covered} were added " + f"in between {accessed}. Please consider reordering the mesh." + ), + stacklevel=2, + ) horizontal_sizes[new_dim.value] = ( im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py index beb068c877..7bad2473e9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -18,7 +18,7 @@ from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( simple_cartesian_grid, Edge, - simple_mesh + simple_mesh, ) from gt4py import next as gtx @@ -46,9 +46,9 @@ def mesh_descriptor(exec_alloc_descriptor): def program_factory( - params: list[itir.Sym], - body: list[itir.SetAt], - declarations: Optional[list[itir.Temporary]] = None, + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, ) -> itir.Program: return itir.Program( id="testee", @@ -60,8 +60,10 @@ def program_factory( def run_test_program( - testee: itir.Program, expected: itir.Program, sizes: Dict[str, common.Domain], - offset_provider: common.OffsetProvider + testee: itir.Program, + expected: itir.Program, + sizes: Dict[str, common.Domain], + offset_provider: common.OffsetProvider, ) -> None: ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) @@ -73,15 +75,20 @@ def run_test_program( def test_trivial_shift(unstructured_case): sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} - ) + unstructured_domain_get_E = im.domain( + common.GridType.UNSTRUCTURED, + { + Edge: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + }, + ) - unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(9, 13)), - im.tuple_get(1, im.make_tuple(9, 13)))} - ) + unstructured_domain_E = im.domain( + common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.make_tuple(9, 13)), im.tuple_get(1, im.make_tuple(9, 13)))}, + ) unstructured_domain_V_37 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (3, 7)}) @@ -124,14 +131,22 @@ def test_trivial_shift(unstructured_case): def test_trivial_shift_warning(unstructured_case): - with pytest.warns(UserWarning, match=r"For Vertex\[horizontal\] the accessed range \[3, 9\[ covers 6 values, " - r"but only 2 are actually present and 4 were added in between \[8 3\]\. " - r"Please consider reordering the mesh\."): + with pytest.warns( + UserWarning, + match=r"For Vertex\[horizontal\] the accessed range \[3, 9\[ covers 6 values, " + r"but only 2 are actually present and 4 were added in between \[8 3\]\. " + r"Please consider reordering the mesh\.", + ): sizes = {"out": gtx.domain({Edge: (8, 10), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} - ) + unstructured_domain_get_E = im.domain( + common.GridType.UNSTRUCTURED, + { + Edge: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + }, + ) offset_provider = unstructured_case.offset_provider testee = program_factory( @@ -155,15 +170,20 @@ def test_trivial_shift_warning(unstructured_case): def test_trivial_shift_switched(unstructured_case): sizes = {"out": gtx.domain({Edge: (2, 16), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} - ) + unstructured_domain_get_E = im.domain( + common.GridType.UNSTRUCTURED, + { + Edge: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + }, + ) - unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(2, 16)), - im.tuple_get(1, im.make_tuple(2, 16)))} - ) + unstructured_domain_E = im.domain( + common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.make_tuple(2, 16)), im.tuple_get(1, im.make_tuple(2, 16)))}, + ) offset_provider = unstructured_case.offset_provider testee = program_factory( @@ -205,15 +225,20 @@ def test_trivial_shift_switched(unstructured_case): def test_two_shifts(unstructured_case): sizes = {"out": gtx.domain({Edge: (3, 8), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))))} - ) + unstructured_domain_get_E = im.domain( + common.GridType.UNSTRUCTURED, + { + Edge: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Edge))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Edge))), + ) + }, + ) - unstructured_domain_E = im.domain(common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(3, 8)), - im.tuple_get(1, im.make_tuple(3, 8)))} - ) + unstructured_domain_E = im.domain( + common.GridType.UNSTRUCTURED, + {Edge: (im.tuple_get(0, im.make_tuple(3, 8)), im.tuple_get(1, im.make_tuple(3, 8)))}, + ) unstructured_domain_V_39 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (3, 9)}) @@ -237,7 +262,9 @@ def test_two_shifts(unstructured_case): expected = program_factory( params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_39, dtype=float_type)], + declarations=[ + itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_39, dtype=float_type) + ], body=[ itir.SetAt( target=im.ref("__tmp_1"), @@ -264,15 +291,19 @@ def test_two_shifts(unstructured_case): def test_nested_shift(unstructured_case): sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (3, 7)})} - unstructured_domain_V = im.domain(common.GridType.UNSTRUCTURED, - {Vertex: (im.tuple_get(0, im.make_tuple(3, 7)), - im.tuple_get(1, im.make_tuple(3, 7)))} - ) - unstructured_domain_get_V = im.domain(common.GridType.UNSTRUCTURED, - {Vertex: ( - im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))))} - ) + unstructured_domain_V = im.domain( + common.GridType.UNSTRUCTURED, + {Vertex: (im.tuple_get(0, im.make_tuple(3, 7)), im.tuple_get(1, im.make_tuple(3, 7)))}, + ) + unstructured_domain_get_V = im.domain( + common.GridType.UNSTRUCTURED, + { + Vertex: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))), + ) + }, + ) unstructured_domain_V_69 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (6, 9)}) @@ -333,19 +364,29 @@ def test_trivial_cartesian(): offset_provider = {"Ioff": grid.offset_provider["Ioff"]} sizes = {"out": gtx.domain({IDim: (2, 7)})} - cartesian_domain = im.domain(common.GridType.CARTESIAN, - {IDim: (im.tuple_get(0, im.make_tuple(2, 7)), - im.tuple_get(1, im.make_tuple(2, 7)))} - ) - cartesian_domain_get = im.domain(common.GridType.CARTESIAN, - {IDim: (im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), - im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))))} - ) - - cartesian_domain_27_p1 = im.domain(common.GridType.CARTESIAN, - {IDim: (im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), - im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1))} - ) + cartesian_domain = im.domain( + common.GridType.CARTESIAN, + {IDim: (im.tuple_get(0, im.make_tuple(2, 7)), im.tuple_get(1, im.make_tuple(2, 7)))}, + ) + cartesian_domain_get = im.domain( + common.GridType.CARTESIAN, + { + IDim: ( + im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), + im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), + ) + }, + ) + + cartesian_domain_27_p1 = im.domain( + common.GridType.CARTESIAN, + { + IDim: ( + im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), + im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1), + ) + }, + ) testee = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], body=[ @@ -388,11 +429,15 @@ def test_trivial_cartesian_forward(): offset_provider = {"Ioff": grid.offset_provider["Ioff"]} sizes = {"out": gtx.domain({IDim: (2, 7)})} - cartesian_domain_get = im.domain(common.GridType.CARTESIAN, - {IDim: ( - im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), - im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4))} - ) + cartesian_domain_get = im.domain( + common.GridType.CARTESIAN, + { + IDim: ( + im.minus(im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(IDim))), 4), + im.minus(im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(IDim))), 4), + ) + }, + ) testee = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], body=[ @@ -410,15 +455,25 @@ def test_trivial_cartesian_forward(): ], ) - cartesian_domain_m2 = im.domain(common.GridType.CARTESIAN, - {IDim: (im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 2), - im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 2))} - ) + cartesian_domain_m2 = im.domain( + common.GridType.CARTESIAN, + { + IDim: ( + im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 2), + im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 2), + ) + }, + ) - cartesian_domain_m4 = im.domain(common.GridType.CARTESIAN, - {IDim: (im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 4), - im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 4))} - ) + cartesian_domain_m4 = im.domain( + common.GridType.CARTESIAN, + { + IDim: ( + im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 4), + im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 4), + ) + }, + ) expected = program_factory( params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_m2, dtype=float_type)], diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 41fb2a1122..08d27fcd22 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -25,8 +25,8 @@ def program_factory( - params: list[str], - body: list[itir.SetAt], + params: list[str], + body: list[itir.SetAt], ) -> itir.Program: return itir.Program( id="testee", @@ -38,8 +38,8 @@ def program_factory( def setat_factory( - domain: common.Domain, - target: str, + domain: common.Domain, + target: str, ) -> itir.SetAt: return itir.SetAt( expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), @@ -49,11 +49,11 @@ def setat_factory( def run_test_program( - params: list[str], - sizes: Dict[str, common.Domain], - target: str, - domain: itir.Expr, - domain_get: itir.Expr, + params: list[str], + sizes: Dict[str, common.Domain], + target: str, + domain: itir.Expr, + domain_get: itir.Expr, ) -> None: testee = program_factory( params=params, @@ -68,7 +68,7 @@ def run_test_program( def construct_domains( - domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] + domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] ) -> tuple[itir.FunCall, itir.FunCall]: ranges_get = {} ragnes_resolved = {} From 466ec0f4b308e9c4f7c57703f86c78b7c6f531ef Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 7 Aug 2025 15:05:04 +0200 Subject: [PATCH 19/93] Reformat --- .../test_transform_get_domain.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 41fb2a1122..08d27fcd22 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -25,8 +25,8 @@ def program_factory( - params: list[str], - body: list[itir.SetAt], + params: list[str], + body: list[itir.SetAt], ) -> itir.Program: return itir.Program( id="testee", @@ -38,8 +38,8 @@ def program_factory( def setat_factory( - domain: common.Domain, - target: str, + domain: common.Domain, + target: str, ) -> itir.SetAt: return itir.SetAt( expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), @@ -49,11 +49,11 @@ def setat_factory( def run_test_program( - params: list[str], - sizes: Dict[str, common.Domain], - target: str, - domain: itir.Expr, - domain_get: itir.Expr, + params: list[str], + sizes: Dict[str, common.Domain], + target: str, + domain: itir.Expr, + domain_get: itir.Expr, ) -> None: testee = program_factory( params=params, @@ -68,7 +68,7 @@ def run_test_program( def construct_domains( - domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] + domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] ) -> tuple[itir.FunCall, itir.FunCall]: ranges_get = {} ragnes_resolved = {} From 0e4eb57ae840829b75e0f36e4ce8b0c48dd82435 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 14:27:04 +0200 Subject: [PATCH 20/93] Rename get_domain to get_domain_range --- src/gt4py/next/ffront/past_to_itir.py | 13 ++++++++----- src/gt4py/next/iterator/builtins.py | 4 ++-- src/gt4py/next/iterator/embedded.py | 4 ++-- src/gt4py/next/iterator/runtime.py | 2 +- .../next/iterator/type_system/type_synthesizer.py | 5 +++-- .../program_processors/codegens/gtfn/codegen.py | 3 ++- .../program_processors/codegens/gtfn/gtfn_ir.py | 2 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 4 ++-- 8 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 79c6a4b36b..f6f9c088fa 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,8 +14,7 @@ import devtools -from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.eve import utils as eve_utils +from gt4py.eve import NodeTranslator, concepts, traits, utils as eve_utils from gt4py.next import common, config, errors from gt4py.next.ffront import ( fbuiltins, @@ -366,14 +365,18 @@ def _construct_itir_domain_arg( ) # if the out_field is a (potentially nested) tuple we get the domain from its first # element - first_out_el_path = eve_utils.first(type_info.primitive_constituents(out_field.type, with_path_arg=True))[1] - first_out_el = functools.reduce(lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id) + first_out_el_path = eve_utils.first( + type_info.primitive_constituents(out_field.type, with_path_arg=True) + )[1] + first_out_el = functools.reduce( + lambda expr, i: im.tuple_get(i, expr), first_out_el_path, out_field.id + ) domain_args = [] domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the range of a dimension - dim_range = im.call("get_domain")( + dim_range = im.call("get_domain_range")( first_out_el, itir.AxisLiteral(value=dim.value, kind=dim.kind) ) dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 153105413b..e54c6ea3d7 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -413,7 +413,7 @@ def concat_where(*args): @builtin_dispatch -def get_domain(*args): +def get_domain_range(*args): raise BackendNotSelectedError() @@ -490,7 +490,7 @@ def get_domain(*args): "cartesian_domain", "cast_", "deref", - "get_domain", + "get_domain_range", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 739f7f2dfa..9e5c9a0efb 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1678,8 +1678,8 @@ def set_at(expr: common.Field, domain: common.DomainLike, target: common.Mutable operators._tuple_assign_field(target, expr, common.domain(domain)) -@runtime.get_domain.register(EMBEDDED) -def get_domain(field: common.Field, dim: common.Dimension) -> tuple[int, int]: +@runtime.get_domain_range.register(EMBEDDED) +def get_domain_range(field: common.Field, dim: common.Dimension) -> tuple[int, int]: return (field.domain[dim].unit_range.start, field.domain[dim].unit_range.stop) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index c927cda843..a995385b0f 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -216,7 +216,7 @@ def set_at(*args): @builtin_dispatch -def get_domain(*args): +def get_domain_range(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 258020ea7e..cc0759b79b 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -589,9 +589,10 @@ def applied_as_fieldop( @_register_builtin_type_synthesizer -def get_domain(field: ts.FieldType, dim: ts.DimensionType) -> ts.TupleType: +def get_domain_range(field: ts.FieldType, dim: ts.DimensionType) -> ts.TupleType: return ts.TupleType( - types=[ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))] * 2 + types=[ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()))] + * 2 ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index bf84bb8519..f142ff006e 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -273,11 +273,12 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include + // TODO(tehrengruber): This should disappear as soon as we introduce a proper builtin. namespace gridtools::fn { // TODO(tehrengruber): `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type` // fails as type used for index calculations in gtfn differs template - GT_FUNCTION gridtools::tuple get_domain(S &&sid, D) { + GT_FUNCTION gridtools::tuple get_domain_range(S &&sid, D) { return {gridtools::host_device::at_key(gridtools::sid::get_lower_bounds(sid)), gridtools::host_device::at_key(gridtools::sid::get_upper_bounds(sid))}; } diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 1e5023f5be..aa5a94991c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -231,7 +231,7 @@ class TemporaryAllocation(Node): "can_deref", "cartesian_domain", "unstructured_domain", - "get_domain", + "get_domain_range", "named_range", "reduce", "index", diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dac7823451..02fb4fdbe2 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -471,11 +471,11 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: tagged_sizes=sizes, tagged_offsets=domain_offsets, connectivities=connectivities ) - def _visit_get_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: + def _visit_get_domain_range(self, node: itir.FunCall, **kwargs: Any) -> Node: field, dim = node.args return FunCall( - fun=SymRef(id="get_domain"), + fun=SymRef(id="get_domain_range"), args=[self.visit(field, **kwargs), self.visit(dim, **kwargs)], ) From d916337482d1ac52152b3ebcfda70e3f4bcd00a2 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 15:15:55 +0200 Subject: [PATCH 21/93] Remove compile time args --- .../next/advanced/ToolchainWalkthrough.md | 10 ++-- src/gt4py/next/backend.py | 2 +- src/gt4py/next/ffront/past_process_args.py | 44 +++----------- src/gt4py/next/ffront/past_to_itir.py | 27 --------- src/gt4py/next/iterator/embedded.py | 5 -- src/gt4py/next/otf/arguments.py | 60 +------------------ .../runners/dace/workflow/decoration.py | 6 +- .../next/program_processors/runners/gtfn.py | 1 - .../gtfn_tests/test_gtfn_module.py | 6 +- 9 files changed, 20 insertions(+), 141 deletions(-) diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index 4d71c0ffe9..8cb8293b7a 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -134,13 +134,13 @@ So far we have gotten away with empty compile time arguments, now we need to sup ```python jit_args = gtx.otf.arguments.JITArgs.from_signature( - gtx.ones(domain={I: 10}, dtype=gtx.float64), - out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), - offset_provider=OFFSET_PROVIDER, + gtx.ones(domain={I: 10}, dtype=gtx.float64), + out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), + offset_provider=OFFSET_PROVIDER, ) -aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete_no_size( - *jit_args.args, **jit_args.kwargs +aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete( + *jit_args.args, **jit_args.kwargs ) ``` diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e8fa6b2ac5..a20c72c24b 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -159,7 +159,7 @@ def __call__( def jit(self, program: INPUT_DATA, *args: Any, **kwargs: Any) -> stages.CompiledProgram: if not isinstance(program, IT_PRG): args, kwargs = signature.convert_to_positional(program, *args, **kwargs) - aot_args = arguments.CompileTimeArgs.from_concrete_no_size(*args, **kwargs) + aot_args = arguments.CompileTimeArgs.from_concrete(*args, **kwargs) return self.compile(program, aot_args) def compile(self, program: INPUT_DATA, compile_time_args: CARG) -> stages.CompiledProgram: diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index f0360e05ba..95db2837dd 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -24,14 +24,14 @@ def transform_program_args(inp: AOT_PRG) -> AOT_PRG: - rewritten_args, size_args, kwargs = _process_args( + rewritten_args, rewritten_kwargs = _process_args( past_node=inp.data.past_node, args=inp.args.args, kwargs=inp.args.kwargs ) return toolchain.CompilableProgram( data=inp.data, args=arguments.CompileTimeArgs( - args=tuple((*rewritten_args, *(size_args))), - kwargs=kwargs, + args=rewritten_args, + kwargs=rewritten_kwargs, offset_provider=inp.args.offset_provider, column_axis=inp.args.column_axis, ), @@ -65,50 +65,20 @@ def _process_args( past_node: past.Program, args: Sequence[ts.TypeSpec | arguments.StaticArg], kwargs: dict[str, ts.TypeSpec | arguments.StaticArg], -) -> tuple[tuple, tuple, dict[str, Any]]: +) -> tuple[tuple, dict[str, Any]]: if not isinstance(past_node.type, ts_ffront.ProgramType): raise TypeError("Can not process arguments for PAST programs prior to type inference.") args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs) + + # validate arguments arg_types = tuple(arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in args) kwarg_types = { k: (v.type_ if isinstance(v, arguments.StaticArg) else v) for k, v in kwargs.items() } _validate_args(past_node=past_node, arg_types=arg_types, kwarg_types=kwarg_types) - implicit_domain = any( - isinstance(stmt, past.Call) and "domain" not in stmt.kwargs for stmt in past_node.body - ) - - # extract size of all field arguments - size_args: list[ts.TypeSpec] = [] - rewritten_args = list(args) - for param_idx, param in enumerate(past_node.params): - if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): - # TODO(tehrengruber): Previously this function was called with the actual arguments - # not their type. The check using the shape here is not functional anymore and - # should instead be placed in a proper location. - ranges_and_dims = [ - *_field_constituents_range_and_dims(arg_types[param_idx], param.type) - ] - # check that all non-scalar like constituents have the same shape and dimension, e.g. - # for `(scalar, (field1, field2))` the two fields need to have the same shape and - # dimension - if ranges_and_dims: - range_, dims = ranges_and_dims[0] - if not all( - el_range == range_ and el_dims == dims - for (el_range, el_dims) in ranges_and_dims - ): - raise ValueError( - "Constituents of composite arguments (e.g. the elements of a" - " tuple) need to have the same shape and dimensions." - ) - index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) - size_args.extend( - range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty - ) - return tuple(rewritten_args), tuple(size_args), kwargs + return args, kwargs def _field_constituents_range_and_dims( diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f6f9c088fa..f0c8060e65 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -223,32 +223,6 @@ def apply( ) -> itir.Program: return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) - def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: - """Generate symbols for each field param and dimension.""" - size_params = [] - for param in node.params: - fields_dims: list[list[common.Dimension]] = ( - type_info.primitive_constituents(param.type) - .if_isinstance(ts.FieldType) - .getattr("dims") - .filter(lambda dims: len(dims) > 0) - .to_list() - ) - if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` - assert all(field_dims == fields_dims[0] for field_dims in fields_dims) - index_type = ts.ScalarType( - kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) - ) - for dim_idx in range(len(fields_dims[0])): - size_params.append( - itir.Sym( - id=_range_arg_from_field(param.id, dim_idx), - type=ts.TupleType(types=[index_type, index_type]), - ) - ) - - return size_params - def visit_Program( self, node: past.Program, @@ -265,7 +239,6 @@ def visit_Program( implicit_domain = False if any("domain" not in body_entry.kwargs for body_entry in node.body): - params = params + self._gen_size_params_from_program(node) implicit_domain = True set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 9e5c9a0efb..b04ba8c42d 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1862,11 +1862,6 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): common.UnitRange(0, 0), # empty: indicates column operation, will update later ) - import inspect - - if len(args) < len(inspect.getfullargspec(fun).args): - args = (*args, *arguments.iter_size_args(args)) - with embedded_context.update(**context_vars): fun(*args) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 39540baebb..980c9849e2 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -55,7 +55,7 @@ def offset_provider_type(self) -> common.OffsetProviderType: return common.offset_provider_to_type(self.offset_provider) @classmethod - def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: + def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" compile_args = tuple(type_translation.from_value(arg) for arg in args) kwargs_copy = kwargs.copy() @@ -69,17 +69,6 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: }, ) - @classmethod - def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: - """Convert concrete GTX program arguments to compile-time, adding (compile-time) dimension size arguments.""" - no_size = cls.from_concrete_no_size(*args, **kwargs) - return cls( - args=(*no_size.args, *iter_size_compile_args(no_size.args)), - offset_provider=no_size.offset_provider, - column_axis=no_size.column_axis, - kwargs=no_size.kwargs, - ) - @classmethod def empty(cls) -> Self: return cls(tuple(), {}, {}, None) @@ -88,7 +77,7 @@ def empty(cls) -> Self: def jit_to_aot_args( inp: JITArgs, ) -> CompileTimeArgs: - return CompileTimeArgs.from_concrete_no_size(*inp.args, **inp.kwargs) + return CompileTimeArgs.from_concrete(*inp.args, **inp.kwargs) def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ @@ -110,47 +99,4 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: return element case _: pass - return None - - -def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]: - """ - Yield the size of each field argument in each dimension. - - This can be used to generate domain size arguments for FieldView Programs that use an implicit domain. - """ - for arg in args: - match arg: - case tuple(): - # we only need the first field, because all fields in a tuple must have the same dims and sizes - first_field = find_first_field(arg) - if first_field: - yield from iter_size_args((first_field,)) - case common.Field(): - for range_ in arg.domain.ranges: - assert isinstance(range_, common.UnitRange) - yield (range_.start, range_.stop) - case _: - pass - - -def iter_size_compile_args( - args: Iterable[ts.TypeSpec | StaticArg], -) -> Iterator[ts.TypeSpec]: - """ - Yield a compile-time size argument for every compile-time field argument in each dimension. - - This can be used inside transformation workflows to generate compile-time domain size arguments for FieldView Programs that use an implicit domain. - """ - for arg in args: - type_ = arg.type_ if isinstance(arg, StaticArg) else arg - field_constituents: list[ts.FieldType] = typing.cast( - list[ts.FieldType], - type_info.primitive_constituents(type_).if_isinstance(ts.FieldType).to_list(), - ) - if field_constituents: - # we only need the first field, because all fields in a tuple must have the same dims and sizes - index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) - yield from [ - ts.TupleType(types=[index_type, index_type]) for _ in field_constituents[0].dims - ] + return None \ No newline at end of file diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index b551381354..054e148e01 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -38,11 +38,7 @@ def decorated_program( if out is not None: args = (*args, out) - if fun.implicit_domain: - # Generate implicit domain size arguments only if necessary - size_args = arguments.iter_size_args(args) - args = (*args, *size_args) - + # TODO: this doesn't belong here and should by done in the dace backend if not fun.sdfg_program._lastargs: # First call, the SDFG is not intitalized, so forward the call to `CompiledSDFG` # to proper initilize it. Later calls to this SDFG will be handled through diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e395bcf991..ae326184d8 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -74,7 +74,6 @@ def decorated_program( # generate implicit domain size arguments only if necessary, using `iter_size_args()` inp( *converted_args, - *(arguments.iter_size_args(args) if inp.implicit_domain else ()), *conn_args, **opt_kwargs, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 510c03e314..acfbede4eb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -77,7 +77,7 @@ def test_codegen(program_example): module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete_no_size( + args=arguments.CompileTimeArgs.from_concrete( *parameters, **{"offset_provider": {}} ), ) @@ -91,7 +91,7 @@ def test_hash_and_diskcache(program_example, tmp_path): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete_no_size( + args=arguments.CompileTimeArgs.from_concrete( *parameters, **{"offset_provider": {}} ), ) @@ -135,7 +135,7 @@ def test_gtfn_file_cache(program_example): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete_no_size( + args=arguments.CompileTimeArgs.from_concrete( *parameters, **{"offset_provider": {}} ), ) From 25e24e9afc0ff71868efdca13cf631a8289fa2be Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 15:16:18 +0200 Subject: [PATCH 22/93] Fix format --- docs/user/next/advanced/ToolchainWalkthrough.md | 10 ++++------ src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/iterator/embedded.py | 1 - src/gt4py/next/otf/arguments.py | 6 +++--- .../runners/dace/workflow/decoration.py | 2 +- src/gt4py/next/program_processors/runners/gtfn.py | 2 +- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 12 +++--------- 7 files changed, 13 insertions(+), 22 deletions(-) diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index 8cb8293b7a..d730eed37e 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -134,14 +134,12 @@ So far we have gotten away with empty compile time arguments, now we need to sup ```python jit_args = gtx.otf.arguments.JITArgs.from_signature( - gtx.ones(domain={I: 10}, dtype=gtx.float64), - out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), - offset_provider=OFFSET_PROVIDER, + gtx.ones(domain={I: 10}, dtype=gtx.float64), + out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), + offset_provider=OFFSET_PROVIDER, ) -aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete( - *jit_args.args, **jit_args.kwargs -) +aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete(*jit_args.args, **jit_args.kwargs) ``` ```python diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f0c8060e65..e7b9412ebf 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -25,7 +25,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.stages import AOT_PRG -from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import remap_symbols from gt4py.next.otf import arguments, stages, workflow diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b04ba8c42d..c43a3422b5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,7 +54,6 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime -from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 980c9849e2..04f4bcf1e8 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -10,14 +10,14 @@ import dataclasses import typing -from typing import Any, Generic, Iterable, Iterator, Optional +from typing import Any, Generic, Optional from typing_extensions import Self from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation DATA_T = typing.TypeVar("DATA_T") @@ -99,4 +99,4 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: return element case _: pass - return None \ No newline at end of file + return None diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 054e148e01..240de755fe 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, config, metrics, utils as gtx_utils -from gt4py.next.otf import arguments, stages +from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable, workflow as dace_worflow from . import common as dace_common diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index ae326184d8..29cacd06e1 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -20,7 +20,7 @@ from gt4py._core import locking from gt4py.next import backend, common, config, field_utils, metrics from gt4py.next.embedded import nd_array_field -from gt4py.next.otf import arguments, recipes, stages, workflow +from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index acfbede4eb..6eba78bb23 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -77,9 +77,7 @@ def test_codegen(program_example): module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete( - *parameters, **{"offset_provider": {}} - ), + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) ) assert module.entry_point.name == fencil.id @@ -91,9 +89,7 @@ def test_hash_and_diskcache(program_example, tmp_path): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete( - *parameters, **{"offset_provider": {}} - ), + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) hash = stages.fingerprint_compilable_program(compilable_program) @@ -135,9 +131,7 @@ def test_gtfn_file_cache(program_example): fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, - args=arguments.CompileTimeArgs.from_concrete( - *parameters, **{"offset_provider": {}} - ), + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( gpu=False, cached=True, otf_workflow__cached_translation=True From e856a1899b1254a6f5d1168e071909cb91727acf Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 15:46:13 +0200 Subject: [PATCH 23/93] Fix failing tests --- tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index cbaa84454d..ae70d27ea8 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -99,8 +99,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), - P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) @@ -191,8 +189,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), - P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) From 00a11a62695b0a9a1bd9510cb9d8ec5622e67a11 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 18 Aug 2025 16:30:59 +0200 Subject: [PATCH 24/93] Fix failing tests --- .../ffront_tests/test_past_to_gtir.py | 73 +++++++------------ 1 file changed, 25 insertions(+), 48 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index ae70d27ea8..aede44283e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -8,6 +8,7 @@ import re +from typing import Literal import pytest @@ -40,6 +41,25 @@ def gtir_identity_fundef(): ) +def get_domain_range_pattern(field: str, dim: str, idx: Literal[0, 1]): + return P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value=str(idx), + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("get_domain_range")), + args=[P(itir.SymRef, id=eve.SymbolRef(field)), P(itir.AxisLiteral, value=dim)], + ), + ], + ) + + def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node = ProgramParser.apply_to_function(copy_program_def) itir_node = ProgramLowering.apply( @@ -58,30 +78,8 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="0", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="1", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), + get_domain_range_pattern("out", "IDim", 0), + get_domain_range_pattern("out", "IDim", 1), ], ) ], @@ -128,18 +126,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("plus")), args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="0", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), + get_domain_range_pattern("out", "IDim", 0), P( itir.Literal, value="1", @@ -155,18 +142,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("plus")), args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), - args=[ - P( - itir.Literal, - value="0", - type=ts.ScalarType(kind=ts.ScalarKind.INT32), - ), - P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), - ], - ), + get_domain_range_pattern("out", "IDim", 0), P( itir.Literal, value="2", @@ -183,6 +159,7 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) ], ), ) + program_pattern = P( itir.Program, id=eve.SymbolName("copy_restrict_program"), From 97e47a294a8aad6c9332b42174dbfe367dbc8d82 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 21 Aug 2025 00:23:18 +0200 Subject: [PATCH 25/93] Cleanup --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 29 +++--- src/gt4py/next/iterator/ir_utils/misc.py | 11 ++- .../iterator/transforms/constant_folding.py | 21 ++--- .../transforms/transform_get_domain.py | 92 +++++++++++-------- .../iterator_tests/test_type_inference.py | 12 +-- .../test_transform_get_domain.py | 20 ++-- 6 files changed, 106 insertions(+), 79 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index de94defe87..b0f4f2c22b 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import typing -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypeAlias, Union from gt4py._core import definitions as core_defs from gt4py.next import common @@ -15,6 +15,9 @@ from gt4py.next.type_system import type_specifications as ts, type_translation +ExprLike: TypeAlias = Union[str, core_defs.Scalar, common.Dimension, itir.Expr] + + def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = None) -> itir.Sym: """ Convert to Sym if necessary. @@ -68,7 +71,7 @@ def ref( return ref -def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> itir.Expr: +def ensure_expr(expr_like: ExprLike) -> itir.Expr: """ Convert literals into a SymRef or Literal and let expressions pass unchanged. @@ -83,14 +86,16 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti >>> ensure_expr(itir.OffsetLiteral(value="i")) OffsetLiteral(value='i') """ - if isinstance(literal_or_expr, str): - return ref(literal_or_expr) - elif core_defs.is_scalar_type(literal_or_expr): - return literal_from_value(literal_or_expr) - elif literal_or_expr is None: + if isinstance(expr_like, str): + return ref(expr_like) + elif core_defs.is_scalar_type(expr_like): + return literal_from_value(expr_like) + elif expr_like is None: return itir.NoneLiteral() - assert isinstance(literal_or_expr, itir.Expr), literal_or_expr - return literal_or_expr + elif isinstance(expr_like, common.Dimension): + return axis_literal(expr_like) + assert isinstance(expr_like, itir.Expr), expr_like + return expr_like def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.OffsetLiteral: @@ -463,7 +468,7 @@ def domain( expr = call(grid_type)( *[ named_range( - axis_literal(d), + d, r[0], r[1], ) @@ -474,7 +479,9 @@ def domain( return expr -def named_range(dim: itir.AxisLiteral, start: itir.Expr, stop: itir.Expr) -> itir.FunCall: +def named_range( + dim: itir.AxisLiteral | common.Dimension, start: itir.Expr, stop: itir.Expr +) -> itir.FunCall: return call("named_range")(dim, start, stop) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 00ff9abbd9..ce5c0c085d 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -11,11 +11,13 @@ from typing import Callable, Iterable, TypeVar from gt4py import eve +from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import embedded, ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.type_system import type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -253,3 +255,10 @@ def unique_symbol(sym: SymOrStr, reserved_names: Iterable[str]) -> SymOrStr: name = name + "_" return name + + +def value_from_literal(literal: itir.Literal) -> core_defs.Scalar: + if literal.type.kind == ts.ScalarKind.BOOL: + assert literal.value in ["True", "False"] + return literal.value == "True" + return getattr(embedded, str(literal.type))(literal.value) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 48653ba5b5..f9269314fb 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -15,18 +15,13 @@ from typing import Optional from gt4py import eve -from gt4py._core import definitions as core_defs from gt4py.next.iterator import builtins, embedded, ir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as ir_misc, +) from gt4py.next.iterator.transforms import fixed_point_transformation -from gt4py.next.type_system import type_specifications as ts - - -def _value_from_literal(literal: ir.Literal) -> core_defs.Scalar: - if literal.type.kind == ts.ScalarKind.BOOL: - assert literal.value in ["True", "False"] - return literal.value == "True" - return getattr(embedded, str(literal.type))(literal.value) class UndoCanonicalizeMinus(eve.NodeTranslator): @@ -42,11 +37,11 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: a, b = node.args if cpm.is_call_to(b, "neg"): return im.minus(a, b.args[0]) - if isinstance(b, ir.Literal) and (val := _value_from_literal(b)) < 0: + if isinstance(b, ir.Literal) and (val := ir_misc.value_from_literal(b)) < 0: return im.minus(a, -val) # type: ignore[operator] # if val would represent an unsigend int, `-` is not supported, but error would be somewhere else if cpm.is_call_to(a, "neg"): return im.minus(b, a.args[0]) - if isinstance(a, ir.Literal) and (val := _value_from_literal(a)) < 0: + if isinstance(a, ir.Literal) and (val := ir_misc.value_from_literal(a)) < 0: return im.minus(b, -val) # type: ignore[operator] # if val would represent an unsigend int, `-` is not supported, but error would be somewhere else return node @@ -217,7 +212,7 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti if node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(node.fun.id)) arg_values = [ - _value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition + ir_misc.value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition for arg in node.args ] return im.literal_from_value(fun(*arg_values)) diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain.py b/src/gt4py/next/iterator/transforms/transform_get_domain.py index d2a94f20be..c34ba61a28 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain.py +++ b/src/gt4py/next/iterator/transforms/transform_get_domain.py @@ -9,15 +9,46 @@ import dataclasses from typing import Dict +from gt4py._core import definitions as core_defs from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import collapse_tuple +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + ir_makers as im, + misc as ir_misc, +) + + +DomainOrTupleThereof = tuple["DomainOrTupleThereof", ...] | common.Domain + + +class _DomainDeduction(NodeTranslator): + def visit_Node(self, node: itir.Node, **kwargs): + return None # means we could not deduce the domain + + def visit_SymRef( + self, node: itir.SymRef, *, sizes: Dict[str, common.Domain], **kwargs + ) -> DomainOrTupleThereof | None: + return sizes.get(node.id, None) + + def visit_Literal(self, node: itir.Literal, **kwargs) -> core_defs.Scalar: + return ir_misc.value_from_literal(node) + + def visit_FunCall(self, node, **kwargs): + args = self.generic_visit(node.args, **kwargs) + + if cpm.is_call_to(node, "tuple_get"): + idx, expr = args + return expr[idx] + elif cpm.is_call_to(node, "make_tuple"): + return tuple(args) + + return node @dataclasses.dataclass(frozen=True) -class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): +class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): """ Transforms `get_domain` calls into a tuple containing start and stop. @@ -30,16 +61,16 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): ... "out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)}), ... } - >>> unstructured_domain_get_out = im.call("unstructured_domain")( + >>> output_domain = im.call("unstructured_domain")( ... im.named_range( - ... im.axis_literal(Vertex), - ... im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(Vertex))), - ... im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(Vertex))), + ... Vertex, + ... im.tuple_get(0, im.call("get_domain_range")("out", Vertex)), + ... im.tuple_get(1, im.call("get_domain_range")("out", Vertex)), ... ), ... im.named_range( - ... im.axis_literal(KDim), - ... im.tuple_get(0, im.call("get_domain")("out", im.axis_literal(KDim))), - ... im.tuple_get(1, im.call("get_domain")("out", im.axis_literal(KDim))), + ... KDim, + ... im.tuple_get(0, im.call("get_domain_range")("out", KDim)), + ... im.tuple_get(1, im.call("get_domain_range")("out", KDim)), ... ), ... ) >>> ir = itir.Program( @@ -49,13 +80,13 @@ class TransformGetDomain(PreserveLocationVisitor, NodeTranslator): ... declarations=[], ... body=[ ... itir.SetAt( - ... expr=im.as_fieldop(im.ref("deref"))(im.ref("inp")), - ... domain=unstructured_domain_get_out, + ... expr=im.as_fieldop("deref")(im.ref("inp")), + ... domain=output_domain, ... target=im.ref("out"), ... ), ... ], ... ) - >>> result = TransformGetDomain.apply(ir, sizes=sizes) + >>> result = TransformGetDomainRange.apply(ir, sizes=sizes) >>> print(result) test(inp, out) { out @ u⟨ Vertexₕ: [{0, 10}[0], {0, 10}[1][, KDimᵥ: [{0, 20}[0], {0, 20}[1][ ⟩ ← (⇑deref)(inp); @@ -67,35 +98,22 @@ def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): return cls().visit(program, sizes=sizes) def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: - sizes = kwargs["sizes"] - - if not cpm.is_call_to(node, "get_domain"): - return self.generic_visit(node, sizes=sizes) + if not cpm.is_call_to(node, "get_domain_range"): + return self.generic_visit(node, **kwargs) field, dim = node.args + assert isinstance(dim, itir.AxisLiteral) + domain = _DomainDeduction().visit(field, sizes=kwargs["sizes"]) - if cpm.is_call_to(field, "tuple_get"): - ref = field.args[1] - if isinstance(ref, itir.SymRef): - assert ref.id in sizes, f"Symbol '{ref.id}' not found in sizes Dict." - assert isinstance(sizes[ref.id], tuple), "A domain-tuple must be passed." - domain = sizes[ref.id][int(field.args[0].value)] - else: - field = collapse_tuple.CollapseTuple.apply( - field, within_stencil=False, allow_undeclared_symbols=True - ) - return self.visit(im.call("get_domain")(field, dim), sizes=sizes) - elif isinstance(field, itir.SymRef): - assert field.id in sizes, f"Symbol '{field.id}' not found in sizes Dict." - domain = sizes[field.id] - else: - raise NotImplementedError( - "Only calls to tuple_get or SymRefs are supported as first argument of get_domain." + if not isinstance(domain, common.Domain): + raise ValueError( + "Could not deduce domain of field expression. Must be a 'SymRef' or " + "tuple expression thereof, but got:\n" + f"'{field}'." ) - input_dims = domain.dims - index = next((i for i, d in enumerate(input_dims) if d.value == dim.value), None) - assert index is not None, f"Dimension {dim.value} not found in {input_dims}" + index = next((i for i, d in enumerate(domain.dims) if d.value == dim.value), None) + assert index is not None, f"Dimension {dim.value} not found in {domain.dims}" start = domain.ranges[index].start stop = domain.ranges[index].stop diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 798ea0c011..0dbb9357c3 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -100,9 +100,7 @@ def expression_test_cases(): it_ts.NamedRangeType(dim=Vertex), ), ( - im.call("cartesian_domain")( - im.named_range(itir.AxisLiteral(value="IDim"), 0, 1) - ), + im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)), ts.DomainType(dims=[IDim]), ), ( @@ -422,9 +420,7 @@ def test_unstructured_fencil_definition(): im.named_range( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ), - im.named_range( - itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 - ), + im.named_range(itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1), ) testee = itir.Program( @@ -491,9 +487,7 @@ def test_fencil_with_nb_field_input(): im.named_range( itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 ), - im.named_range( - itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 - ), + im.named_range(itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1), ) testee = itir.Program( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 08d27fcd22..2335e3b4d7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -21,7 +21,7 @@ from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomain +from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomainRange def program_factory( @@ -63,7 +63,7 @@ def run_test_program( params=params, body=[setat_factory(domain=domain, target=im.ref(target))], ) - actual = TransformGetDomain.apply(testee, sizes=sizes) + actual = TransformGetDomainRange.apply(testee, sizes=sizes) assert actual == expected @@ -74,7 +74,7 @@ def construct_domains( ragnes_resolved = {} for dim, r in zip(domain_resolved.dims, domain_resolved.ranges): - get_call = im.call("get_domain")(symbol_name, im.axis_literal(dim)) + get_call = im.call("get_domain_range")(symbol_name, im.axis_literal(dim)) ranges_get[dim] = (im.tuple_get(0, get_call), im.tuple_get(1, get_call)) bounds = im.make_tuple(r.start, r.stop) ragnes_resolved[dim] = (im.tuple_get(0, bounds), im.tuple_get(1, bounds)) @@ -119,7 +119,7 @@ def test_get_domain_inside_as_fieldop(): ], ) - actual = TransformGetDomain.apply(testee, sizes=sizes) + actual = TransformGetDomainRange.apply(testee, sizes=sizes) assert actual == expected @@ -131,10 +131,10 @@ def test_get_domain_tuples(): { Vertex: ( im.tuple_get( - 0, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + 0, im.call("get_domain_range")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) ), im.tuple_get( - 1, im.call("get_domain")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) + 1, im.call("get_domain_range")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) ), ) }, @@ -158,8 +158,12 @@ def test_get_domain_nested_tuples(): common.GridType.UNSTRUCTURED, { KDim: ( - im.tuple_get(0, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), - im.tuple_get(1, im.call("get_domain")(im.tuple_get(0, tup), im.axis_literal(KDim))), + im.tuple_get( + 0, im.call("get_domain_range")(im.tuple_get(0, tup), im.axis_literal(KDim)) + ), + im.tuple_get( + 1, im.call("get_domain_range")(im.tuple_get(0, tup), im.axis_literal(KDim)) + ), ) }, ) From 564ca1604b3d728a671eeab2dfa1e2b585189807 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 23 Aug 2025 14:06:42 +0200 Subject: [PATCH 26/93] Cleanup tests --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 23 +++- .../iterator/transforms/collapse_tuple.py | 16 ++- .../test_transform_get_domain.py | 110 ++++-------------- 3 files changed, 60 insertions(+), 89 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index b0f4f2c22b..c16a61d6ae 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import typing -from typing import Any, Callable, Optional, TypeAlias, Union +from typing import Any, Callable, Iterable, Optional, TypeAlias, Union from gt4py._core import definitions as core_defs from gt4py.next import common @@ -479,6 +479,27 @@ def domain( return expr +def get_field_domain( + grid_type: Union[common.GridType, str], + field: str | itir.Expr, + dims: Iterable[common.Dimension] | None = None, +) -> itir.Expr: + if isinstance(field, itir.Expr) and isinstance(field.type, ts.FieldType): + assert dims is None or all(d1 == d2 for d1, d2 in zip(field.type.dims, dims, strict=True)) + dims = field.type.dims + + return domain( + grid_type, + { + dim: ( + tuple_get(0, call("get_domain_range")(field, dim)), + tuple_get(1, call("get_domain_range")(field, dim)), + ) + for dim in dims + }, + ) + + def named_range( dim: itir.AxisLiteral | common.Dimension, start: itir.Expr, stop: itir.Expr ) -> itir.FunCall: diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 85a4854998..a9fe4c3d98 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -15,6 +15,8 @@ import re from typing import Literal, Optional +from prompt_toolkit.layout.processors import Transformation + from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common @@ -217,7 +219,19 @@ def apply( False, ], "Parameter 'within_stencil' mandatory if node is not a 'Program'." - if not ignore_tuple_size: + requires_types = False + if enabled_transformations & ( + cls.Transformation.PROPAGATE_TO_IF_ON_TUPLES_CPS + | cls.Transformation.FLATTEN_AS_FIELDOP_ARGS + ): + requires_types = True + elif ( + not ignore_tuple_size + and enabled_transformations & cls.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET + ): + requires_types = True + + if requires_types: node = itir_type_inference.infer( node, offset_provider_type=offset_provider_type, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py index 2335e3b4d7..e012e4e9dc 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py @@ -22,6 +22,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomainRange +from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple def program_factory( @@ -64,115 +65,50 @@ def run_test_program( body=[setat_factory(domain=domain, target=im.ref(target))], ) actual = TransformGetDomainRange.apply(testee, sizes=sizes) + actual = CollapseTuple.apply( + actual, enabled_transformations=CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ) assert actual == expected -def construct_domains( - domain_resolved: Domain, symbol_name: str, type: Union[common.GridType, str] -) -> tuple[itir.FunCall, itir.FunCall]: - ranges_get = {} - ragnes_resolved = {} - - for dim, r in zip(domain_resolved.dims, domain_resolved.ranges): - get_call = im.call("get_domain_range")(symbol_name, im.axis_literal(dim)) - ranges_get[dim] = (im.tuple_get(0, get_call), im.tuple_get(1, get_call)) - bounds = im.make_tuple(r.start, r.stop) - ragnes_resolved[dim] = (im.tuple_get(0, bounds), im.tuple_get(1, bounds)) - - return im.domain(type, ragnes_resolved), im.domain(type, ranges_get) +def domain_as_expr(domain: gtx.Domain) -> itir.Expr: + return im.domain( + common.GridType.UNSTRUCTURED, + {d: (r.start, r.stop) for d, r in zip(domain.dims, domain.ranges)}, + ) def test_get_domain(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} - unstructured_domain, unstructured_domain_get = construct_domains( - sizes["out"], "out", "unstructured_domain" - ) + get_domain_expr = im.get_field_domain(common.GridType.UNSTRUCTURED, "out", sizes["out"].dims) - run_test_program(["inp", "out"], sizes, "out", unstructured_domain, unstructured_domain_get) - - -def test_get_domain_inside_as_fieldop(): - sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} - unstructured_domain, unstructured_domain_get = construct_domains( - sizes["out"], "out", "unstructured_domain" - ) - - testee = program_factory( - params=["inp", "out"], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"), unstructured_domain_get)(im.ref("inp")), - domain=unstructured_domain_get, - target=im.ref("out"), - ), - ], - ) - - expected = program_factory( - params=["inp", "out"], - body=[ - itir.SetAt( - expr=im.as_fieldop(im.ref("deref"), unstructured_domain)(im.ref("inp")), - domain=unstructured_domain, - target=im.ref("out"), - ), - ], - ) - - actual = TransformGetDomainRange.apply(testee, sizes=sizes) - assert actual == expected + run_test_program(["inp", "out"], sizes, "out", domain_as_expr(sizes["out"]), get_domain_expr) def test_get_domain_tuples(): sizes = {"out": (gtx.domain({Vertex: (0, 5)}), gtx.domain({Vertex: (0, 7)}))} - unstructured_domain_get = im.domain( - common.GridType.UNSTRUCTURED, - { - Vertex: ( - im.tuple_get( - 0, im.call("get_domain_range")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) - ), - im.tuple_get( - 1, im.call("get_domain_range")(im.tuple_get(0, "out"), im.axis_literal(Vertex)) - ), - ) - }, - ) - unstructured_domain = im.domain( - common.GridType.UNSTRUCTURED, - {Vertex: (im.tuple_get(0, im.make_tuple(0, 5)), im.tuple_get(1, im.make_tuple(0, 5)))}, + get_domain_expr = im.get_field_domain( + common.GridType.UNSTRUCTURED, im.tuple_get(1, "out"), sizes["out"][1].dims ) - run_test_program(["inp", "out"], sizes, "out", unstructured_domain, unstructured_domain_get) + run_test_program(["inp", "out"], sizes, "out", domain_as_expr(sizes["out"][1]), get_domain_expr) def test_get_domain_nested_tuples(): sizes = {"a": gtx.domain({KDim: (0, 3)})} - t0 = im.make_tuple("a", "b") - t1 = im.make_tuple("c", "d") - tup = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - - unstructured_domain_get = im.domain( + get_domain_expr = im.get_field_domain( common.GridType.UNSTRUCTURED, - { - KDim: ( - im.tuple_get( - 0, im.call("get_domain_range")(im.tuple_get(0, tup), im.axis_literal(KDim)) - ), - im.tuple_get( - 1, im.call("get_domain_range")(im.tuple_get(0, tup), im.axis_literal(KDim)) - ), - ) - }, - ) - - unstructured_domain = im.domain( - common.GridType.UNSTRUCTURED, - {KDim: (im.tuple_get(0, im.make_tuple(0, 3)), im.tuple_get(1, im.make_tuple(0, 3)))}, + im.tuple_get( + 0, + im.make_tuple( + im.tuple_get(0, im.make_tuple("a", "b")), im.tuple_get(1, im.make_tuple("c", "d")) + ), + ), + sizes["a"].dims, ) run_test_program( - ["inp", "a", "b", "c", "d"], sizes, "a", unstructured_domain, unstructured_domain_get + ["inp", "a", "b", "c", "d"], sizes, "a", domain_as_expr(sizes["a"]), get_domain_expr ) From 0aae1fbe4b17cdee81d852fe52edf114785e0feb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 23 Aug 2025 14:38:47 +0200 Subject: [PATCH 27/93] Fix failing test --- .../transforms_tests/test_temporary_domain_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py index c75caf20a3..d36fc865ca 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -163,7 +163,7 @@ def test_trivial_shift_warning(mesh_descriptor): ) ir = inline_fundefs.InlineFundefs().visit(testee) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomain.apply(ir, sizes=sizes) + ir = TransformGetDomainRange.apply(ir, sizes=sizes) global_tmps.create_global_tmps(ir, offset_provider=offset_provider) From a35bfb5b77ee4e158253b354aca5f90663005c4a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 23 Aug 2025 14:41:14 +0200 Subject: [PATCH 28/93] Cleanup --- .../{transform_get_domain.py => transform_get_domain_range.py} | 0 ...ansform_get_domain.py => test_transform_get_domain_range.py} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/gt4py/next/iterator/transforms/{transform_get_domain.py => transform_get_domain_range.py} (100%) rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_transform_get_domain.py => test_transform_get_domain_range.py} (97%) diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain.py b/src/gt4py/next/iterator/transforms/transform_get_domain_range.py similarity index 100% rename from src/gt4py/next/iterator/transforms/transform_get_domain.py rename to src/gt4py/next/iterator/transforms/transform_get_domain_range.py diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py similarity index 97% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py index e012e4e9dc..a4864ff00e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py @@ -21,7 +21,7 @@ from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomainRange +from gt4py.next.iterator.transforms.transform_get_domain_range import TransformGetDomainRange from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple From 325db750e5afe1979b4f2925fc20f782b9dd5e1c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sat, 23 Aug 2025 14:41:37 +0200 Subject: [PATCH 29/93] Cleanup --- .../transforms_tests/test_temporary_domain_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py index d36fc865ca..dad33c4d2a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -28,7 +28,7 @@ ir_makers as im, ) from gt4py.next.iterator.transforms import inline_fundefs, global_tmps -from gt4py.next.iterator.transforms.transform_get_domain import TransformGetDomainRange +from gt4py.next.iterator.transforms.transform_get_domain_range import TransformGetDomainRange from gt4py.next.type_system import type_specifications as ts IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) From f246582ee78b64c4781d659b9888eac8a6d55cc2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 2 Sep 2025 14:34:48 +0200 Subject: [PATCH 30/93] Address review comments --- .../ADRs/next/0020-Runtime-domains.md | 23 +++++++++++++++++++ .../codegens/gtfn/codegen.py | 7 ++++-- .../runners/dace/workflow/decoration.py | 1 - 3 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 docs/development/ADRs/next/0020-Runtime-domains.md diff --git a/docs/development/ADRs/next/0020-Runtime-domains.md b/docs/development/ADRs/next/0020-Runtime-domains.md new file mode 100644 index 0000000000..f971675060 --- /dev/null +++ b/docs/development/ADRs/next/0020-Runtime-domains.md @@ -0,0 +1,23 @@ +--- +tags: [] +--- + +# [Runtime domains] + +- **Status**: valid +- **Authors**: Till Ehrengruber (@tehrengruber) +- **Created**: 2025-09-01 +- **Updated**: 2025-09-01 + +The mechanism for representing domains in the IR and accesing their values at runtime has been updated with the introduction of a new builtin on GTIR called `get_domain_range(field, dimension)`. + +## History + +In the early days the domain of a field was represented by a set of implicit (scalar) parameters that were added in the frontend (e.g. from PAST -> GTIR). This mechanism increased the complexity of the frontend and created artifical differences between the frontend representation (PAST) and GTIR. + +## Decision + +`get_domain_range(field, dimension) -> (start, stop)` takes a field and a dimension as an input and returns a tuple of integers containing the start and stop indices of given dimension in the domain of the field. Eventually we want a builtin that returns the entire domain of a field. However, this was withdrawn due to the effort required to design & implement it properly. We identified the following issues: + +- The gtfn backend currently selects a backend, i.e. cartesian or unstructured, based on the type of the domain object on which a stencil is executed. Additionally, for unstructured all neighbor tables are stored in the domain. A field or sid, however, does and should not have a notion of a backend, so constructing a domain from a sid is not possible without hacks. We should therefore not encode the backend in the domain. The same applies for the connectivities. +- The domain in gtfn consists of a start index and a length, instead of a start and stop index. Some trial were done for the cartesian domain (by @tehrengruber), but this has not been pursued any further because of the above issue. diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index f142ff006e..73df7397c6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -275,8 +275,11 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll // TODO(tehrengruber): This should disappear as soon as we introduce a proper builtin. namespace gridtools::fn { - // TODO(tehrengruber): `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type` - // fails as type used for index calculations in gtfn differs + """ + # TODO(tehrengruber): The return type should be + # `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type`, + # but fails as type used for index calculations in gtfn differs + """ template GT_FUNCTION gridtools::tuple get_domain_range(S &&sid, D) { return {gridtools::host_device::at_key(gridtools::sid::get_lower_bounds(sid)), diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index c61d3d30cf..cc37d98e10 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -37,7 +37,6 @@ def decorated_program( if out is not None: args = (*args, out) - # TODO: this doesn't belong here and should by done in the dace backend if not fun.sdfg_program._lastargs: # First call, the SDFG is not intitalized, so forward the call to `CompiledSDFG` # to proper initilize it. Later calls to this SDFG will be handled through From 53a5ac8fd7eebefd2c57bdcd453dfd11dd8906a2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 2 Sep 2025 14:40:21 +0200 Subject: [PATCH 31/93] Remove implicit domain --- src/gt4py/next/ffront/past_to_itir.py | 6 ------ src/gt4py/next/iterator/ir.py | 1 - .../next/otf/compilation/build_systems/compiledb.py | 3 --- src/gt4py/next/otf/compilation/cache.py | 1 - src/gt4py/next/otf/compilation/compiler.py | 9 ++------- src/gt4py/next/otf/stages.py | 7 ------- .../next/program_processors/codegens/gtfn/gtfn_module.py | 1 - .../runners/dace/workflow/compilation.py | 5 +---- .../runners/dace/workflow/translation.py | 1 - src/gt4py/next/program_processors/runners/gtfn.py | 2 +- .../compilation_tests/build_systems_tests/conftest.py | 1 - tests/next_tests/unit_tests/otf_tests/test_languages.py | 2 -- .../runners_tests/dace_tests/test_dace_bindings.py | 1 - 13 files changed, 4 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index e7b9412ebf..0faca56dd1 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -109,7 +109,6 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: params=itir_program.params, declarations=itir_program.declarations, body=body, - implicit_domain=itir_program.implicit_domain, ) if config.DEBUG or inp.data.debug: @@ -237,10 +236,6 @@ def visit_Program( params = self.visit(node.params) - implicit_domain = False - if any("domain" not in body_entry.kwargs for body_entry in node.body): - implicit_domain = True - set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] return itir.Program( id=node.id, @@ -248,7 +243,6 @@ def visit_Program( params=params, declarations=[], body=set_ats, - implicit_domain=implicit_domain, ) def _visit_field_operator_call(self, node: past.Call, **kwargs: Any) -> itir.SetAt: diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e25eaeee1e..79ccc83cd2 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -139,7 +139,6 @@ class Program(Node, ValidatedSymbolTableTrait): params: List[Sym] declarations: List[Temporary] body: List[Stmt] - implicit_domain: bool = False _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ Sym(id=name) for name in sorted(BUILTINS) diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 9017ae1ff1..afff250e46 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -66,7 +66,6 @@ def __call__( cmake_flags=self.cmake_extra_flags or [], language=source.program_source.language, language_settings=source.program_source.language_settings, - implicit_domain=source.program_source.implicit_domain, ) compiledb_template = _cc_get_compiledb( @@ -256,7 +255,6 @@ def _cc_prototype_program_source( cmake_flags: list[str], language: type[SrcL], language_settings: languages.LanguageWithHeaderFilesSettings, - implicit_domain: bool, ) -> stages.ProgramSource: name = _cc_prototype_program_name(deps, build_type.value, cmake_flags) return stages.ProgramSource( @@ -265,7 +263,6 @@ def _cc_prototype_program_source( library_deps=deps, language=language, language_settings=language_settings, - implicit_domain=implicit_domain, ) diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index 430e1d931e..43ceb71fc3 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -39,7 +39,6 @@ def _serialize_source(source: stages.ProgramSource) -> str: params: {", ".join(parameters)} deps: {", ".join(dependencies)} src: {source.source_code} - implicit_domain: {source.implicit_domain} """ diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e6479faf46..e03fa84e50 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -65,7 +65,7 @@ class Compiler( def __call__( self, inp: stages.CompilableSource[SourceLanguageType, LanguageSettingsType, languages.Python], - ) -> stages.ExtendedCompiledProgram: + ) -> stages.CompiledProgram: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -87,12 +87,7 @@ def __call__( importer.import_from_path(src_dir / new_data.module), new_data.entry_point_name ) - @dataclasses.dataclass(frozen=True) - class Wrapper(stages.ExtendedCompiledProgram): - implicit_domain: bool = inp.program_source.implicit_domain - __call__: stages.CompiledProgram = compiled_prog - - return Wrapper() + return compiled_prog class CompilerFactory(factory.Factory): diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index c5bc8bec34..d2eabbed3d 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -82,7 +82,6 @@ class ProgramSource(Generic[SrcL, SettingT]): library_deps: tuple[interface.LibraryDependency, ...] language: type[SrcL] language_settings: SettingT - implicit_domain: bool def __post_init__(self) -> None: if not isinstance(self.language_settings, self.language.settings_class): @@ -144,12 +143,6 @@ class CompiledProgram(Protocol): def __call__(self, *args: Any, **kwargs: Any) -> None: ... -class ExtendedCompiledProgram(CompiledProgram, Protocol): - """Executable python representation of a program with extra info.""" - - implicit_domain: bool - - def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 90441ec61a..327404324e 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -263,7 +263,6 @@ def __call__( source_code=source_code, language=self._language(), language_settings=self._language_settings(), - implicit_domain=inp.data.implicit_domain, ) return module diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index ed36ef1947..e6cb4d8e9f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -48,7 +48,7 @@ def _get_sdfg_ctype_arglist_callback( return module.__dict__[bind_func_name] -class CompiledDaceProgram(stages.ExtendedCompiledProgram): +class CompiledDaceProgram(stages.CompiledProgram): sdfg_program: dace.CompiledSDFG # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; @@ -66,10 +66,8 @@ def __init__( program: dace.CompiledSDFG, bind_func_name: str, binding_source: stages.BindingSource[languages.SDFG, languages.Python], - implicit_domain: bool, ): self.sdfg_program = program - self.implicit_domain = implicit_domain # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument # name to its data type, in the same order as arguments appear in the program ABI. @@ -132,7 +130,6 @@ def __call__( sdfg_program, self.bind_func_name, inp.binding_source, - inp.program_source.implicit_domain, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index af1d312665..d1cac1f6f0 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -291,7 +291,6 @@ def __call__( language_settings=languages.LanguageSettings( formatter_key="", formatter_style="", file_extension="sdfg" ), - implicit_domain=inp.data.implicit_domain, ) ) return module diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 29cacd06e1..5fc38944a4 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -46,7 +46,7 @@ def convert_arg(arg: Any) -> Any: def convert_args( - inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU + inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( *args: Any, diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 75a756f93f..97c848bea9 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -75,7 +75,6 @@ def make_program_source(name: str) -> stages.ProgramSource: library_deps=[interface.LibraryDependency("gridtools_cpu", "master")], language=languages.CPP, language_settings=cpp_interface.CPP_DEFAULT, - implicit_domain=False, ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_languages.py b/tests/next_tests/unit_tests/otf_tests/test_languages.py index 7a3fc0c007..95b48d359c 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_languages.py +++ b/tests/next_tests/unit_tests/otf_tests/test_languages.py @@ -22,7 +22,6 @@ def test_basic_settings_with_cpp_rejected(): language_settings=languages.LanguageSettings( formatter_key="cpp", formatter_style="llvm", file_extension="cpp" ), - implicit_domain=False, ) @@ -38,5 +37,4 @@ def test_header_files_settings_with_cpp_accepted(): file_extension="cpp", header_extension="hpp", ), - implicit_domain=False, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py index d161c69f68..9a813b1b27 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py @@ -176,7 +176,6 @@ def test_bind_sdfg(persistent_config): library_deps=tuple(), language=languages.SDFG, language_settings=_language_settings(), - implicit_domain=False, ) ) From 9c7d9e24e1ad8b0bd5ee005a01de8a0a5565d495 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 2 Sep 2025 14:41:34 +0200 Subject: [PATCH 32/93] Fix format --- .../next/program_processors/codegens/gtfn/codegen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 73df7397c6..029a674912 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -275,11 +275,11 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll // TODO(tehrengruber): This should disappear as soon as we introduce a proper builtin. namespace gridtools::fn { - """ - # TODO(tehrengruber): The return type should be - # `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type`, - # but fails as type used for index calculations in gtfn differs - """ + """ + # TODO(tehrengruber): The return type should be + # `typename gridtools::sid::lower_bounds_type, typename gridtools::sid::upper_bounds_type`, + # but fails as type used for index calculations in gtfn differs + """ template GT_FUNCTION gridtools::tuple get_domain_range(S &&sid, D) { return {gridtools::host_device::at_key(gridtools::sid::get_lower_bounds(sid)), From 30f4b7b2e5f5f8143bc25a5dad7d6c1b6fa1a607 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 11 Sep 2025 21:02:34 +0200 Subject: [PATCH 33/93] Add ArgumentDescriptor mechanism, reworking static args and enabling compile time domain folding --- src/gt4py/next/ffront/decorator.py | 24 +- src/gt4py/next/ffront/past_process_args.py | 9 +- src/gt4py/next/ffront/past_to_itir.py | 21 +- src/gt4py/next/ffront/stages.py | 2 +- .../iterator/transforms/constant_folding.py | 2 +- .../transforms/transform_get_domain_range.py | 5 +- src/gt4py/next/otf/arguments.py | 123 +++++++--- src/gt4py/next/otf/compiled_program.py | 210 +++++++++--------- .../ffront_tests/test_compiled_program.py | 2 +- .../ffront_tests/test_program.py | 2 +- .../otf_tests/test_compiled_program.py | 9 +- 11 files changed, 245 insertions(+), 164 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index fe3e2410fc..df643bc775 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -110,6 +110,7 @@ class Program: static_params: ( Sequence[str] | None ) # if the user requests static params, they will be used later to initialize CompiledPrograms + static_domains: bool _extended_offset_provider_cache: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(common.hash_offset_provider_unsafe), @@ -125,6 +126,7 @@ def from_function( grid_type: common.GridType | None = None, enable_jit: bool | None = None, static_params: Sequence[str] | None = None, + static_domains: bool = False, connectivities: Optional[ common.OffsetProvider ] = None, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information @@ -136,6 +138,7 @@ def from_function( connectivities=connectivities, enable_jit=enable_jit, static_params=static_params, + static_domains=static_domains ) # needed in testing @@ -273,13 +276,29 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.static_params is None: object.__setattr__(self, "static_params", ()) + def path_to_expr(path: Sequence[int]): + return "".join(map(lambda idx: f"[{idx}]", path)) + + static_domain_args = [] + if self.static_domains: + func_type = self.past_stage.past_node.type.definition + param_types = func_type.pos_or_kw_args | func_type.kw_only_args + for name, type_ in param_types.items(): + for el_type, path in type_info.primitive_constituents(type_, with_path_arg=True): + static_domain_args.append(f"{name}{path_to_expr(path)}") + + argument_descriptor_mapping = { + arguments.StaticArg: self.static_params, + arguments.FieldDomainDescriptor: static_domain_args + } + program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) return compiled_program.CompiledProgramsPool( backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, - static_params=self.static_params, + argument_descriptor_mapping=argument_descriptor_mapping, ) def _extend_offset_provider( @@ -574,6 +593,7 @@ def program( grid_type: common.GridType | None, enable_jit: bool | None, static_params: Sequence[str] | None, + static_domains: bool, frozen: bool, ) -> Callable[[types.FunctionType], Program]: ... @@ -586,6 +606,7 @@ def program( grid_type: common.GridType | None = None, enable_jit: bool | None = None, # only relevant if static_params are set static_params: Sequence[str] | None = None, + static_domains: bool = False, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -614,6 +635,7 @@ def program_inner(definition: types.FunctionType) -> Program: ), grid_type=grid_type, enable_jit=enable_jit, + static_domains=static_domains, static_params=static_params, ) if frozen: diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 95db2837dd..b5dc5a12ea 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -34,6 +34,7 @@ def transform_program_args(inp: AOT_PRG) -> AOT_PRG: kwargs=rewritten_kwargs, offset_provider=inp.args.offset_provider, column_axis=inp.args.column_axis, + argument_descriptors=inp.args.argument_descriptors, ), ) @@ -70,13 +71,7 @@ def _process_args( raise TypeError("Can not process arguments for PAST programs prior to type inference.") args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs) - - # validate arguments - arg_types = tuple(arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in args) - kwarg_types = { - k: (v.type_ if isinstance(v, arguments.StaticArg) else v) for k, v in kwargs.items() - } - _validate_args(past_node=past_node, arg_types=arg_types, kwarg_types=kwarg_types) + _validate_args(past_node=past_node, arg_types=args, kwarg_types=kwargs) return args, kwargs diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index e7b9412ebf..4586cdf92a 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -27,9 +27,10 @@ from gt4py.next.ffront.stages import AOT_PRG from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import remap_symbols +from gt4py.next.iterator.transforms import remap_symbols, transform_get_domain_range from gt4py.next.otf import arguments, stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next import utils # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove @@ -95,12 +96,11 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) - static_args_index = { - i: arg.value for i, arg in enumerate(inp.args.args) if isinstance(arg, arguments.StaticArg) - } + static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] + if not all(isinstance(arg_descriptor, arguments.StaticArg) for arg_descriptor in static_arg_descriptors.values()): + raise NotImplementedError("Only top-level arguments can be static.") static_args = { - itir_program.params[i].id: im.literal_from_tuple_value(value) - for i, value in static_args_index.items() + name: im.literal_from_tuple_value(descr.value) for name, descr in static_arg_descriptors.items() } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( @@ -112,6 +112,15 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: implicit_domain=itir_program.implicit_domain, ) + if arguments.FieldDomainDescriptor in inp.args.argument_descriptors: + field_domains = { + param: utils.tree_map(lambda x: x.domain)(v) for param, v in inp.args.argument_descriptors[arguments.FieldDomainDescriptor].items() + } + itir_program = transform_get_domain_range.TransformGetDomainRange.apply( + itir_program, + sizes=field_domains + ) + if config.DEBUG or inp.data.debug: devtools.debug(itir_program) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 2594a13973..caa97ef247 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -129,7 +129,7 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in sorted(obj.items()): + for key, value in sorted(obj.items(), key=lambda x: (x[0].__module__, x[0].__qualname__) if isinstance(x[0], type) else x[0]): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index f9269314fb..0b9321ef33 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -77,7 +77,7 @@ class Transformation(enum.Flag): # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` FOLD_MIN_MAX = enum.auto() - # `maximum(a + 1), a)` -> `a + 1` + # `maximum(a + 1, a)` -> `a + 1` # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` FOLD_MIN_MAX_PLUS = enum.auto() diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py b/src/gt4py/next/iterator/transforms/transform_get_domain_range.py index c34ba61a28..27de217d10 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py +++ b/src/gt4py/next/iterator/transforms/transform_get_domain_range.py @@ -115,7 +115,4 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: index = next((i for i, d in enumerate(domain.dims) if d.value == dim.value), None) assert index is not None, f"Dimension {dim.value} not found in {domain.dims}" - start = domain.ranges[index].start - stop = domain.ranges[index].stop - node = im.make_tuple(start, stop) - return node + return im.make_tuple(domain.ranges[index].start, domain.ranges[index].stop) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 04f4bcf1e8..2d60e122ff 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -9,24 +9,94 @@ from __future__ import annotations import dataclasses +import functools import typing -from typing import Any, Generic, Optional +from pickletools import ArgumentDescriptor +from typing import Any, Generic, Optional, TypeAlias from typing_extensions import Self from gt4py._core import definitions as core_defs -from gt4py.next import common +from gt4py.eve import extended_typing +from gt4py.next import common, utils, Field, errors from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_specifications as ts, type_translation - +from gt4py.next.type_system import type_specifications as ts, type_translation, type_info +from gt4py.next.type_system.type_info import apply_to_primitive_constituents DATA_T = typing.TypeVar("DATA_T") +T = typing.TypeVar("T") +TOrTupleOf: TypeAlias = T | tuple["TupleOf[T]", ...] +ArgumentDescriptorT = typing.TypeVar("ArgumentDescriptorT", bound=ArgumentDescriptor) + +class PartialValue(Generic[ArgumentDescriptorT]): + attrs: dict[str, Any] + items: dict[Any, Any] + + def __init__(self): + object.__setattr__(self, "attrs", {}) + object.__setattr__(self, "items", {}) + def __setattr__(self, key: str, value: Any) -> None: + object.__getattribute__(self, "attrs")[key] = value + + def __setitem__(self, key: Any, value: Any) -> None: + object.__getattribute__(self, "items")[key] = value + + @property + def empty(self): + return not self.attrs and not self.items + +class ArgumentDescriptor: + def validate(self, name: str, type_: ts.TypeSpec): + """ + Validate argument descriptor. + + This function is called when the type of the argument is available. The name is merely + given to give good error messages. + """ + pass + + @classmethod + def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: + """ + Return a mapping from the attributes of our descriptor to the expressions to retrieve them. + + E.g. if `arg_expr` would be `myarg` and the result of this function + `{'value': 'my_arg.value'}` then the descriptor is constructed as + `ArgumentDescriptor(value=my_arg.value)`. We use expression here such that we can compute + a cache key by just hashing `my_arg.value` instead of first constructing the descriptor. + """ + ... + +@dataclasses.dataclass(frozen=True) +class StaticArg(ArgumentDescriptor, Generic[T]): + value: TOrTupleOf[core_defs.ScalarT] + + def validate(self, name: str, type_: ts.TypeSpec): + if not type_info.is_type_or_tuple_of_type(type_, ts.ScalarType): + raise errors.DSLTypeError( + message=f"Invalid static argument '{name}' with type '{type_}' (only scalars or (nested) tuples of scalars can be static).", + location=None, + ) + + actual_type = type_translation.from_value(self.value) + if actual_type != type_: + raise errors.DSLTypeError( + message=f"Invalid static argument '{name}', expected '{type_}', but static value '{self.value}' has type '{actual_type}'.", + location=None, + ) + + @classmethod + def attribute_extractor(cls, arg_expr: str): + return {"value": arg_expr} @dataclasses.dataclass(frozen=True) -class StaticArg(Generic[core_defs.ScalarT]): - value: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] - type_: ts.TypeSpec +class FieldDomainDescriptor(ArgumentDescriptor): + domain: common.Domain + + @classmethod + def attribute_extractor(cls, arg_expr: str): + return {"domain": f"({arg_expr}).domain"} @dataclasses.dataclass(frozen=True) @@ -45,10 +115,11 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" - args: tuple[ts.TypeSpec | StaticArg, ...] - kwargs: dict[str, ts.TypeSpec | StaticArg] + args: tuple[ts.TypeSpec, ...] + kwargs: dict[str, ts.TypeSpec] offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + argument_descriptors: dict[type[ArgumentDescriptor], PartialValue[ArgumentDescriptor]] @property def offset_provider_type(self) -> common.OffsetProviderType: @@ -57,21 +128,23 @@ def offset_provider_type(self) -> common.OffsetProviderType: @classmethod def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" - compile_args = tuple(type_translation.from_value(arg) for arg in args) - kwargs_copy = kwargs.copy() - offset_provider = kwargs_copy.pop("offset_provider", {}) + kwargs = kwargs.copy() + offset_provider = kwargs.pop("offset_provider", {}) + column_axis = kwargs.pop("column_axis", None) + compile_args = tuple(StaticArg.from_value(arg) for arg in args) + compile_kwargs = { + k: StaticArg.from_value(v) for k, v in kwargs.items() if v is not None + } return cls( args=compile_args, + kwargs=compile_kwargs, offset_provider=offset_provider, - column_axis=kwargs_copy.pop("column_axis", None), - kwargs={ - k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None - }, + column_axis=column_axis, ) @classmethod def empty(cls) -> Self: - return cls(tuple(), {}, {}, None) + return cls(tuple(), {}, {}, None, {}) def jit_to_aot_args( @@ -85,18 +158,4 @@ def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ toolchain.CompilableProgram[DATA_T, CompileTimeArgs], ]: """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" - return toolchain.ArgsOnlyAdapter(jit_to_aot_args) - - -def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: - for element in tuple_arg: - match element: - case tuple(): - found = find_first_field(element) - if found: - return found - case common.Field(): - return element - case _: - pass - return None + return toolchain.ArgsOnlyAdapter(jit_to_aot_args) \ No newline at end of file diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 174ca8edb1..b7022a65a4 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -13,7 +13,7 @@ import functools import itertools from collections.abc import Sequence -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Callable, DefaultDict from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing, utils as eve_utils @@ -37,75 +37,16 @@ def _hash_compiled_program_unsafe(cp_key: CompiledProgramsKey) -> int: assert common.is_offset_provider_type(offset_provider) return hash((values, id(offset_provider))) +def _make_tuple_expr(el_exprs): + return "".join((f"{el},") for el in el_exprs) -def _validate_types( - program_name: str, - static_args: dict[str, ScalarOrTupleOfScalars], - program_type: ts_ffront.ProgramType, -) -> None: - unknown_args = list( - set(static_args.keys()) - set(program_type.definition.pos_or_kw_args.keys()) - ) - if unknown_args: - raise errors.DSLTypeError( - message=f"Invalid static arguments provided for '{program_name}' with type '{program_type}', the following are not parameters of the program:\n" - + ("\n".join([f" - '{arg}'" for arg in unknown_args])), - location=None, - ) - - param_types = program_type.definition.pos_or_kw_args - - if type_errors := [ - f"'{name}' with type '{param_types[name]}' cannot be static." - for name in static_args - if not type_info.is_type_or_tuple_of_type(param_types[name], ts.ScalarType) - ]: - raise errors.DSLTypeError( - message=f"Invalid static arguments provided for '{program_name}' with type '{program_type}' (only scalars or (nested) tuples of scalars can be static):\n" - + ("\n".join([f" - {error}" for error in type_errors])), - location=None, - ) - - static_arg_types = {name: tt.from_value(value) for name, value in static_args.items()} - types_from_values = [ - static_arg_types[name] # use type of the provided value for the static arguments - if name in static_arg_types - else type_ # else use the type from progam_type which will never mismatch - for name, type_ in program_type.definition.pos_or_kw_args.items() - ] - assert not program_type.definition.pos_only_args - assert not program_type.definition.kw_only_args - - if mismatch_errors := list( - type_info.function_signature_incompatibilities( - program_type, args=types_from_values, kwargs={} - ) - ): - raise errors.DSLTypeError( - message=f"Invalid static argument types when trying to compile '{program_name}' with type '{program_type}':\n" - + ("\n".join([f" - {error}" for error in mismatch_errors])), - location=None, - ) - - -def _sanitize_static_args( - program_name: str, - static_args: dict[str, ScalarOrTupleOfScalars], - program_type: ts_ffront.ProgramType, -) -> dict[str, ScalarOrTupleOfScalars]: - """ - Sanitize static arguments to be used in the program compilation. - - This function will convert all values to their corresponding type - and check that the types are compatible with the program type. - """ - _validate_types(program_name, static_args, program_type) - - return { - name: tt.unsafe_cast_to(value, program_type.definition.pos_or_kw_args[name]) # type: ignore[arg-type] # checked in _validate_types - for name, value in static_args.items() - } - +def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str): + # TODO: error handling + func_type = program_type.definition + params = func_type.pos_or_kw_args | func_type.kw_only_args + vars = {param: type_info.apply_to_primitive_constituents(lambda x: x, type_, tuple_constructor=lambda *els: tuple(els)) + for param, type_ in params.items()} + return eval(expr, vars) @dataclasses.dataclass class CompiledProgramsPool: @@ -125,7 +66,10 @@ class CompiledProgramsPool: backend: gtx_backend.Backend definition_stage: ffront_stages.ProgramDefinition program_type: ts_ffront.ProgramType - static_params: Sequence[str] | None = None # not ordered + #: mapping from an argument descriptor type to a list of parameters or expression thereof + #: e.g. `{arguments.StaticArg: ["static_int_param"]}` + #: Note: The list is not ordered. + argument_descriptor_mapping: dict[type[arguments.ArgumentDescriptor], list[str]] | None _compiled_programs: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe), @@ -154,7 +98,8 @@ def __call__( it is an error. """ args, kwargs = type_info.canonicalize_arguments(self.program_type, args, kwargs) - static_args_values = tuple(args[i] for i in self._static_arg_indices) + static_args_values = self._argument_descriptor_cache_key_from_args(*args, **kwargs) + # TODO: dispatching over offset provider type is wrong. especially when we use compile time domains. test? key = (static_args_values, self._offset_provider_to_type_unsafe(offset_provider)) try: self._compiled_programs[key](*args, **kwargs, offset_provider=offset_provider) @@ -164,45 +109,102 @@ def __call__( program(*args, **kwargs, offset_provider=offset_provider) except KeyError as e: if enable_jit: - assert self.static_params is not None - static_args = { - name: value - for name, value in zip(self.static_params, static_args_values, strict=True) - } - self._compile_variant(static_args=static_args, offset_provider=offset_provider) + assert self.argument_descriptor_mapping is not None + self._compile_variant( + argument_descriptors=self._make_argument_descriptors(*args, **kwargs), + offset_provider=offset_provider + ) return self( *args, offset_provider=offset_provider, enable_jit=False, **kwargs ) # passing `enable_jit=False` because a cache miss should be a hard-error in this call` raise RuntimeError("No program compiled for this set of static arguments.") from e + # TODO: test that compares _argument_descriptor_cache_key_from_args with _argument_descriptor_cache_key_from_descriptor @functools.cached_property - def _static_arg_indices(self) -> tuple[int, ...]: - if self.static_params is None: - # this could also be done in `__call__` but would be an extra if in the fast path - self.static_params = () + def _argument_descriptor_cache_key_from_args(self) -> Callable: + func_type = self.program_type.definition + params = func_type.pos_only_args + list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) + elements = [] + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for arg_expr in arg_exprs: + attr_extractor = descriptor_cls.attribute_extractor(arg_expr) + elements.extend(attr_extractor.values()) + return eval(f"""lambda {",".join(params)}: ({_make_tuple_expr(elements)})""") + + def _argument_descriptor_cache_key_from_descriptor(self, argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]) -> tuple: + elements = [] + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for arg_expr in arg_exprs: + attr_extractor = descriptor_cls.attribute_extractor(arg_expr) + attrs = attr_extractor.keys() + for attr in attrs: + elements.append(getattr(eval(f"{arg_expr}", argument_descriptors[descriptor_cls]), attr)) + return tuple(elements) - all_params = list(self.program_type.definition.pos_or_kw_args.keys()) - return tuple(all_params.index(p) for p in self.static_params) + @functools.cached_property + def _make_argument_descriptors(self) -> Callable: + def make_dict_expr(exprs: dict[str, str]): + return "{"+",".join((f"'{k}': {v}" for k, v in exprs.items()))+"}" + + # for each argument expression build a lambda function that constructs (the attributes of) + # its argument descriptor + func_type = self.program_type.definition + params = func_type.pos_only_args + list(func_type.pos_or_kw_args.keys()) + list( + func_type.kw_only_args.keys()) + descriptor_attrs = DefaultDict(dict) + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for arg_expr in arg_exprs: + attr_exprs = descriptor_cls.attribute_extractor(arg_expr) + descriptor_attrs[descriptor_cls][arg_expr] = eval(f"""lambda {",".join(params)}: {make_dict_expr(attr_exprs)}""") + + def _impl(*args, **kwargs): + descriptors = {} + for descriptor_cls, expr_descriptor_attr_mapping in descriptor_attrs.items(): + descriptors[descriptor_cls] = {} + for expr, attr in expr_descriptor_attr_mapping.items(): + descriptor = descriptor_cls(**attr(*args, **kwargs)) + descriptors[descriptor_cls][expr] = descriptor + self.validate_argument_descriptors(descriptors) + return descriptors + return _impl + + def validate_argument_descriptors(self, all_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]): + for descriptors in all_descriptors.values(): + for expr, descriptor in descriptors.items(): + param_type = _get_type_of_param_expr(self.program_type, expr) # TODO: error handling if type is wrong + descriptor.validate(expr, param_type) def _compile_variant( self, - static_args: dict[str, ScalarOrTupleOfScalars], + argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]], offset_provider: common.OffsetProviderType | common.OffsetProvider, ) -> None: - assert self.static_params is not None - static_args = _sanitize_static_args( - self.definition_stage.definition.__name__, static_args, self.program_type - ) - - args = tuple( - arguments.StaticArg(value=static_args[name], type_=type_) - if name in self.static_params - else type_ - for name, type_ in self.program_type.definition.pos_or_kw_args.items() - ) + if self.argument_descriptor_mapping is None: + self.argument_descriptor_mapping = {descr_cls: descriptor_expr_mapping.keys() for descr_cls, descriptor_expr_mapping in argument_descriptors.items()} + else: + for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): + if (expected := set(self.argument_descriptor_mapping[descr_cls])) != (got := set(descriptor_expr_mapping.keys())): + raise ValueError( + f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs. Got {list(got)}, expected {list(expected)}." + ) + + self.validate_argument_descriptors(argument_descriptors) + + assert self.argument_descriptor_mapping is not None + + named_params = self.program_type.definition.pos_or_kw_args.keys() | self.program_type.definition.kw_only_args.keys() + + # TODO: use full structure + # TODO: check that + structured_descriptors = DefaultDict(lambda: {k: arguments.PartialValue() for k in named_params}) + for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): + assert "__descriptor" not in structured_descriptors + for expr, descriptor in descriptor_expr_mapping.items(): + # TODO: gracefully catch error + exec(f"{expr} = __descriptor", {"__descriptor": descriptor}, structured_descriptors[descriptor_cls]) key = ( - tuple(static_args[p] for p in self.static_params), + self._argument_descriptor_cache_key_from_descriptor(structured_descriptors), self._offset_provider_to_type_unsafe(offset_provider), ) if key in self._compiled_programs: @@ -211,8 +213,9 @@ def _compile_variant( compile_time_args = arguments.CompileTimeArgs( offset_provider=offset_provider, # type:ignore[arg-type] # TODO(havogt): resolve OffsetProviderType vs OffsetProvider column_axis=None, # TODO(havogt): column_axis seems to a unused, even for programs with scans - args=args, - kwargs={}, + args=tuple(self.program_type.definition.pos_only_args) + tuple(self.program_type.definition.pos_or_kw_args.values()), + kwargs=self.program_type.definition.kw_only_args, + argument_descriptors=argument_descriptors, ) self._compiled_programs[key] = _async_compilation_pool.submit( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args @@ -250,17 +253,12 @@ def compile( pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3]) will compile for (0,2), (1,3) """ - if self.static_params is None: - self.static_params = tuple(static_args.keys()) - elif set(self.static_params) != set(static_args.keys()): - raise ValueError( - f"Static arguments must be the same for all compiled programs. Got {list(static_args.keys())}, expected {self.static_params}." - ) - for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): self._compile_variant( - dict(zip(static_args.keys(), static_values, strict=True)), + argument_descriptors={ + arguments.StaticArg: dict(zip(static_args.keys(), map(arguments.StaticArg, static_values), strict=True)), + }, offset_provider=offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 55d87c1eae..dae2c7a975 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -748,7 +748,7 @@ def test_compile_variants_non_existing_param(cartesian_case, compile_variants_te def test_compile_variants_wrong_type(cartesian_case, compile_variants_testee_not_compiled): - with pytest.raises(errors.DSLTypeError, match="Expected.*'scalar_int'.*int32"): + with pytest.raises(errors.DSLTypeError, match="'scalar_int'.*expected.*int32"): compile_variants_testee_not_compiled.compile(scalar_int=[1.0], offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 27c4252e14..88613e9c99 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -82,7 +82,7 @@ def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatFie def test_copy_execution(cartesian_case, copy_program_def): - copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) + copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend, static_domains=True) cases.verify_with_default_data(cartesian_case, copy_program, ref=lambda in_field: in_field) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index e4a07da525..7a47552d40 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -138,12 +138,13 @@ def _verify_program_has_expected_true_value(program: itir.Program): def test_inlining_of_scalars_works(): - args = prog.past_stage.past_node.type.definition.pos_or_kw_args - args = [arguments.StaticArg(value=True, type_=v) if k == "cond" else v for k, v in args.items()] - input_pair = toolchain.CompilableProgram( data=prog.definition_stage, - args=arguments.CompileTimeArgs(args=args, kwargs={}, offset_provider={}, column_axis=None), + args=arguments.CompileTimeArgs( + args=list(prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), + kwargs={}, offset_provider={}, column_axis=None, + argument_descriptors={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}} + ), ) transformed = backend.DEFAULT_TRANSFORMS(input_pair).data From 19f8f7af1717a342a9e011f61b5a60feeeb7d473 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 11 Sep 2025 21:15:19 +0200 Subject: [PATCH 34/93] Add argument descriptors and rework static args mechanism --- src/gt4py/next/ffront/decorator.py | 9 +- src/gt4py/next/ffront/past_process_args.py | 9 +- src/gt4py/next/ffront/past_to_itir.py | 9 +- src/gt4py/next/ffront/stages.py | 2 +- src/gt4py/next/otf/arguments.py | 115 +++++++--- src/gt4py/next/otf/compiled_program.py | 210 +++++++++--------- .../ffront_tests/test_compiled_program.py | 2 +- .../otf_tests/test_compiled_program.py | 9 +- 8 files changed, 208 insertions(+), 157 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index fe3e2410fc..b4f5bb42b5 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -273,13 +273,20 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.static_params is None: object.__setattr__(self, "static_params", ()) + def path_to_expr(path: Sequence[int]): + return "".join(map(lambda idx: f"[{idx}]", path)) + + argument_descriptor_mapping = { + arguments.StaticArg: self.static_params, + } + program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) return compiled_program.CompiledProgramsPool( backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, - static_params=self.static_params, + argument_descriptor_mapping=argument_descriptor_mapping, ) def _extend_offset_provider( diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 95db2837dd..b5dc5a12ea 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -34,6 +34,7 @@ def transform_program_args(inp: AOT_PRG) -> AOT_PRG: kwargs=rewritten_kwargs, offset_provider=inp.args.offset_provider, column_axis=inp.args.column_axis, + argument_descriptors=inp.args.argument_descriptors, ), ) @@ -70,13 +71,7 @@ def _process_args( raise TypeError("Can not process arguments for PAST programs prior to type inference.") args, kwargs = type_info.canonicalize_arguments(past_node.type, args, kwargs) - - # validate arguments - arg_types = tuple(arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in args) - kwarg_types = { - k: (v.type_ if isinstance(v, arguments.StaticArg) else v) for k, v in kwargs.items() - } - _validate_args(past_node=past_node, arg_types=arg_types, kwarg_types=kwarg_types) + _validate_args(past_node=past_node, arg_types=args, kwarg_types=kwargs) return args, kwargs diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c002e92d44..3315aecdcc 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -95,12 +95,11 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) - static_args_index = { - i: arg.value for i, arg in enumerate(inp.args.args) if isinstance(arg, arguments.StaticArg) - } + static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] + if not all(isinstance(arg_descriptor, arguments.StaticArg) for arg_descriptor in static_arg_descriptors.values()): + raise NotImplementedError("Only top-level arguments can be static.") static_args = { - itir_program.params[i].id: im.literal_from_tuple_value(value) - for i, value in static_args_index.items() + name: im.literal_from_tuple_value(descr.value) for name, descr in static_arg_descriptors.items() } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 2594a13973..caa97ef247 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -129,7 +129,7 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in sorted(obj.items()): + for key, value in sorted(obj.items(), key=lambda x: (x[0].__module__, x[0].__qualname__) if isinstance(x[0], type) else x[0]): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 04f4bcf1e8..885dfb9d8a 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -9,24 +9,86 @@ from __future__ import annotations import dataclasses +import functools import typing -from typing import Any, Generic, Optional +from pickletools import ArgumentDescriptor +from typing import Any, Generic, Optional, TypeAlias from typing_extensions import Self from gt4py._core import definitions as core_defs -from gt4py.next import common +from gt4py.eve import extended_typing +from gt4py.next import common, utils, errors from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_specifications as ts, type_translation - +from gt4py.next.type_system import type_specifications as ts, type_translation, type_info +from gt4py.next.type_system.type_info import apply_to_primitive_constituents DATA_T = typing.TypeVar("DATA_T") +T = typing.TypeVar("T") +TOrTupleOf: TypeAlias = T | tuple["TupleOf[T]", ...] +ArgumentDescriptorT = typing.TypeVar("ArgumentDescriptorT", bound=ArgumentDescriptor) + +class PartialValue(Generic[ArgumentDescriptorT]): + attrs: dict[str, Any] + items: dict[Any, Any] + + def __init__(self): + object.__setattr__(self, "attrs", {}) + object.__setattr__(self, "items", {}) + + def __setattr__(self, key: str, value: Any) -> None: + object.__getattribute__(self, "attrs")[key] = value + + def __setitem__(self, key: Any, value: Any) -> None: + object.__getattribute__(self, "items")[key] = value + + @property + def empty(self): + return not self.attrs and not self.items +class ArgumentDescriptor: + def validate(self, name: str, type_: ts.TypeSpec): + """ + Validate argument descriptor. + + This function is called when the type of the argument is available. The name is merely + given to give good error messages. + """ + pass + + @classmethod + def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: + """ + Return a mapping from the attributes of our descriptor to the expressions to retrieve them. + + E.g. if `arg_expr` would be `myarg` and the result of this function + `{'value': 'my_arg.value'}` then the descriptor is constructed as + `ArgumentDescriptor(value=my_arg.value)`. We use expression here such that we can compute + a cache key by just hashing `my_arg.value` instead of first constructing the descriptor. + """ + ... @dataclasses.dataclass(frozen=True) -class StaticArg(Generic[core_defs.ScalarT]): - value: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] - type_: ts.TypeSpec +class StaticArg(ArgumentDescriptor, Generic[T]): + value: TOrTupleOf[core_defs.ScalarT] + + def validate(self, name: str, type_: ts.TypeSpec): + if not type_info.is_type_or_tuple_of_type(type_, ts.ScalarType): + raise errors.DSLTypeError( + message=f"Invalid static argument '{name}' with type '{type_}' (only scalars or (nested) tuples of scalars can be static).", + location=None, + ) + + actual_type = type_translation.from_value(self.value) + if actual_type != type_: + raise errors.DSLTypeError( + message=f"Invalid static argument '{name}', expected '{type_}', but static value '{self.value}' has type '{actual_type}'.", + location=None, + ) + + @classmethod + def attribute_extractor(cls, arg_expr: str): + return {"value": arg_expr} @dataclasses.dataclass(frozen=True) @@ -45,10 +107,11 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" - args: tuple[ts.TypeSpec | StaticArg, ...] - kwargs: dict[str, ts.TypeSpec | StaticArg] + args: tuple[ts.TypeSpec, ...] + kwargs: dict[str, ts.TypeSpec] offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + argument_descriptors: dict[type[ArgumentDescriptor], PartialValue[ArgumentDescriptor]] @property def offset_provider_type(self) -> common.OffsetProviderType: @@ -57,21 +120,23 @@ def offset_provider_type(self) -> common.OffsetProviderType: @classmethod def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" - compile_args = tuple(type_translation.from_value(arg) for arg in args) - kwargs_copy = kwargs.copy() - offset_provider = kwargs_copy.pop("offset_provider", {}) + kwargs = kwargs.copy() + offset_provider = kwargs.pop("offset_provider", {}) + column_axis = kwargs.pop("column_axis", None) + compile_args = tuple(StaticArg.from_value(arg) for arg in args) + compile_kwargs = { + k: StaticArg.from_value(v) for k, v in kwargs.items() if v is not None + } return cls( args=compile_args, + kwargs=compile_kwargs, offset_provider=offset_provider, - column_axis=kwargs_copy.pop("column_axis", None), - kwargs={ - k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None - }, + column_axis=column_axis, ) @classmethod def empty(cls) -> Self: - return cls(tuple(), {}, {}, None) + return cls(tuple(), {}, {}, None, {}) def jit_to_aot_args( @@ -85,18 +150,4 @@ def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ toolchain.CompilableProgram[DATA_T, CompileTimeArgs], ]: """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" - return toolchain.ArgsOnlyAdapter(jit_to_aot_args) - - -def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: - for element in tuple_arg: - match element: - case tuple(): - found = find_first_field(element) - if found: - return found - case common.Field(): - return element - case _: - pass - return None + return toolchain.ArgsOnlyAdapter(jit_to_aot_args) \ No newline at end of file diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 174ca8edb1..b7022a65a4 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -13,7 +13,7 @@ import functools import itertools from collections.abc import Sequence -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Callable, DefaultDict from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing, utils as eve_utils @@ -37,75 +37,16 @@ def _hash_compiled_program_unsafe(cp_key: CompiledProgramsKey) -> int: assert common.is_offset_provider_type(offset_provider) return hash((values, id(offset_provider))) +def _make_tuple_expr(el_exprs): + return "".join((f"{el},") for el in el_exprs) -def _validate_types( - program_name: str, - static_args: dict[str, ScalarOrTupleOfScalars], - program_type: ts_ffront.ProgramType, -) -> None: - unknown_args = list( - set(static_args.keys()) - set(program_type.definition.pos_or_kw_args.keys()) - ) - if unknown_args: - raise errors.DSLTypeError( - message=f"Invalid static arguments provided for '{program_name}' with type '{program_type}', the following are not parameters of the program:\n" - + ("\n".join([f" - '{arg}'" for arg in unknown_args])), - location=None, - ) - - param_types = program_type.definition.pos_or_kw_args - - if type_errors := [ - f"'{name}' with type '{param_types[name]}' cannot be static." - for name in static_args - if not type_info.is_type_or_tuple_of_type(param_types[name], ts.ScalarType) - ]: - raise errors.DSLTypeError( - message=f"Invalid static arguments provided for '{program_name}' with type '{program_type}' (only scalars or (nested) tuples of scalars can be static):\n" - + ("\n".join([f" - {error}" for error in type_errors])), - location=None, - ) - - static_arg_types = {name: tt.from_value(value) for name, value in static_args.items()} - types_from_values = [ - static_arg_types[name] # use type of the provided value for the static arguments - if name in static_arg_types - else type_ # else use the type from progam_type which will never mismatch - for name, type_ in program_type.definition.pos_or_kw_args.items() - ] - assert not program_type.definition.pos_only_args - assert not program_type.definition.kw_only_args - - if mismatch_errors := list( - type_info.function_signature_incompatibilities( - program_type, args=types_from_values, kwargs={} - ) - ): - raise errors.DSLTypeError( - message=f"Invalid static argument types when trying to compile '{program_name}' with type '{program_type}':\n" - + ("\n".join([f" - {error}" for error in mismatch_errors])), - location=None, - ) - - -def _sanitize_static_args( - program_name: str, - static_args: dict[str, ScalarOrTupleOfScalars], - program_type: ts_ffront.ProgramType, -) -> dict[str, ScalarOrTupleOfScalars]: - """ - Sanitize static arguments to be used in the program compilation. - - This function will convert all values to their corresponding type - and check that the types are compatible with the program type. - """ - _validate_types(program_name, static_args, program_type) - - return { - name: tt.unsafe_cast_to(value, program_type.definition.pos_or_kw_args[name]) # type: ignore[arg-type] # checked in _validate_types - for name, value in static_args.items() - } - +def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str): + # TODO: error handling + func_type = program_type.definition + params = func_type.pos_or_kw_args | func_type.kw_only_args + vars = {param: type_info.apply_to_primitive_constituents(lambda x: x, type_, tuple_constructor=lambda *els: tuple(els)) + for param, type_ in params.items()} + return eval(expr, vars) @dataclasses.dataclass class CompiledProgramsPool: @@ -125,7 +66,10 @@ class CompiledProgramsPool: backend: gtx_backend.Backend definition_stage: ffront_stages.ProgramDefinition program_type: ts_ffront.ProgramType - static_params: Sequence[str] | None = None # not ordered + #: mapping from an argument descriptor type to a list of parameters or expression thereof + #: e.g. `{arguments.StaticArg: ["static_int_param"]}` + #: Note: The list is not ordered. + argument_descriptor_mapping: dict[type[arguments.ArgumentDescriptor], list[str]] | None _compiled_programs: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe), @@ -154,7 +98,8 @@ def __call__( it is an error. """ args, kwargs = type_info.canonicalize_arguments(self.program_type, args, kwargs) - static_args_values = tuple(args[i] for i in self._static_arg_indices) + static_args_values = self._argument_descriptor_cache_key_from_args(*args, **kwargs) + # TODO: dispatching over offset provider type is wrong. especially when we use compile time domains. test? key = (static_args_values, self._offset_provider_to_type_unsafe(offset_provider)) try: self._compiled_programs[key](*args, **kwargs, offset_provider=offset_provider) @@ -164,45 +109,102 @@ def __call__( program(*args, **kwargs, offset_provider=offset_provider) except KeyError as e: if enable_jit: - assert self.static_params is not None - static_args = { - name: value - for name, value in zip(self.static_params, static_args_values, strict=True) - } - self._compile_variant(static_args=static_args, offset_provider=offset_provider) + assert self.argument_descriptor_mapping is not None + self._compile_variant( + argument_descriptors=self._make_argument_descriptors(*args, **kwargs), + offset_provider=offset_provider + ) return self( *args, offset_provider=offset_provider, enable_jit=False, **kwargs ) # passing `enable_jit=False` because a cache miss should be a hard-error in this call` raise RuntimeError("No program compiled for this set of static arguments.") from e + # TODO: test that compares _argument_descriptor_cache_key_from_args with _argument_descriptor_cache_key_from_descriptor @functools.cached_property - def _static_arg_indices(self) -> tuple[int, ...]: - if self.static_params is None: - # this could also be done in `__call__` but would be an extra if in the fast path - self.static_params = () + def _argument_descriptor_cache_key_from_args(self) -> Callable: + func_type = self.program_type.definition + params = func_type.pos_only_args + list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) + elements = [] + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for arg_expr in arg_exprs: + attr_extractor = descriptor_cls.attribute_extractor(arg_expr) + elements.extend(attr_extractor.values()) + return eval(f"""lambda {",".join(params)}: ({_make_tuple_expr(elements)})""") + + def _argument_descriptor_cache_key_from_descriptor(self, argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]) -> tuple: + elements = [] + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for arg_expr in arg_exprs: + attr_extractor = descriptor_cls.attribute_extractor(arg_expr) + attrs = attr_extractor.keys() + for attr in attrs: + elements.append(getattr(eval(f"{arg_expr}", argument_descriptors[descriptor_cls]), attr)) + return tuple(elements) - all_params = list(self.program_type.definition.pos_or_kw_args.keys()) - return tuple(all_params.index(p) for p in self.static_params) + @functools.cached_property + def _make_argument_descriptors(self) -> Callable: + def make_dict_expr(exprs: dict[str, str]): + return "{"+",".join((f"'{k}': {v}" for k, v in exprs.items()))+"}" + + # for each argument expression build a lambda function that constructs (the attributes of) + # its argument descriptor + func_type = self.program_type.definition + params = func_type.pos_only_args + list(func_type.pos_or_kw_args.keys()) + list( + func_type.kw_only_args.keys()) + descriptor_attrs = DefaultDict(dict) + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for arg_expr in arg_exprs: + attr_exprs = descriptor_cls.attribute_extractor(arg_expr) + descriptor_attrs[descriptor_cls][arg_expr] = eval(f"""lambda {",".join(params)}: {make_dict_expr(attr_exprs)}""") + + def _impl(*args, **kwargs): + descriptors = {} + for descriptor_cls, expr_descriptor_attr_mapping in descriptor_attrs.items(): + descriptors[descriptor_cls] = {} + for expr, attr in expr_descriptor_attr_mapping.items(): + descriptor = descriptor_cls(**attr(*args, **kwargs)) + descriptors[descriptor_cls][expr] = descriptor + self.validate_argument_descriptors(descriptors) + return descriptors + return _impl + + def validate_argument_descriptors(self, all_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]): + for descriptors in all_descriptors.values(): + for expr, descriptor in descriptors.items(): + param_type = _get_type_of_param_expr(self.program_type, expr) # TODO: error handling if type is wrong + descriptor.validate(expr, param_type) def _compile_variant( self, - static_args: dict[str, ScalarOrTupleOfScalars], + argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]], offset_provider: common.OffsetProviderType | common.OffsetProvider, ) -> None: - assert self.static_params is not None - static_args = _sanitize_static_args( - self.definition_stage.definition.__name__, static_args, self.program_type - ) - - args = tuple( - arguments.StaticArg(value=static_args[name], type_=type_) - if name in self.static_params - else type_ - for name, type_ in self.program_type.definition.pos_or_kw_args.items() - ) + if self.argument_descriptor_mapping is None: + self.argument_descriptor_mapping = {descr_cls: descriptor_expr_mapping.keys() for descr_cls, descriptor_expr_mapping in argument_descriptors.items()} + else: + for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): + if (expected := set(self.argument_descriptor_mapping[descr_cls])) != (got := set(descriptor_expr_mapping.keys())): + raise ValueError( + f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs. Got {list(got)}, expected {list(expected)}." + ) + + self.validate_argument_descriptors(argument_descriptors) + + assert self.argument_descriptor_mapping is not None + + named_params = self.program_type.definition.pos_or_kw_args.keys() | self.program_type.definition.kw_only_args.keys() + + # TODO: use full structure + # TODO: check that + structured_descriptors = DefaultDict(lambda: {k: arguments.PartialValue() for k in named_params}) + for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): + assert "__descriptor" not in structured_descriptors + for expr, descriptor in descriptor_expr_mapping.items(): + # TODO: gracefully catch error + exec(f"{expr} = __descriptor", {"__descriptor": descriptor}, structured_descriptors[descriptor_cls]) key = ( - tuple(static_args[p] for p in self.static_params), + self._argument_descriptor_cache_key_from_descriptor(structured_descriptors), self._offset_provider_to_type_unsafe(offset_provider), ) if key in self._compiled_programs: @@ -211,8 +213,9 @@ def _compile_variant( compile_time_args = arguments.CompileTimeArgs( offset_provider=offset_provider, # type:ignore[arg-type] # TODO(havogt): resolve OffsetProviderType vs OffsetProvider column_axis=None, # TODO(havogt): column_axis seems to a unused, even for programs with scans - args=args, - kwargs={}, + args=tuple(self.program_type.definition.pos_only_args) + tuple(self.program_type.definition.pos_or_kw_args.values()), + kwargs=self.program_type.definition.kw_only_args, + argument_descriptors=argument_descriptors, ) self._compiled_programs[key] = _async_compilation_pool.submit( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args @@ -250,17 +253,12 @@ def compile( pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3]) will compile for (0,2), (1,3) """ - if self.static_params is None: - self.static_params = tuple(static_args.keys()) - elif set(self.static_params) != set(static_args.keys()): - raise ValueError( - f"Static arguments must be the same for all compiled programs. Got {list(static_args.keys())}, expected {self.static_params}." - ) - for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): self._compile_variant( - dict(zip(static_args.keys(), static_values, strict=True)), + argument_descriptors={ + arguments.StaticArg: dict(zip(static_args.keys(), map(arguments.StaticArg, static_values), strict=True)), + }, offset_provider=offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 55d87c1eae..dae2c7a975 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -748,7 +748,7 @@ def test_compile_variants_non_existing_param(cartesian_case, compile_variants_te def test_compile_variants_wrong_type(cartesian_case, compile_variants_testee_not_compiled): - with pytest.raises(errors.DSLTypeError, match="Expected.*'scalar_int'.*int32"): + with pytest.raises(errors.DSLTypeError, match="'scalar_int'.*expected.*int32"): compile_variants_testee_not_compiled.compile(scalar_int=[1.0], offset_provider={}) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index e4a07da525..7a47552d40 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -138,12 +138,13 @@ def _verify_program_has_expected_true_value(program: itir.Program): def test_inlining_of_scalars_works(): - args = prog.past_stage.past_node.type.definition.pos_or_kw_args - args = [arguments.StaticArg(value=True, type_=v) if k == "cond" else v for k, v in args.items()] - input_pair = toolchain.CompilableProgram( data=prog.definition_stage, - args=arguments.CompileTimeArgs(args=args, kwargs={}, offset_provider={}, column_axis=None), + args=arguments.CompileTimeArgs( + args=list(prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), + kwargs={}, offset_provider={}, column_axis=None, + argument_descriptors={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}} + ), ) transformed = backend.DEFAULT_TRANSFORMS(input_pair).data From 1702aecb0cfc148c9aaaffe03172c5efa389b687 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 11 Sep 2025 21:22:59 +0200 Subject: [PATCH 35/93] Cleanup --- src/gt4py/next/ffront/decorator.py | 3 --- src/gt4py/next/otf/arguments.py | 18 +++++++----------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index b4f5bb42b5..2bd8f1a8b5 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -273,9 +273,6 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.static_params is None: object.__setattr__(self, "static_params", ()) - def path_to_expr(path: Sequence[int]): - return "".join(map(lambda idx: f"[{idx}]", path)) - argument_descriptor_mapping = { arguments.StaticArg: self.static_params, } diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 885dfb9d8a..ed76e8ca42 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -120,18 +120,14 @@ def offset_provider_type(self) -> common.OffsetProviderType: @classmethod def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" - kwargs = kwargs.copy() - offset_provider = kwargs.pop("offset_provider", {}) - column_axis = kwargs.pop("column_axis", None) - compile_args = tuple(StaticArg.from_value(arg) for arg in args) - compile_kwargs = { - k: StaticArg.from_value(v) for k, v in kwargs.items() if v is not None - } + kwargs_copy = kwargs.copy() return cls( - args=compile_args, - kwargs=compile_kwargs, - offset_provider=offset_provider, - column_axis=column_axis, + args=tuple(type_translation.from_value(arg) for arg in args), + offset_provider=kwargs_copy.pop("offset_provider", {}), + column_axis=kwargs_copy.pop("column_axis", None), + kwargs={ + k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None + }, ) @classmethod From e5aac5074dda09147b63e14a1225e35ae3ef343f Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 11 Sep 2025 22:25:46 +0200 Subject: [PATCH 36/93] Cleanup --- src/gt4py/next/otf/arguments.py | 33 +++--- src/gt4py/next/otf/compiled_program.py | 51 +++++---- .../otf_tests/test_compiled_program.py | 101 +++--------------- 3 files changed, 63 insertions(+), 122 deletions(-) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index ed76e8ca42..b91c68b7ee 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import enum import functools import typing from pickletools import ArgumentDescriptor @@ -28,24 +29,6 @@ TOrTupleOf: TypeAlias = T | tuple["TupleOf[T]", ...] ArgumentDescriptorT = typing.TypeVar("ArgumentDescriptorT", bound=ArgumentDescriptor) -class PartialValue(Generic[ArgumentDescriptorT]): - attrs: dict[str, Any] - items: dict[Any, Any] - - def __init__(self): - object.__setattr__(self, "attrs", {}) - object.__setattr__(self, "items", {}) - - def __setattr__(self, key: str, value: Any) -> None: - object.__getattribute__(self, "attrs")[key] = value - - def __setitem__(self, key: Any, value: Any) -> None: - object.__getattribute__(self, "items")[key] = value - - @property - def empty(self): - return not self.attrs and not self.items - class ArgumentDescriptor: def validate(self, name: str, type_: ts.TypeSpec): """ @@ -72,6 +55,11 @@ def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: class StaticArg(ArgumentDescriptor, Generic[T]): value: TOrTupleOf[core_defs.ScalarT] + def __post_init__(self): + # transform enum value into the actual value + if isinstance(self.value, enum.Enum): + object.__setattr__(self, "value", self.value.value) + def validate(self, name: str, type_: ts.TypeSpec): if not type_info.is_type_or_tuple_of_type(type_, ts.ScalarType): raise errors.DSLTypeError( @@ -91,6 +79,13 @@ def attribute_extractor(cls, arg_expr: str): return {"value": arg_expr} +class _RuntimeArgument: + pass + +#: Sentinel value used to describe that there is no ArgumentDescriptor for an argument. Can be +#: used by transformation passes for consistency checks. +RUNTIME_ARGUMENT = _RuntimeArgument() + @dataclasses.dataclass(frozen=True) class JITArgs: """Concrete (runtime) arguments to a GTX program in a format that can be passed into the toolchain.""" @@ -111,7 +106,7 @@ class CompileTimeArgs: kwargs: dict[str, ts.TypeSpec] offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] - argument_descriptors: dict[type[ArgumentDescriptor], PartialValue[ArgumentDescriptor]] + argument_descriptors: dict[type[ArgumentDescriptor], extended_typing.MaybeNestedInTuple[ArgumentDescriptor | _RuntimeArgument]] @property def offset_provider_type(self) -> common.OffsetProviderType: diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index b7022a65a4..86e8a1c0a4 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -40,13 +40,15 @@ def _hash_compiled_program_unsafe(cp_key: CompiledProgramsKey) -> int: def _make_tuple_expr(el_exprs): return "".join((f"{el},") for el in el_exprs) +def _make_param_context_from_func_type(func_type: ts.FunctionType, type_map: Callable[[ts.TypeSpec], Any] = lambda x: x): + params = func_type.pos_or_kw_args | func_type.kw_only_args + return {param: type_info.apply_to_primitive_constituents( + type_map, type_, tuple_constructor=lambda *els: tuple(els) + ) for param, type_ in params.items()} + def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str): # TODO: error handling - func_type = program_type.definition - params = func_type.pos_or_kw_args | func_type.kw_only_args - vars = {param: type_info.apply_to_primitive_constituents(lambda x: x, type_, tuple_constructor=lambda *els: tuple(els)) - for param, type_ in params.items()} - return eval(expr, vars) + return eval(expr, _make_param_context_from_func_type(program_type.definition)) @dataclasses.dataclass class CompiledProgramsPool: @@ -81,11 +83,12 @@ class CompiledProgramsPool: init=False, ) # cache the offset provider type in order to avoid recomputing it at each program call - def __postinit__(self) -> None: + def __post_init__(self) -> None: # TODO(havogt): We currently don't support pos_only or kw_only args at the program level. # This check makes sure we don't miss updating this code if we add support for them in the future. assert not self.program_type.definition.kw_only_args assert not self.program_type.definition.pos_only_args + self._validate_argument_descriptor_mapping() def __call__( self, *args: Any, offset_provider: common.OffsetProvider, enable_jit: bool, **kwargs: Any @@ -164,16 +167,31 @@ def _impl(*args, **kwargs): for expr, attr in expr_descriptor_attr_mapping.items(): descriptor = descriptor_cls(**attr(*args, **kwargs)) descriptors[descriptor_cls][expr] = descriptor - self.validate_argument_descriptors(descriptors) + self._validate_argument_descriptors(descriptors) return descriptors return _impl - def validate_argument_descriptors(self, all_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]): + def _validate_argument_descriptors(self, all_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]): for descriptors in all_descriptors.values(): for expr, descriptor in descriptors.items(): param_type = _get_type_of_param_expr(self.program_type, expr) # TODO: error handling if type is wrong descriptor.validate(expr, param_type) + def _validate_argument_descriptor_mapping(self): + if self.argument_descriptor_mapping is None: + return + context = _make_param_context_from_func_type(self.program_type.definition, lambda x: None) + for descr_cls, exprs in self.argument_descriptor_mapping.items(): + for expr in exprs: + try: + assert eval(expr, context) == None + except: + raise errors.DSLTypeError( + message=f"Invalid parameter expression '{expr}' for '{descr_cls.__name__}'. " + f"Must be the name of a parameter or an access to one of its elements.", + location=None, + ) + def _compile_variant( self, argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]], @@ -181,6 +199,7 @@ def _compile_variant( ) -> None: if self.argument_descriptor_mapping is None: self.argument_descriptor_mapping = {descr_cls: descriptor_expr_mapping.keys() for descr_cls, descriptor_expr_mapping in argument_descriptors.items()} + self._validate_argument_descriptor_mapping() else: for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): if (expected := set(self.argument_descriptor_mapping[descr_cls])) != (got := set(descriptor_expr_mapping.keys())): @@ -188,19 +207,15 @@ def _compile_variant( f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs. Got {list(got)}, expected {list(expected)}." ) - self.validate_argument_descriptors(argument_descriptors) - - assert self.argument_descriptor_mapping is not None - - named_params = self.program_type.definition.pos_or_kw_args.keys() | self.program_type.definition.kw_only_args.keys() + self._validate_argument_descriptors(argument_descriptors) - # TODO: use full structure - # TODO: check that - structured_descriptors = DefaultDict(lambda: {k: arguments.PartialValue() for k in named_params}) + structured_descriptors = {} for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): - assert "__descriptor" not in structured_descriptors + structured_descriptors[descriptor_cls] = _make_param_context_from_func_type(self.program_type.definition, lambda x: arguments.RUNTIME_ARGUMENT) + assert "__descriptor" not in structured_descriptors[descriptor_cls] for expr, descriptor in descriptor_expr_mapping.items(): - # TODO: gracefully catch error + # note: we don't need to handle any errors here since the `expr` has been validated + # in `_validate_argument_descriptor_mapping` exec(f"{expr} = __descriptor", {"__descriptor": descriptor}, structured_descriptors[descriptor_cls]) key = ( diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 7a47552d40..1ff565a566 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -16,99 +16,30 @@ from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners import gtfn +def test_static_arg_from_enum(): + class SomeEnum(eve.IntEnum): + FOO = 1 -class SomeEnum(eve.IntEnum): - FOO = 1 - - -@pytest.mark.parametrize( - "value, type_, expected", - [ - (gtx.int32(1), ts.ScalarType(kind=ts.ScalarKind.INT32), gtx.int32(1)), - (gtx.int64(1), ts.ScalarType(kind=ts.ScalarKind.INT64), gtx.int64(1)), - (1, ts.ScalarType(kind=ts.ScalarKind.INT32), gtx.int32(1)), - (True, ts.ScalarType(kind=ts.ScalarKind.BOOL), True), - (False, ts.ScalarType(kind=ts.ScalarKind.BOOL), False), - (SomeEnum.FOO, ts.ScalarType(kind=ts.ScalarKind.INT32), gtx.int32(1)), - ( - (1, (2.0, gtx.float32(3.0))), - ts.TupleType( - types=[ - ts.ScalarType(kind=ts.ScalarKind.INT32), - ts.TupleType( - types=[ - ts.ScalarType(kind=ts.ScalarKind.FLOAT64), - ts.ScalarType(kind=ts.ScalarKind.FLOAT32), - ] - ), - ] - ), - (gtx.float32(1), (gtx.float64(2.0), gtx.float32(3.0))), - ), - ], -) -def test_sanitize_static_args(value, type_, expected): - program_type = ts_ffront.ProgramType( - definition=ts.FunctionType( - pos_only_args=[], - pos_or_kw_args={ - "testee": type_, - }, - kw_only_args={}, - returns=ts.VoidType(), - ) - ) + static_arg = arguments.StaticArg(value=SomeEnum.FOO) + assert static_arg.value == 1 - result = compiled_program._sanitize_static_args( - "testee_program", {"testee": value}, program_type - ) - assert result == {"testee": expected} - - -def test_sanitize_static_args_non_scalar_type(): - program_type = ts_ffront.ProgramType( - definition=ts.FunctionType( - pos_only_args=[], - pos_or_kw_args={ - "foo": ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) - }, - kw_only_args={}, - returns=ts.VoidType(), - ) - ) + +def test_static_args_non_scalar_type(): with pytest.raises( errors.DSLTypeError, - match="foo.*cannot be static", + match="only scalars.*can be static", ): - compiled_program._sanitize_static_args("foo_program", {"foo": gtx.int32(1)}, program_type) + static_arg = arguments.StaticArg(value=1) + static_arg.validate("foo", ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) def test_sanitize_static_args_wrong_type(): - program_type = ts_ffront.ProgramType( - definition=ts.FunctionType( - pos_only_args=[], - pos_or_kw_args={"foo": ts.ScalarType(kind=ts.ScalarKind.INT32)}, - kw_only_args={}, - returns=ts.VoidType(), - ) - ) - with pytest.raises(errors.DSLTypeError, match="got 'int64'"): - compiled_program._sanitize_static_args("foo_program", {"foo": gtx.int64(1)}, program_type) - - -def test_sanitize_static_args_non_existing_parameter(): - program_type = ts_ffront.ProgramType( - definition=ts.FunctionType( - pos_only_args=[], - pos_or_kw_args={"foo": ts.ScalarType(kind=ts.ScalarKind.INT32)}, - kw_only_args={}, - returns=ts.VoidType(), - ) - ) - with pytest.raises(errors.DSLTypeError, match="'unknown_param'"): - compiled_program._sanitize_static_args( - "foo_program", {"unknown_param": gtx.int64(1)}, program_type - ) + with pytest.raises( + errors.DSLTypeError, + match="expected 'int32'.*has.*'int64'", + ): + static_arg = arguments.StaticArg(value=gtx.int64(1)) + static_arg.validate("foo", ts.ScalarType(kind=ts.ScalarKind.INT32)) TDim = gtx.Dimension("TDim") From f83925df940ab07cadd1554417bb81519aec0bfa Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:21:29 +0200 Subject: [PATCH 37/93] Cleanup --- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/ffront/foast_to_past.py | 10 +- src/gt4py/next/ffront/past_process_args.py | 6 +- src/gt4py/next/ffront/past_to_itir.py | 9 +- src/gt4py/next/ffront/stages.py | 5 +- src/gt4py/next/otf/arguments.py | 43 ++--- src/gt4py/next/otf/compiled_program.py | 164 ++++++++++++------ .../codegens/gtfn/gtfn_module.py | 7 +- .../runners/dace/program.py | 1 + .../runners/dace/workflow/translation.py | 6 +- .../otf_tests/test_compiled_program.py | 11 +- 11 files changed, 159 insertions(+), 105 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 2bd8f1a8b5..7f825a8bf5 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -283,7 +283,7 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, - argument_descriptor_mapping=argument_descriptor_mapping, + argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible ) def _extend_offset_provider( diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index afc2b1689e..8adb7ea87f 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -20,7 +20,7 @@ from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction from gt4py.next.ffront.stages import AOT_FOP, AOT_PRG from gt4py.next.iterator import ir as itir -from gt4py.next.otf import arguments, toolchain, workflow +from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -95,13 +95,7 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG: # TODO(tehrengruber): check foast operator has no out argument that clashes # with the out argument of the program we generate here. - arg_types = tuple( - arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in inp.args.args - ) - kwarg_types = { - k: v.type_ if isinstance(v, arguments.StaticArg) else v - for k, v in inp.args.kwargs.items() - } + arg_types, kwarg_types = inp.args.args, inp.args.kwargs loc = inp.data.foast_node.location # use a new UID generator to allow caching diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index b5dc5a12ea..81f9363822 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -64,9 +64,9 @@ def _validate_args( def _process_args( past_node: past.Program, - args: Sequence[ts.TypeSpec | arguments.StaticArg], - kwargs: dict[str, ts.TypeSpec | arguments.StaticArg], -) -> tuple[tuple, dict[str, Any]]: + args: Sequence[ts.TypeSpec], + kwargs: dict[str, ts.TypeSpec], +) -> tuple[tuple[ts.TypeSpec], dict[str, ts.TypeSpec]]: if not isinstance(past_node.type, ts_ffront.ProgramType): raise TypeError("Can not process arguments for PAST programs prior to type inference.") diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 3315aecdcc..60ba7023f0 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -96,10 +96,15 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ) static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] - if not all(isinstance(arg_descriptor, arguments.StaticArg) for arg_descriptor in static_arg_descriptors.values()): + if not all( + isinstance(arg_descriptor, arguments.StaticArg) + for arg_descriptor in static_arg_descriptors.values() + ): raise NotImplementedError("Only top-level arguments can be static.") static_args = { - name: im.literal_from_tuple_value(descr.value) for name, descr in static_arg_descriptors.items() + name: im.literal_from_tuple_value(descr.value) + for name, descr in static_arg_descriptors.items() + if isinstance(descr, arguments.StaticArg) } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index caa97ef247..46c2e9811d 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -129,7 +129,10 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in sorted(obj.items(), key=lambda x: (x[0].__module__, x[0].__qualname__) if isinstance(x[0], type) else x[0]): + for key, value in sorted( + obj.items(), + key=lambda x: (x[0].__module__, x[0].__qualname__) if isinstance(x[0], type) else x[0], + ): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index b91c68b7ee..6c621e3c68 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -10,27 +10,24 @@ import dataclasses import enum -import functools import typing -from pickletools import ArgumentDescriptor -from typing import Any, Generic, Optional, TypeAlias +from typing import Any, Generic, Mapping, Optional from typing_extensions import Self from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing -from gt4py.next import common, utils, errors +from gt4py.next import common, errors from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_specifications as ts, type_translation, type_info -from gt4py.next.type_system.type_info import apply_to_primitive_constituents +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation + DATA_T = typing.TypeVar("DATA_T") T = typing.TypeVar("T") -TOrTupleOf: TypeAlias = T | tuple["TupleOf[T]", ...] -ArgumentDescriptorT = typing.TypeVar("ArgumentDescriptorT", bound=ArgumentDescriptor) + class ArgumentDescriptor: - def validate(self, name: str, type_: ts.TypeSpec): + def validate(self, name: str, type_: ts.TypeSpec) -> None: """ Validate argument descriptor. @@ -40,7 +37,7 @@ def validate(self, name: str, type_: ts.TypeSpec): pass @classmethod - def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: + def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: # type: ignore[empty-body] # classmethod is abstract """ Return a mapping from the attributes of our descriptor to the expressions to retrieve them. @@ -51,16 +48,17 @@ def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: """ ... + @dataclasses.dataclass(frozen=True) -class StaticArg(ArgumentDescriptor, Generic[T]): - value: TOrTupleOf[core_defs.ScalarT] +class StaticArg(ArgumentDescriptor, Generic[core_defs.ScalarT]): + value: extended_typing.MaybeNestedInTuple[core_defs.ScalarT] - def __post_init__(self): + def __post_init__(self) -> None: # transform enum value into the actual value if isinstance(self.value, enum.Enum): object.__setattr__(self, "value", self.value.value) - def validate(self, name: str, type_: ts.TypeSpec): + def validate(self, name: str, type_: ts.TypeSpec) -> None: if not type_info.is_type_or_tuple_of_type(type_, ts.ScalarType): raise errors.DSLTypeError( message=f"Invalid static argument '{name}' with type '{type_}' (only scalars or (nested) tuples of scalars can be static).", @@ -75,17 +73,10 @@ def validate(self, name: str, type_: ts.TypeSpec): ) @classmethod - def attribute_extractor(cls, arg_expr: str): + def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: return {"value": arg_expr} -class _RuntimeArgument: - pass - -#: Sentinel value used to describe that there is no ArgumentDescriptor for an argument. Can be -#: used by transformation passes for consistency checks. -RUNTIME_ARGUMENT = _RuntimeArgument() - @dataclasses.dataclass(frozen=True) class JITArgs: """Concrete (runtime) arguments to a GTX program in a format that can be passed into the toolchain.""" @@ -106,7 +97,10 @@ class CompileTimeArgs: kwargs: dict[str, ts.TypeSpec] offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] - argument_descriptors: dict[type[ArgumentDescriptor], extended_typing.MaybeNestedInTuple[ArgumentDescriptor | _RuntimeArgument]] + argument_descriptors: Mapping[ + type[ArgumentDescriptor], + dict[str, ArgumentDescriptor], + ] @property def offset_provider_type(self) -> common.OffsetProviderType: @@ -123,6 +117,7 @@ def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None }, + argument_descriptors={}, ) @classmethod @@ -141,4 +136,4 @@ def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ toolchain.CompilableProgram[DATA_T, CompileTimeArgs], ]: """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" - return toolchain.ArgsOnlyAdapter(jit_to_aot_args) \ No newline at end of file + return toolchain.ArgsOnlyAdapter(jit_to_aot_args) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 86e8a1c0a4..8e49d060fd 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -12,17 +12,18 @@ import dataclasses import functools import itertools -from collections.abc import Sequence -from typing import Any, TypeAlias, Callable, DefaultDict +from typing import Any, Callable, DefaultDict, Sequence, TypeAlias, TypeVar from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing, utils as eve_utils from gt4py.next import backend as gtx_backend, common, config, errors from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront from gt4py.next.otf import arguments, stages -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt +from gt4py.next.type_system import type_info, type_specifications as ts +T = TypeVar("T") + # TODO(havogt): We would like this to be a ProcessPoolExecutor, which requires (to decide what) to pickle. _async_compilation_pool = concurrent.futures.ThreadPoolExecutor(max_workers=config.BUILD_JOBS) @@ -30,6 +31,9 @@ CompiledProgramsKey: TypeAlias = tuple[ tuple[ScalarOrTupleOfScalars, ...], common.OffsetProviderType ] +ArgumentDescriptors: TypeAlias = dict[ + type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor] +] def _hash_compiled_program_unsafe(cp_key: CompiledProgramsKey) -> int: @@ -37,19 +41,28 @@ def _hash_compiled_program_unsafe(cp_key: CompiledProgramsKey) -> int: assert common.is_offset_provider_type(offset_provider) return hash((values, id(offset_provider))) -def _make_tuple_expr(el_exprs): + +def _make_tuple_expr(el_exprs: list[str]) -> str: return "".join((f"{el},") for el in el_exprs) -def _make_param_context_from_func_type(func_type: ts.FunctionType, type_map: Callable[[ts.TypeSpec], Any] = lambda x: x): + +def _make_param_context_from_func_type( + func_type: ts.FunctionType, + type_map: Callable[[ts.TypeSpec], T] = lambda x: x, # type: ignore[assignment, return-value] # mypy not smart enough to narrow type for default +) -> dict[str, extended_typing.MaybeNestedInTuple[T]]: params = func_type.pos_or_kw_args | func_type.kw_only_args - return {param: type_info.apply_to_primitive_constituents( - type_map, type_, tuple_constructor=lambda *els: tuple(els) - ) for param, type_ in params.items()} + return { + param: type_info.apply_to_primitive_constituents( + type_map, type_, tuple_constructor=lambda *els: tuple(els) + ) + for param, type_ in params.items() + } + -def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str): - # TODO: error handling +def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str) -> ts.TypeSpec: return eval(expr, _make_param_context_from_func_type(program_type.definition)) + @dataclasses.dataclass class CompiledProgramsPool: """ @@ -71,7 +84,7 @@ class CompiledProgramsPool: #: mapping from an argument descriptor type to a list of parameters or expression thereof #: e.g. `{arguments.StaticArg: ["static_int_param"]}` #: Note: The list is not ordered. - argument_descriptor_mapping: dict[type[arguments.ArgumentDescriptor], list[str]] | None + argument_descriptor_mapping: dict[type[arguments.ArgumentDescriptor], Sequence[str]] | None _compiled_programs: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe), @@ -115,7 +128,7 @@ def __call__( assert self.argument_descriptor_mapping is not None self._compile_variant( argument_descriptors=self._make_argument_descriptors(*args, **kwargs), - offset_provider=offset_provider + offset_provider=offset_provider, ) return self( *args, offset_provider=offset_provider, enable_jit=False, **kwargs @@ -126,83 +139,113 @@ def __call__( @functools.cached_property def _argument_descriptor_cache_key_from_args(self) -> Callable: func_type = self.program_type.definition - params = func_type.pos_only_args + list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) - elements = [] - for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) + elements: list[str] = [] + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: attr_extractor = descriptor_cls.attribute_extractor(arg_expr) elements.extend(attr_extractor.values()) return eval(f"""lambda {",".join(params)}: ({_make_tuple_expr(elements)})""") - def _argument_descriptor_cache_key_from_descriptor(self, argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]) -> tuple: + def _argument_descriptor_cache_key_from_structured_descriptors( + self, + argument_descriptors: dict[ + type[arguments.ArgumentDescriptor], + dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgumentDescriptor | None]], + ], + ) -> tuple: elements = [] - for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: attr_extractor = descriptor_cls.attribute_extractor(arg_expr) attrs = attr_extractor.keys() for attr in attrs: - elements.append(getattr(eval(f"{arg_expr}", argument_descriptors[descriptor_cls]), attr)) + elements.append( + getattr(eval(f"{arg_expr}", argument_descriptors[descriptor_cls]), attr) + ) return tuple(elements) @functools.cached_property - def _make_argument_descriptors(self) -> Callable: - def make_dict_expr(exprs: dict[str, str]): - return "{"+",".join((f"'{k}': {v}" for k, v in exprs.items()))+"}" + def _descriptor_attr_retrievers( + self, + ) -> dict[type[arguments.ArgumentDescriptor], dict[str, Callable]]: + """ + For each argument expression build a lambda function that constructs (the attributes of) + its argument descriptor + """ + + def make_dict_expr(exprs: dict[str, str]) -> str: + return "{" + ",".join((f"'{k}': {v}" for k, v in exprs.items())) + "}" - # for each argument expression build a lambda function that constructs (the attributes of) + # # its argument descriptor func_type = self.program_type.definition - params = func_type.pos_only_args + list(func_type.pos_or_kw_args.keys()) + list( - func_type.kw_only_args.keys()) - descriptor_attrs = DefaultDict(dict) - for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): + params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) + retrievers: dict[type[arguments.ArgumentDescriptor], dict[str, Callable]] = DefaultDict( + dict + ) + for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: attr_exprs = descriptor_cls.attribute_extractor(arg_expr) - descriptor_attrs[descriptor_cls][arg_expr] = eval(f"""lambda {",".join(params)}: {make_dict_expr(attr_exprs)}""") - - def _impl(*args, **kwargs): - descriptors = {} - for descriptor_cls, expr_descriptor_attr_mapping in descriptor_attrs.items(): - descriptors[descriptor_cls] = {} - for expr, attr in expr_descriptor_attr_mapping.items(): - descriptor = descriptor_cls(**attr(*args, **kwargs)) - descriptors[descriptor_cls][expr] = descriptor - self._validate_argument_descriptors(descriptors) - return descriptors - return _impl - - def _validate_argument_descriptors(self, all_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]]): + retrievers[descriptor_cls][arg_expr] = eval( + f"""lambda {",".join(params)}: {make_dict_expr(attr_exprs)}""" + ) + + return retrievers + + def _make_argument_descriptors(self, *args: Any, **kwargs: Any) -> ArgumentDescriptors: + descriptors: ArgumentDescriptors = {} + for descriptor_cls, attr_retrievers in self._descriptor_attr_retrievers.items(): + descriptors[descriptor_cls] = {} + for expr, attr_retriever in attr_retrievers.items(): + descriptor = descriptor_cls(**attr_retriever(*args, **kwargs)) + descriptors[descriptor_cls][expr] = descriptor + self._validate_argument_descriptors(descriptors) + return descriptors + + def _validate_argument_descriptors( + self, + all_descriptors: ArgumentDescriptors, + ) -> None: for descriptors in all_descriptors.values(): for expr, descriptor in descriptors.items(): - param_type = _get_type_of_param_expr(self.program_type, expr) # TODO: error handling if type is wrong + param_type = _get_type_of_param_expr( + self.program_type, expr + ) # TODO: error handling if type is wrong descriptor.validate(expr, param_type) - def _validate_argument_descriptor_mapping(self): + def _validate_argument_descriptor_mapping(self) -> None: if self.argument_descriptor_mapping is None: return context = _make_param_context_from_func_type(self.program_type.definition, lambda x: None) for descr_cls, exprs in self.argument_descriptor_mapping.items(): for expr in exprs: try: - assert eval(expr, context) == None - except: - raise errors.DSLTypeError( + if eval(expr, context) is not None: + raise ValueError() + except (ValueError, KeyError): + raise errors.DSLTypeError( # noqa: B904 # we don't care about the original exception message=f"Invalid parameter expression '{expr}' for '{descr_cls.__name__}'. " - f"Must be the name of a parameter or an access to one of its elements.", + f"Must be the name of a parameter or an access to one of its elements.", location=None, ) def _compile_variant( self, - argument_descriptors: dict[type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor]], + argument_descriptors: ArgumentDescriptors, offset_provider: common.OffsetProviderType | common.OffsetProvider, ) -> None: if self.argument_descriptor_mapping is None: - self.argument_descriptor_mapping = {descr_cls: descriptor_expr_mapping.keys() for descr_cls, descriptor_expr_mapping in argument_descriptors.items()} + self.argument_descriptor_mapping = { + descr_cls: list(descriptor_expr_mapping.keys()) + for descr_cls, descriptor_expr_mapping in argument_descriptors.items() + } self._validate_argument_descriptor_mapping() else: for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): - if (expected := set(self.argument_descriptor_mapping[descr_cls])) != (got := set(descriptor_expr_mapping.keys())): + if (expected := set(self.argument_descriptor_mapping[descr_cls])) != ( + got := set(descriptor_expr_mapping.keys()) + ): raise ValueError( f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs. Got {list(got)}, expected {list(expected)}." ) @@ -211,15 +254,21 @@ def _compile_variant( structured_descriptors = {} for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): - structured_descriptors[descriptor_cls] = _make_param_context_from_func_type(self.program_type.definition, lambda x: arguments.RUNTIME_ARGUMENT) + structured_descriptors[descriptor_cls] = _make_param_context_from_func_type( + self.program_type.definition, lambda x: None + ) assert "__descriptor" not in structured_descriptors[descriptor_cls] for expr, descriptor in descriptor_expr_mapping.items(): # note: we don't need to handle any errors here since the `expr` has been validated # in `_validate_argument_descriptor_mapping` - exec(f"{expr} = __descriptor", {"__descriptor": descriptor}, structured_descriptors[descriptor_cls]) + exec( + f"{expr} = __descriptor", + {"__descriptor": descriptor}, + structured_descriptors[descriptor_cls], + ) key = ( - self._argument_descriptor_cache_key_from_descriptor(structured_descriptors), + self._argument_descriptor_cache_key_from_structured_descriptors(structured_descriptors), # type: ignore[arg-type] # mypy not smart enough self._offset_provider_to_type_unsafe(offset_provider), ) if key in self._compiled_programs: @@ -228,7 +277,8 @@ def _compile_variant( compile_time_args = arguments.CompileTimeArgs( offset_provider=offset_provider, # type:ignore[arg-type] # TODO(havogt): resolve OffsetProviderType vs OffsetProvider column_axis=None, # TODO(havogt): column_axis seems to a unused, even for programs with scans - args=tuple(self.program_type.definition.pos_only_args) + tuple(self.program_type.definition.pos_or_kw_args.values()), + args=tuple(self.program_type.definition.pos_only_args) + + tuple(self.program_type.definition.pos_or_kw_args.values()), kwargs=self.program_type.definition.kw_only_args, argument_descriptors=argument_descriptors, ) @@ -272,7 +322,13 @@ def compile( for static_values in itertools.product(*static_args.values()): self._compile_variant( argument_descriptors={ - arguments.StaticArg: dict(zip(static_args.keys(), map(arguments.StaticArg, static_values), strict=True)), + arguments.StaticArg: dict( + zip( + static_args.keys(), + [arguments.StaticArg(value=v) for v in static_values], + strict=True, + ) + ), }, offset_provider=offset_provider, ) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index e71e8a7783..5bc796de9d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -21,7 +21,7 @@ from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import pass_manager -from gt4py.next.otf import arguments, languages, stages, step_types, workflow +from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering @@ -85,7 +85,6 @@ def _process_regular_arguments( arg_exprs: list[str] = [] for arg_type, program_param in zip(arg_types, program.params, strict=True): - arg_type = arg_type.type_ if isinstance(arg_type, arguments.StaticArg) else arg_type # parameter parameter = get_param_description(program_param.id, arg_type) parameters.append(parameter) @@ -214,9 +213,7 @@ def __call__( # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) - arg_types = tuple( - arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in inp.args.args - ) + arg_types = inp.args.args regular_parameters, regular_args_expr = self._process_regular_arguments( program, arg_types, inp.args.offset_provider_type ) diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index acf12880e5..6e5fcc29c2 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -55,6 +55,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: kwargs={}, column_axis=column_axis, offset_provider=offset_provider, + argument_descriptors={}, ), ) ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 73313ae9a4..8055f7b0d8 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config, metrics from gt4py.next.iterator import ir as itir, transforms as itir_transforms -from gt4py.next.otf import arguments, languages, stages, step_types, workflow +from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings from gt4py.next.program_processors.runners.dace import ( @@ -276,9 +276,7 @@ def __call__( inp.args.column_axis, ) - arg_types = tuple( - arg.type_ if isinstance(arg, arguments.StaticArg) else arg for arg in inp.args.args - ) + arg_types = inp.args.args program_parameters = tuple( interface.Parameter(param.id, arg_type) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 1ff565a566..1c132773e0 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners import gtfn + def test_static_arg_from_enum(): class SomeEnum(eve.IntEnum): FOO = 1 @@ -30,7 +31,9 @@ def test_static_args_non_scalar_type(): match="only scalars.*can be static", ): static_arg = arguments.StaticArg(value=1) - static_arg.validate("foo", ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) + static_arg.validate( + "foo", ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + ) def test_sanitize_static_args_wrong_type(): @@ -73,8 +76,10 @@ def test_inlining_of_scalars_works(): data=prog.definition_stage, args=arguments.CompileTimeArgs( args=list(prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), - kwargs={}, offset_provider={}, column_axis=None, - argument_descriptors={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}} + kwargs={}, + offset_provider={}, + column_axis=None, + argument_descriptors={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}}, ), ) From 2d00cb643fcbd89284446ac373b542a1dc32cec8 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:22:39 +0200 Subject: [PATCH 38/93] Cleanup --- src/gt4py/next/ffront/foast_to_past.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 8adb7ea87f..1668aa0ae2 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -75,6 +75,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... kwargs={}, ... offset_provider={"I", IDim}, ... column_axis=None, + ... argument_descriptors={} ... ) >>> copy_program = op_to_prog( From 44773ea7a632ef2549e96a4da02a0769239f2d95 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:26:40 +0200 Subject: [PATCH 39/93] Cleanup --- src/gt4py/next/otf/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 6c621e3c68..647af295c3 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -43,7 +43,7 @@ def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: # type: ignore[e E.g. if `arg_expr` would be `myarg` and the result of this function `{'value': 'my_arg.value'}` then the descriptor is constructed as - `ArgumentDescriptor(value=my_arg.value)`. We use expression here such that we can compute + `ArgumentDescriptor(value=my_arg.value)`. We use an expression here such that we can compute a cache key by just hashing `my_arg.value` instead of first constructing the descriptor. """ ... From 4f455730a51d370954c46450a4d4ffc6c283f223 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:29:22 +0200 Subject: [PATCH 40/93] Cleanup --- src/gt4py/next/ffront/foast_to_past.py | 2 +- src/gt4py/next/otf/compiled_program.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 1668aa0ae2..272823ad6b 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -75,7 +75,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... kwargs={}, ... offset_provider={"I", IDim}, ... column_axis=None, - ... argument_descriptors={} + ... argument_descriptors={}, ... ) >>> copy_program = op_to_prog( diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 8e49d060fd..b283b3125b 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -135,7 +135,6 @@ def __call__( ) # passing `enable_jit=False` because a cache miss should be a hard-error in this call` raise RuntimeError("No program compiled for this set of static arguments.") from e - # TODO: test that compares _argument_descriptor_cache_key_from_args with _argument_descriptor_cache_key_from_descriptor @functools.cached_property def _argument_descriptor_cache_key_from_args(self) -> Callable: func_type = self.program_type.definition From b01d6eeb29228b8400bce34b5efb798e998234fa Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:29:55 +0200 Subject: [PATCH 41/93] Cleanup --- src/gt4py/next/otf/compiled_program.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index b283b3125b..6bf5dedbf4 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -176,8 +176,6 @@ def _descriptor_attr_retrievers( def make_dict_expr(exprs: dict[str, str]) -> str: return "{" + ",".join((f"'{k}': {v}" for k, v in exprs.items())) + "}" - # - # its argument descriptor func_type = self.program_type.definition params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) retrievers: dict[type[arguments.ArgumentDescriptor], dict[str, Callable]] = DefaultDict( From fa9fa57fb0239a79c51572891cf9841c3b79d632 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:31:12 +0200 Subject: [PATCH 42/93] Cleanup --- src/gt4py/next/otf/compiled_program.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 6bf5dedbf4..da2ca7ab5e 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -60,7 +60,9 @@ def _make_param_context_from_func_type( def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str) -> ts.TypeSpec: - return eval(expr, _make_param_context_from_func_type(program_type.definition)) + type_ = eval(expr, _make_param_context_from_func_type(program_type.definition)) + assert isinstance(type_, ts.TypeSpec) + return type_ @dataclasses.dataclass @@ -206,9 +208,7 @@ def _validate_argument_descriptors( ) -> None: for descriptors in all_descriptors.values(): for expr, descriptor in descriptors.items(): - param_type = _get_type_of_param_expr( - self.program_type, expr - ) # TODO: error handling if type is wrong + param_type = _get_type_of_param_expr(self.program_type, expr) descriptor.validate(expr, param_type) def _validate_argument_descriptor_mapping(self) -> None: From 60e3b791c1cfcbea5677bb92a45a489b93912599 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:32:22 +0200 Subject: [PATCH 43/93] Cleanup --- src/gt4py/next/otf/compiled_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index da2ca7ab5e..0b631a6703 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -244,7 +244,7 @@ def _compile_variant( got := set(descriptor_expr_mapping.keys()) ): raise ValueError( - f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs. Got {list(got)}, expected {list(expected)}." + f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs, got {list(got)} expected {list(expected)}." ) self._validate_argument_descriptors(argument_descriptors) From b694623aa2c7a04d8667b7068282586f156a922c Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 01:33:31 +0200 Subject: [PATCH 44/93] Cleanup --- src/gt4py/next/otf/compiled_program.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 0b631a6703..f7b8fd0590 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -34,6 +34,10 @@ ArgumentDescriptors: TypeAlias = dict[ type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor] ] +StructuredArgumentDescriptors: TypeAlias = dict[ + type[arguments.ArgumentDescriptor], + dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgumentDescriptor | None]], +] def _hash_compiled_program_unsafe(cp_key: CompiledProgramsKey) -> int: @@ -150,10 +154,7 @@ def _argument_descriptor_cache_key_from_args(self) -> Callable: def _argument_descriptor_cache_key_from_structured_descriptors( self, - argument_descriptors: dict[ - type[arguments.ArgumentDescriptor], - dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgumentDescriptor | None]], - ], + structured_descriptors: StructuredArgumentDescriptors, ) -> tuple: elements = [] for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point @@ -162,7 +163,7 @@ def _argument_descriptor_cache_key_from_structured_descriptors( attrs = attr_extractor.keys() for attr in attrs: elements.append( - getattr(eval(f"{arg_expr}", argument_descriptors[descriptor_cls]), attr) + getattr(eval(f"{arg_expr}", structured_descriptors[descriptor_cls]), attr) ) return tuple(elements) @@ -249,7 +250,7 @@ def _compile_variant( self._validate_argument_descriptors(argument_descriptors) - structured_descriptors = {} + structured_descriptors: StructuredArgumentDescriptors = {} for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): structured_descriptors[descriptor_cls] = _make_param_context_from_func_type( self.program_type.definition, lambda x: None From f7e131c222832404653108d4c20749ee2d2f0136 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 02:30:42 +0200 Subject: [PATCH 45/93] Cleanup --- src/gt4py/next/ffront/past_to_itir.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 60ba7023f0..f7f8386420 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -58,6 +58,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ... kwargs={}, ... offset_provider={"I": IDim}, ... column_axis=None, + ... argument_descriptors={} ... ) >>> itir_copy = past_to_gtir( From d7a09d8bfbc0074ff5a7fb4f6666d13726649cbb Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 02:31:16 +0200 Subject: [PATCH 46/93] Cleanup --- src/gt4py/next/ffront/past_to_itir.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f7f8386420..dcc1d9d5da 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -105,7 +105,6 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: static_args = { name: im.literal_from_tuple_value(descr.value) for name, descr in static_arg_descriptors.items() - if isinstance(descr, arguments.StaticArg) } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( From 2277bc1ced759822665b5266eb7b9849dce74c9b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 02:47:02 +0200 Subject: [PATCH 47/93] Cleanup --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 7c3aace73a..fefca65a62 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -487,6 +487,9 @@ def get_field_domain( if isinstance(field, itir.Expr) and isinstance(field.type, ts.FieldType): assert dims is None or all(d1 == d2 for d1, d2 in zip(field.type.dims, dims, strict=True)) dims = field.type.dims + else: + if dims is None: + raise ValueError("Field expression must be typed if 'dims' is not given.") return domain( grid_type, From 86d19a701dae7567f29e39019f6777b11eb2820d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 10:40:31 +0200 Subject: [PATCH 48/93] Cleanup --- src/gt4py/next/ffront/past_to_itir.py | 37 ++++++++-------- src/gt4py/next/iterator/ir_utils/ir_makers.py | 10 ++++- .../test_transform_get_domain_range.py | 13 ++---- .../otf_tests/test_compiled_program.py | 44 ++++++++++++++----- 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 5a6a28b4a9..d5ea5bbf96 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -97,24 +97,25 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) - static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] - if not all( - isinstance(arg_descriptor, arguments.StaticArg) - for arg_descriptor in static_arg_descriptors.values() - ): - raise NotImplementedError("Only top-level arguments can be static.") - static_args = { - name: im.literal_from_tuple_value(descr.value) - for name, descr in static_arg_descriptors.items() - } - body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) - itir_program = itir.Program( - id=itir_program.id, - function_definitions=itir_program.function_definitions, - params=itir_program.params, - declarations=itir_program.declarations, - body=body, - ) + if arguments.StaticArg in inp.args.argument_descriptors: + static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] + if not all( + isinstance(arg_descriptor, arguments.StaticArg) + for arg_descriptor in static_arg_descriptors.values() + ): + raise NotImplementedError("Only top-level arguments can be static.") + static_args = { + name: im.literal_from_tuple_value(descr.value) + for name, descr in static_arg_descriptors.items() + } + body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) + itir_program = itir.Program( + id=itir_program.id, + function_definitions=itir_program.function_definitions, + params=itir_program.params, + declarations=itir_program.declarations, + body=body, + ) if arguments.FieldDomainDescriptor in inp.args.argument_descriptors: field_domains = { diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index fefca65a62..1ad3334c98 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -453,7 +453,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]], + ranges_or_domain: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] | common.Domain, ) -> itir.FunCall: """ >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -463,6 +463,13 @@ def domain( >>> str(domain(common.GridType.UNSTRUCTURED, {IDim: (0, 10), JDim: (0, 20)})) 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ + if isinstance(ranges_or_domain, common.Domain): + domain = ranges_or_domain + ranges = {d: (r.start, r.stop) for d, r in zip(domain.dims, domain.ranges)} + else: + assert isinstance(ranges_or_domain, dict) + ranges = ranges_or_domain + if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" expr = call(grid_type)( @@ -478,7 +485,6 @@ def domain( expr.type = ts.DomainType(dims=list(ranges.keys())) return expr - def get_field_domain( grid_type: Union[common.GridType, str], field: str | itir.Expr, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py index a4864ff00e..98d52830e2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py @@ -71,18 +71,11 @@ def run_test_program( assert actual == expected -def domain_as_expr(domain: gtx.Domain) -> itir.Expr: - return im.domain( - common.GridType.UNSTRUCTURED, - {d: (r.start, r.stop) for d, r in zip(domain.dims, domain.ranges)}, - ) - - def test_get_domain(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} get_domain_expr = im.get_field_domain(common.GridType.UNSTRUCTURED, "out", sizes["out"].dims) - run_test_program(["inp", "out"], sizes, "out", domain_as_expr(sizes["out"]), get_domain_expr) + run_test_program(["inp", "out"], sizes, "out", im.domain_as_expr(sizes["out"]), get_domain_expr) def test_get_domain_tuples(): @@ -92,7 +85,7 @@ def test_get_domain_tuples(): common.GridType.UNSTRUCTURED, im.tuple_get(1, "out"), sizes["out"][1].dims ) - run_test_program(["inp", "out"], sizes, "out", domain_as_expr(sizes["out"][1]), get_domain_expr) + run_test_program(["inp", "out"], sizes, "out", im.domain_as_expr(sizes["out"][1]), get_domain_expr) def test_get_domain_nested_tuples(): @@ -110,5 +103,5 @@ def test_get_domain_nested_tuples(): ) run_test_program( - ["inp", "a", "b", "c", "d"], sizes, "a", domain_as_expr(sizes["a"]), get_domain_expr + ["inp", "a", "b", "c", "d"], sizes, "a", im.domain_as_expr(sizes["a"]), get_domain_expr ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 1c132773e0..40dabef1d6 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -9,9 +9,10 @@ import pytest from gt4py import eve, next as gtx -from gt4py.next import errors, backend -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.otf import compiled_program, toolchain, arguments +from gt4py.next import errors, backend, broadcast, common +from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import toolchain, arguments from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners import gtfn @@ -49,18 +50,16 @@ def test_sanitize_static_args_wrong_type(): @gtx.field_operator -def fop(cond: bool, a: gtx.Field[gtx.Dims[TDim], float], b: gtx.Field[gtx.Dims[TDim], float]): - return a if cond else b +def fop(cond: bool): + return broadcast(cond, (TDim,)) @gtx.program def prog( cond: bool, - a: gtx.Field[gtx.Dims[TDim], gtx.float64], - b: gtx.Field[gtx.Dims[TDim], gtx.float64], - out: gtx.Field[gtx.Dims[TDim], gtx.float64], + out: gtx.Field[gtx.Dims[TDim], bool], ): - fop(cond, a, b, out=out) + fop(cond, out=out) def _verify_program_has_expected_true_value(program: itir.Program): @@ -108,10 +107,31 @@ def pirate(program: toolchain.CompilableProgram): testee = prog.with_backend(hacked_gtfn_backend).compile(cond=[True], offset_provider={}) testee( cond=True, - a=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), - b=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), - out=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), + out=gtx.zeros(domain={TDim: 1}, dtype=bool), offset_provider={}, ) _verify_program_has_expected_true_value(hijacked_program.data) + +def _verify_program_has_expected_domain(program: itir.Program, expected_domain: gtx.Domain): + assert isinstance(program.body[0], itir.SetAt) + assert isinstance(program.body[0].expr, itir.FunCall) + assert program.body[0].expr.fun == itir.SymRef(id="fop") + domain = CollapseTuple.apply(program.body[0].domain, within_stencil=False) + assert domain == im.domain(common.GridType.CARTESIAN, expected_domain) + +def test_inlining_of_static_domain_works(): + domain = gtx.Domain(dims=(TDim,), ranges=(gtx.UnitRange(0, 1),)) + input_pair = toolchain.CompilableProgram( + data=prog.definition_stage, + args=arguments.CompileTimeArgs( + args=list(prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), + kwargs={}, + offset_provider={}, + column_axis=None, + argument_descriptors={arguments.FieldDomainDescriptor: {"out": arguments.FieldDomainDescriptor(domain)}}, + ), + ) + + transformed = backend.DEFAULT_TRANSFORMS(input_pair).data + _verify_program_has_expected_domain(transformed, domain) \ No newline at end of file From ebca236203286751e741a10b8ac57ea5afe3d58b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 14:44:45 +0200 Subject: [PATCH 49/93] Improve docs add ADR --- .../ADRs/next/0021-Argument-Descriptors.md | 26 +++++ src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/otf/arguments.py | 12 +++ src/gt4py/next/otf/compiled_program.py | 99 +++++++++++-------- 4 files changed, 98 insertions(+), 41 deletions(-) create mode 100644 docs/development/ADRs/next/0021-Argument-Descriptors.md diff --git a/docs/development/ADRs/next/0021-Argument-Descriptors.md b/docs/development/ADRs/next/0021-Argument-Descriptors.md new file mode 100644 index 0000000000..5c9ebad16d --- /dev/null +++ b/docs/development/ADRs/next/0021-Argument-Descriptors.md @@ -0,0 +1,26 @@ +--- +tags: [] +--- + +# [Argument Descriptors] + +- **Status**: valid +- **Authors**: Till Ehrengruber (@tehrengruber) +- **Created**: 2025-09-12 +- **Updated**: 2025-09-12 + +A generic mechanism to deduce (partial) information from runtime arguments and providing it at compile time is introduced. + +## History + +The first version of static parameters introduced a new class `CompiledProgramsPool` which was responsible for compiling and dispatching to a program compiled for a given set of arguments which were marked as static. How a static argument was fingerprinted in order to dispatch, how it was constructed to pass it down the toolchain was implemented directly inside the `CompiledProgramsPool` class. For static parameters this was a simple and working design, but adding additional information about arguments would have resulted in bloating the class with more and more code that was specific to the information we wanted to add: How do we construct it and from which subset of the value, how to fingerprint it without much overhead, how to validate the extracted data is correct, etc. + +## Decision + +We introduce a new class `ArgumentDescriptor` that all classes providing additional compile time information inherit from. This class contains all information needed to represent, extract and validate compile time information of an argument, hence uncoupling this from the `CompiledProgramsPool` implementation. + +The `CompiledProgramsPool` class gets a new attributes `argument_descriptor_mapping` which maps from subclasses of `ArgumentDescriptor` to a list of parameter expression for which the respective descriptor is constructed for. We extend this from parameter names in the initial implementation to expressions such that we can, if desired, retrieve information for parts of the arguments (e.g. an element of a nested tuple). For static parameters this is less important, but usually we want a descriptor for a single leaf-type instead of a tuple / container, e.g. for a field instead of a tuple of fields. + +In a first version we chose to allow multiple argument descriptors (internally) for a single parameter as this gives us maximum flexibility in the future. However, we will not expose this in the public facing interface, until we have actual cases where this is useful. Both approaches have their advantages and disadvantages. Having a single descriptor would mean we would need an additional mechanism to decide which information should be available to a descriptor, e.g. only the memory layout of a field or also its domain. Multiple descriptors allow distinguishing simply by not passing them in the `argument_descriptor_mapping`. On the other hand a single descriptor allows to describe a parameter with a single descriptor object when precompiling which reduces code bloat. + +We chose expressions to describe which element of a parameter is used to construct argument descriptors as this is a convenient way to describe them for the user with a syntax that he is familiar with, e.g. one can just write `tuple_arg[0]` and because it was convenient in the implementation. Right now this is only used internally though and the user is only allowed to use parameter names. diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index dcc1d9d5da..aa76173275 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -58,7 +58,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ... kwargs={}, ... offset_provider={"I": IDim}, ... column_axis=None, - ... argument_descriptors={} + ... argument_descriptors={}, ... ) >>> itir_copy = past_to_gtir( diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 647af295c3..c3b3fc70ef 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -27,6 +27,18 @@ class ArgumentDescriptor: + """ + Abstract class to represent, extract, validate compile time information of an argument. + + The information that is available at compile time is extracted from the runtime argument + (or provided when pre-compiling) is described by a set of (python) expressions returned by the + `attribute_extractor` class-method. These expressions are evaluated in the context of the + arguments. We chose expressions here instead of a method taking the actual value such that we + can code generate a single expression for all argument descriptors only retrieving the necessary + values without actually constructing the descriptors. That way the cache key computation to the + compiled is fast. + """ + def validate(self, name: str, type_: ts.TypeSpec) -> None: """ Validate argument descriptor. diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index f7b8fd0590..ff9de4b6bb 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -34,7 +34,7 @@ ArgumentDescriptors: TypeAlias = dict[ type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor] ] -StructuredArgumentDescriptors: TypeAlias = dict[ +ArgumentDescriptorContext: TypeAlias = dict[ type[arguments.ArgumentDescriptor], dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgumentDescriptor | None]], ] @@ -74,9 +74,9 @@ class CompiledProgramsPool: """ A pool of compiled programs for a given program and backend. - If 'static_params' is set (or static arguments are passed to 'compile'), - the pool will create a program for each argument that is marked static - and each 'OffsetProviderType'. + If 'argument_descriptor_mapping' is populated the pool will create a program for each + argument that has an argument descriptor. E.g., if a param is marked static we create + a new program for each value of that parameter. If `enable_jit` is True in the call to the pool, it will compile a program with static arguments corresponding to the 'static_params', otherwise it @@ -143,6 +143,14 @@ def __call__( @functools.cached_property def _argument_descriptor_cache_key_from_args(self) -> Callable: + """ + Given the entire set of runtime arguments compute the cache key used to retrieve the + instance of the compiled program which is compiled for the argument descriptors from + the given set of arguments. + + This is part of the performance critical path that is called on every program call, + hence we code generate a single lambda expression here. + """ func_type = self.program_type.definition params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) elements: list[str] = [] @@ -152,10 +160,33 @@ def _argument_descriptor_cache_key_from_args(self) -> Callable: elements.extend(attr_extractor.values()) return eval(f"""lambda {",".join(params)}: ({_make_tuple_expr(elements)})""") - def _argument_descriptor_cache_key_from_structured_descriptors( + def _argument_descriptor_cache_key_from_descriptors( self, - structured_descriptors: StructuredArgumentDescriptors, + argument_descriptors: ArgumentDescriptors, ) -> tuple: + """ + Given a set of argument descriptors deduce the cache key used to retrieve the instance + of the compiled program which is compiled for the given argument descriptors. + + This function is not performance critical as it is only called once when compiling a + variant. + """ + # first build a context that we can evaluate parameter expressions on descriptors in + descriptor_context: ArgumentDescriptorContext = {} + for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): + descriptor_context[descriptor_cls] = _make_param_context_from_func_type( + self.program_type.definition, lambda x: None + ) + assert "__descriptor" not in descriptor_context[descriptor_cls] + for expr, descriptor in descriptor_expr_mapping.items(): + # note: we don't need to handle any errors here since the `expr` has been validated + # in `_validate_argument_descriptor_mapping` + exec( + f"{expr} = __descriptor", + {"__descriptor": descriptor}, + descriptor_context[descriptor_cls], + ) + elements = [] for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: @@ -163,12 +194,12 @@ def _argument_descriptor_cache_key_from_structured_descriptors( attrs = attr_extractor.keys() for attr in attrs: elements.append( - getattr(eval(f"{arg_expr}", structured_descriptors[descriptor_cls]), attr) + getattr(eval(f"{arg_expr}", descriptor_context[descriptor_cls]), attr) ) return tuple(elements) @functools.cached_property - def _descriptor_attr_retrievers( + def _argument_descriptor_attr_retrievers( self, ) -> dict[type[arguments.ArgumentDescriptor], dict[str, Callable]]: """ @@ -194,8 +225,9 @@ def make_dict_expr(exprs: dict[str, str]) -> str: return retrievers def _make_argument_descriptors(self, *args: Any, **kwargs: Any) -> ArgumentDescriptors: + """Given a set of runtime arguments construct all argument descriptors from them.""" descriptors: ArgumentDescriptors = {} - for descriptor_cls, attr_retrievers in self._descriptor_attr_retrievers.items(): + for descriptor_cls, attr_retrievers in self._argument_descriptor_attr_retrievers.items(): descriptors[descriptor_cls] = {} for expr, attr_retriever in attr_retrievers.items(): descriptor = descriptor_cls(**attr_retriever(*args, **kwargs)) @@ -212,6 +244,22 @@ def _validate_argument_descriptors( param_type = _get_type_of_param_expr(self.program_type, expr) descriptor.validate(expr, param_type) + def _initialize_argument_descriptor_mapping(self, argument_descriptors: ArgumentDescriptors): + if self.argument_descriptor_mapping is None: + self.argument_descriptor_mapping = { + descr_cls: list(descriptor_expr_mapping.keys()) + for descr_cls, descriptor_expr_mapping in argument_descriptors.items() + } + self._validate_argument_descriptor_mapping() + else: + for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): + if (expected := set(self.argument_descriptor_mapping[descr_cls])) != ( + got := set(descriptor_expr_mapping.keys()) + ): + raise ValueError( + f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs, got {list(got)} expected {list(expected)}." + ) + def _validate_argument_descriptor_mapping(self) -> None: if self.argument_descriptor_mapping is None: return @@ -233,40 +281,11 @@ def _compile_variant( argument_descriptors: ArgumentDescriptors, offset_provider: common.OffsetProviderType | common.OffsetProvider, ) -> None: - if self.argument_descriptor_mapping is None: - self.argument_descriptor_mapping = { - descr_cls: list(descriptor_expr_mapping.keys()) - for descr_cls, descriptor_expr_mapping in argument_descriptors.items() - } - self._validate_argument_descriptor_mapping() - else: - for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): - if (expected := set(self.argument_descriptor_mapping[descr_cls])) != ( - got := set(descriptor_expr_mapping.keys()) - ): - raise ValueError( - f"Argument descriptor {descr_cls.__name__} must be the same for all compiled programs, got {list(got)} expected {list(expected)}." - ) - + self._initialize_argument_descriptor_mapping(argument_descriptors) self._validate_argument_descriptors(argument_descriptors) - structured_descriptors: StructuredArgumentDescriptors = {} - for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): - structured_descriptors[descriptor_cls] = _make_param_context_from_func_type( - self.program_type.definition, lambda x: None - ) - assert "__descriptor" not in structured_descriptors[descriptor_cls] - for expr, descriptor in descriptor_expr_mapping.items(): - # note: we don't need to handle any errors here since the `expr` has been validated - # in `_validate_argument_descriptor_mapping` - exec( - f"{expr} = __descriptor", - {"__descriptor": descriptor}, - structured_descriptors[descriptor_cls], - ) - key = ( - self._argument_descriptor_cache_key_from_structured_descriptors(structured_descriptors), # type: ignore[arg-type] # mypy not smart enough + self._argument_descriptor_cache_key_from_descriptors(argument_descriptors), # type: ignore[arg-type] # mypy not smart enough self._offset_provider_to_type_unsafe(offset_provider), ) if key in self._compiled_programs: From 0c9bb894bad7b0f643dd4cfc8b9dbd3782798d9c Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 14:48:19 +0200 Subject: [PATCH 50/93] Improve docs --- src/gt4py/next/otf/compiled_program.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index ff9de4b6bb..4f895741d4 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -76,12 +76,13 @@ class CompiledProgramsPool: If 'argument_descriptor_mapping' is populated the pool will create a program for each argument that has an argument descriptor. E.g., if a param is marked static we create - a new program for each value of that parameter. + a new program for each value of that parameter. See :ref:`arguments.ArgumentDescriptor` for + more information on argument descriptors. If `enable_jit` is True in the call to the pool, it will compile a program - with static arguments corresponding to the 'static_params', otherwise it + with static information as described in `argument_descriptor_mapping`, otherwise it will error. In the latter case, the pool needs to be filled with call(s) - to 'compile' before it can be used. + to `compile` before it can be used. """ backend: gtx_backend.Backend @@ -121,7 +122,8 @@ def __call__( """ args, kwargs = type_info.canonicalize_arguments(self.program_type, args, kwargs) static_args_values = self._argument_descriptor_cache_key_from_args(*args, **kwargs) - # TODO: dispatching over offset provider type is wrong. especially when we use compile time domains. test? + # TODO: Dispatching over offset provider type is wrong, especially when we use compile time + # domains. key = (static_args_values, self._offset_provider_to_type_unsafe(offset_provider)) try: self._compiled_programs[key](*args, **kwargs, offset_provider=offset_provider) From a284ea7ad43032825a79cb335b7f55bd51c735bc Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 14:49:42 +0200 Subject: [PATCH 51/93] Cleanup --- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/otf/compiled_program.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index aa76173275..602409b91c 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -103,7 +103,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ): raise NotImplementedError("Only top-level arguments can be static.") static_args = { - name: im.literal_from_tuple_value(descr.value) + name: im.literal_from_tuple_value(descr.value) # type: ignore[attr-defined] # type checked above for name, descr in static_arg_descriptors.items() } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 4f895741d4..5d75d0042e 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -287,7 +287,7 @@ def _compile_variant( self._validate_argument_descriptors(argument_descriptors) key = ( - self._argument_descriptor_cache_key_from_descriptors(argument_descriptors), # type: ignore[arg-type] # mypy not smart enough + self._argument_descriptor_cache_key_from_descriptors(argument_descriptors), self._offset_provider_to_type_unsafe(offset_provider), ) if key in self._compiled_programs: From c7c7361dea50e8cb326b8c13f124ff7c98cab475 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 14:57:15 +0200 Subject: [PATCH 52/93] Cleanup --- src/gt4py/next/otf/compiled_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 5d75d0042e..22c0392753 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -246,7 +246,7 @@ def _validate_argument_descriptors( param_type = _get_type_of_param_expr(self.program_type, expr) descriptor.validate(expr, param_type) - def _initialize_argument_descriptor_mapping(self, argument_descriptors: ArgumentDescriptors): + def _initialize_argument_descriptor_mapping(self, argument_descriptors: ArgumentDescriptors) -> None: if self.argument_descriptor_mapping is None: self.argument_descriptor_mapping = { descr_cls: list(descriptor_expr_mapping.keys()) From efd331be57e89baf9fdb9a175e9d9b630c05f9c0 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 20:31:08 +0200 Subject: [PATCH 53/93] Fix failing test --- src/gt4py/next/ffront/past_to_itir.py | 37 ++++++++++++++------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 602409b91c..f82680be15 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -96,24 +96,25 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) - static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] - if not all( - isinstance(arg_descriptor, arguments.StaticArg) - for arg_descriptor in static_arg_descriptors.values() - ): - raise NotImplementedError("Only top-level arguments can be static.") - static_args = { - name: im.literal_from_tuple_value(descr.value) # type: ignore[attr-defined] # type checked above - for name, descr in static_arg_descriptors.items() - } - body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) - itir_program = itir.Program( - id=itir_program.id, - function_definitions=itir_program.function_definitions, - params=itir_program.params, - declarations=itir_program.declarations, - body=body, - ) + if arguments.StaticArg in inp.args.argument_descriptors: + static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] + if not all( + isinstance(arg_descriptor, arguments.StaticArg) + for arg_descriptor in static_arg_descriptors.values() + ): + raise NotImplementedError("Only top-level arguments can be static.") + static_args = { + name: im.literal_from_tuple_value(descr.value) # type: ignore[attr-defined] # type checked above + for name, descr in static_arg_descriptors.items() + } + body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) + itir_program = itir.Program( + id=itir_program.id, + function_definitions=itir_program.function_definitions, + params=itir_program.params, + declarations=itir_program.declarations, + body=body, + ) if config.DEBUG or inp.data.debug: devtools.debug(itir_program) From 07b968cf1f0014d9fdb8af0a5b25d5943ac5df86 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 20:33:36 +0200 Subject: [PATCH 54/93] Fix failing test, abstract classmethod --- src/gt4py/next/otf/arguments.py | 4 +++- src/gt4py/next/otf/compiled_program.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index c3b3fc70ef..52e94f719a 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -8,6 +8,7 @@ from __future__ import annotations +import abc import dataclasses import enum import typing @@ -26,7 +27,7 @@ T = typing.TypeVar("T") -class ArgumentDescriptor: +class ArgumentDescriptor(abc.ABC): """ Abstract class to represent, extract, validate compile time information of an argument. @@ -49,6 +50,7 @@ def validate(self, name: str, type_: ts.TypeSpec) -> None: pass @classmethod + @abc.abstractmethod def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: # type: ignore[empty-body] # classmethod is abstract """ Return a mapping from the attributes of our descriptor to the expressions to retrieve them. diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 22c0392753..f1574dfa03 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -246,7 +246,9 @@ def _validate_argument_descriptors( param_type = _get_type_of_param_expr(self.program_type, expr) descriptor.validate(expr, param_type) - def _initialize_argument_descriptor_mapping(self, argument_descriptors: ArgumentDescriptors) -> None: + def _initialize_argument_descriptor_mapping( + self, argument_descriptors: ArgumentDescriptors + ) -> None: if self.argument_descriptor_mapping is None: self.argument_descriptor_mapping = { descr_cls: list(descriptor_expr_mapping.keys()) From 1692f5cc6d2052f9b02eb0dbb686366b8b56a0b8 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 20:35:03 +0200 Subject: [PATCH 55/93] Cleanup --- src/gt4py/next/otf/arguments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 52e94f719a..73681b55cc 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -40,7 +40,7 @@ class ArgumentDescriptor(abc.ABC): compiled is fast. """ - def validate(self, name: str, type_: ts.TypeSpec) -> None: + def validate(self, name: str, type_: ts.TypeSpec) -> None: # noqa: B027 # method is not abstract, but just empty when not implemented """ Validate argument descriptor. @@ -51,7 +51,7 @@ def validate(self, name: str, type_: ts.TypeSpec) -> None: @classmethod @abc.abstractmethod - def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: # type: ignore[empty-body] # classmethod is abstract + def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: """ Return a mapping from the attributes of our descriptor to the expressions to retrieve them. From 7e3ec654062f63eb8a827a6fca7049c669a3032d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 12 Sep 2025 20:47:59 +0200 Subject: [PATCH 56/93] Cleanup --- .../ir_utils_test.py/test_misc.py | 40 +++++++++++++++++++ .../transforms_tests/test_constant_folding.py | 39 ------------------ 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_misc.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_misc.py index ad0781acb3..d4f3f4b07f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_misc.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_misc.py @@ -8,8 +8,48 @@ import pytest +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im, misc from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.type_system import type_specifications as ts + + +@pytest.mark.parametrize( + "value,expected", + [ + (itir.Literal(value="True", type=ts.ScalarType(kind=ts.ScalarKind.BOOL)), True), + (itir.Literal(value="False", type=ts.ScalarType(kind=ts.ScalarKind.BOOL)), False), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT8)), gtx.int8(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT8)), gtx.uint8(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT16)), gtx.int16(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT16)), gtx.uint16(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), gtx.int32(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT32)), gtx.uint32(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT64)), gtx.int64(1)), + (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT64)), gtx.uint64(1)), + ( + itir.Literal(value="0.1", type=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), + gtx.float32("0.1"), + ), + ( + itir.Literal(value="0.1", type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + gtx.float64("0.1"), + ), + ( + itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + gtx.float64("1.0"), + ), + ], + ids=lambda param: f"Literal[{param.value}, {param.type}]" + if isinstance(param, itir.Literal) + else str(param), +) +def test_value_from_literal(value, expected): + result = misc.value_from_literal(value) + + assert result == expected + assert type(result) is type(expected) @pytest.mark.parametrize( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index a56b539014..90a28086e8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -8,48 +8,9 @@ import pytest -from gt4py import next as gtx from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import constant_folding -from gt4py.next.type_system import type_specifications as ts - - -@pytest.mark.parametrize( - "value,expected", - [ - (itir.Literal(value="True", type=ts.ScalarType(kind=ts.ScalarKind.BOOL)), True), - (itir.Literal(value="False", type=ts.ScalarType(kind=ts.ScalarKind.BOOL)), False), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT8)), gtx.int8(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT8)), gtx.uint8(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT16)), gtx.int16(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT16)), gtx.uint16(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), gtx.int32(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT32)), gtx.uint32(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.INT64)), gtx.int64(1)), - (itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.UINT64)), gtx.uint64(1)), - ( - itir.Literal(value="0.1", type=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), - gtx.float32("0.1"), - ), - ( - itir.Literal(value="0.1", type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - gtx.float64("0.1"), - ), - ( - itir.Literal(value="1", type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - gtx.float64("1.0"), - ), - ], - ids=lambda param: f"Literal[{param.value}, {param.type}]" - if isinstance(param, itir.Literal) - else str(param), -) -def test_value_from_literal(value, expected): - result = constant_folding._value_from_literal(value) - - assert result == expected - assert type(result) is type(expected) @pytest.mark.parametrize( From 6dfb03770d062e042e99166d2f401d904d0ce6e7 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 22 Sep 2025 01:53:10 +0200 Subject: [PATCH 57/93] Address review comments --- src/gt4py/next/otf/arguments.py | 26 +++-- src/gt4py/next/otf/compiled_program.py | 109 ++++++++---------- .../ffront_tests/test_compiled_program.py | 2 +- 3 files changed, 69 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 73681b55cc..9a80a83137 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -12,7 +12,7 @@ import dataclasses import enum import typing -from typing import Any, Generic, Mapping, Optional +from typing import Any, Generic, Mapping, Optional, final from typing_extensions import Self @@ -27,7 +27,11 @@ T = typing.TypeVar("T") -class ArgumentDescriptor(abc.ABC): +def _make_dict_expr(exprs: dict[str, str]) -> str: + return "{" + ",".join((f"'{k}': {v}" for k, v in exprs.items())) + "}" + + +class ArgStaticDescriptor(abc.ABC): """ Abstract class to represent, extract, validate compile time information of an argument. @@ -42,16 +46,22 @@ class ArgumentDescriptor(abc.ABC): def validate(self, name: str, type_: ts.TypeSpec) -> None: # noqa: B027 # method is not abstract, but just empty when not implemented """ - Validate argument descriptor. + Validate argument descriptor in the context of an actual program. This function is called when the type of the argument is available. The name is merely given to give good error messages. """ pass + @classmethod + @final + def from_value(cls, value: Any) -> ArgStaticDescriptor: + attr_exprs = cls.attribute_extractor_exprs("self") + return cls(**eval(f"""lambda self: {_make_dict_expr(attr_exprs)}""")(value)) + @classmethod @abc.abstractmethod - def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: + def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: """ Return a mapping from the attributes of our descriptor to the expressions to retrieve them. @@ -64,7 +74,7 @@ def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: @dataclasses.dataclass(frozen=True) -class StaticArg(ArgumentDescriptor, Generic[core_defs.ScalarT]): +class StaticArg(ArgStaticDescriptor, Generic[core_defs.ScalarT]): value: extended_typing.MaybeNestedInTuple[core_defs.ScalarT] def __post_init__(self) -> None: @@ -87,7 +97,7 @@ def validate(self, name: str, type_: ts.TypeSpec) -> None: ) @classmethod - def attribute_extractor(cls, arg_expr: str) -> dict[str, str]: + def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: return {"value": arg_expr} @@ -112,8 +122,8 @@ class CompileTimeArgs: offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] argument_descriptors: Mapping[ - type[ArgumentDescriptor], - dict[str, ArgumentDescriptor], + type[ArgStaticDescriptor], + dict[str, ArgStaticDescriptor], ] @property diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index f1574dfa03..9f47dd0613 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -12,14 +12,15 @@ import dataclasses import functools import itertools -from typing import Any, Callable, DefaultDict, Sequence, TypeAlias, TypeVar +from typing import Any, Callable, Sequence, TypeAlias, TypeVar from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing, utils as eve_utils -from gt4py.next import backend as gtx_backend, common, config, errors +from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import stages as ffront_stages, type_specifications as ts_ffront from gt4py.next.otf import arguments, stages from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.utils import tree_map T = TypeVar("T") @@ -32,11 +33,11 @@ tuple[ScalarOrTupleOfScalars, ...], common.OffsetProviderType ] ArgumentDescriptors: TypeAlias = dict[ - type[arguments.ArgumentDescriptor], dict[str, arguments.ArgumentDescriptor] + type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] ] ArgumentDescriptorContext: TypeAlias = dict[ - type[arguments.ArgumentDescriptor], - dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgumentDescriptor | None]], + type[arguments.ArgStaticDescriptor], + dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None]], ] @@ -64,11 +65,43 @@ def _make_param_context_from_func_type( def _get_type_of_param_expr(program_type: ts_ffront.ProgramType, expr: str) -> ts.TypeSpec: - type_ = eval(expr, _make_param_context_from_func_type(program_type.definition)) + structured_type_ = eval(expr, _make_param_context_from_func_type(program_type.definition)) + type_ = tree_map( + lambda v: v, result_collection_constructor=lambda elts: ts.TupleType(types=list(elts)) + )(structured_type_) assert isinstance(type_, ts.TypeSpec) return type_ +def _make_argument_descriptors( + program_type: ts_ffront.ProgramType, + argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]], + args: tuple[Any], + kwargs: dict[str, Any], +) -> ArgumentDescriptors: + """Given a set of runtime arguments construct all argument descriptors from them.""" + func_type = program_type.definition + params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) + descriptors: ArgumentDescriptors = {} + for descriptor_cls, exprs in argument_descriptor_mapping.items(): + descriptors[descriptor_cls] = {} + for expr in exprs: + argument = eval(f"""lambda {",".join(params)}: {expr}""")(*args, **kwargs) + descriptors[descriptor_cls][expr] = descriptor_cls.from_value(argument) + _validate_argument_descriptors(program_type, descriptors) + return descriptors + + +def _validate_argument_descriptors( + program_type: ts_ffront.ProgramType, + all_descriptors: ArgumentDescriptors, +) -> None: + for descriptors in all_descriptors.values(): + for expr, descriptor in descriptors.items(): + param_type = _get_type_of_param_expr(program_type, expr) + descriptor.validate(expr, param_type) + + @dataclasses.dataclass class CompiledProgramsPool: """ @@ -91,7 +124,7 @@ class CompiledProgramsPool: #: mapping from an argument descriptor type to a list of parameters or expression thereof #: e.g. `{arguments.StaticArg: ["static_int_param"]}` #: Note: The list is not ordered. - argument_descriptor_mapping: dict[type[arguments.ArgumentDescriptor], Sequence[str]] | None + argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] | None _compiled_programs: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe), @@ -135,7 +168,9 @@ def __call__( if enable_jit: assert self.argument_descriptor_mapping is not None self._compile_variant( - argument_descriptors=self._make_argument_descriptors(*args, **kwargs), + argument_descriptors=_make_argument_descriptors( + self.program_type, self.argument_descriptor_mapping, args, kwargs + ), offset_provider=offset_provider, ) return self( @@ -158,7 +193,7 @@ def _argument_descriptor_cache_key_from_args(self) -> Callable: elements: list[str] = [] for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: - attr_extractor = descriptor_cls.attribute_extractor(arg_expr) + attr_extractor = descriptor_cls.attribute_extractor_exprs(arg_expr) elements.extend(attr_extractor.values()) return eval(f"""lambda {",".join(params)}: ({_make_tuple_expr(elements)})""") @@ -192,7 +227,7 @@ def _argument_descriptor_cache_key_from_descriptors( elements = [] for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: - attr_extractor = descriptor_cls.attribute_extractor(arg_expr) + attr_extractor = descriptor_cls.attribute_extractor_exprs(arg_expr) attrs = attr_extractor.keys() for attr in attrs: elements.append( @@ -200,52 +235,6 @@ def _argument_descriptor_cache_key_from_descriptors( ) return tuple(elements) - @functools.cached_property - def _argument_descriptor_attr_retrievers( - self, - ) -> dict[type[arguments.ArgumentDescriptor], dict[str, Callable]]: - """ - For each argument expression build a lambda function that constructs (the attributes of) - its argument descriptor - """ - - def make_dict_expr(exprs: dict[str, str]) -> str: - return "{" + ",".join((f"'{k}': {v}" for k, v in exprs.items())) + "}" - - func_type = self.program_type.definition - params = list(func_type.pos_or_kw_args.keys()) + list(func_type.kw_only_args.keys()) - retrievers: dict[type[arguments.ArgumentDescriptor], dict[str, Callable]] = DefaultDict( - dict - ) - for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point - for arg_expr in arg_exprs: - attr_exprs = descriptor_cls.attribute_extractor(arg_expr) - retrievers[descriptor_cls][arg_expr] = eval( - f"""lambda {",".join(params)}: {make_dict_expr(attr_exprs)}""" - ) - - return retrievers - - def _make_argument_descriptors(self, *args: Any, **kwargs: Any) -> ArgumentDescriptors: - """Given a set of runtime arguments construct all argument descriptors from them.""" - descriptors: ArgumentDescriptors = {} - for descriptor_cls, attr_retrievers in self._argument_descriptor_attr_retrievers.items(): - descriptors[descriptor_cls] = {} - for expr, attr_retriever in attr_retrievers.items(): - descriptor = descriptor_cls(**attr_retriever(*args, **kwargs)) - descriptors[descriptor_cls][expr] = descriptor - self._validate_argument_descriptors(descriptors) - return descriptors - - def _validate_argument_descriptors( - self, - all_descriptors: ArgumentDescriptors, - ) -> None: - for descriptors in all_descriptors.values(): - for expr, descriptor in descriptors.items(): - param_type = _get_type_of_param_expr(self.program_type, expr) - descriptor.validate(expr, param_type) - def _initialize_argument_descriptor_mapping( self, argument_descriptors: ArgumentDescriptors ) -> None: @@ -271,9 +260,11 @@ def _validate_argument_descriptor_mapping(self) -> None: for descr_cls, exprs in self.argument_descriptor_mapping.items(): for expr in exprs: try: - if eval(expr, context) is not None: + if any( + v is not None for v in gtx_utils.flatten_nested_tuple(eval(expr, context)) + ): raise ValueError() - except (ValueError, KeyError): + except (ValueError, KeyError, NameError): raise errors.DSLTypeError( # noqa: B904 # we don't care about the original exception message=f"Invalid parameter expression '{expr}' for '{descr_cls.__name__}'. " f"Must be the name of a parameter or an access to one of its elements.", @@ -286,7 +277,7 @@ def _compile_variant( offset_provider: common.OffsetProviderType | common.OffsetProvider, ) -> None: self._initialize_argument_descriptor_mapping(argument_descriptors) - self._validate_argument_descriptors(argument_descriptors) + _validate_argument_descriptors(self.program_type, argument_descriptors) key = ( self._argument_descriptor_cache_key_from_descriptors(argument_descriptors), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index dae2c7a975..609836be29 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -754,7 +754,7 @@ def test_compile_variants_wrong_type(cartesian_case, compile_variants_testee_not def test_compile_variants_error_static_field(cartesian_case, compile_variants_testee_not_compiled): field_a = cases.allocate(cartesian_case, compile_variants_testee_not_compiled, "field_a")() - with pytest.raises(errors.DSLTypeError, match="field_a.*cannot be static"): + with pytest.raises(errors.DSLTypeError, match="Invalid static argument.*field_a"): compile_variants_testee_not_compiled.compile(field_a=[field_a], offset_provider={}) From de7a8bd994fd1071f939a5bc73f0877b1160ead5 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 22 Sep 2025 01:55:40 +0200 Subject: [PATCH 58/93] Address review comments --- src/gt4py/next/otf/compiled_program.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 9f47dd0613..126f25861a 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -260,6 +260,8 @@ def _validate_argument_descriptor_mapping(self) -> None: for descr_cls, exprs in self.argument_descriptor_mapping.items(): for expr in exprs: try: + # TODO(tehrengruber): Re-evaluate the way we validate here when we add support + # for containers. if any( v is not None for v in gtx_utils.flatten_nested_tuple(eval(expr, context)) ): From 1e904425e412a0a4e1fa4d0f876994c0911126d2 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 22 Sep 2025 21:38:54 +0200 Subject: [PATCH 59/93] Cleanup --- src/gt4py/next/ffront/decorator.py | 4 +- src/gt4py/next/ffront/past_to_itir.py | 9 +- .../next/iterator/ir_utils/domain_utils.py | 94 +++++++++------- src/gt4py/next/iterator/ir_utils/ir_makers.py | 1 + src/gt4py/next/otf/arguments.py | 3 +- .../ir_utils_test.py/test_domain_utils.py | 101 +++++++++++++++++- .../test_temporary_domain_inference.py | 8 +- .../test_transform_get_domain_range.py | 4 +- .../otf_tests/test_compiled_program.py | 8 +- 9 files changed, 178 insertions(+), 54 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 3362228293..bd0c171d43 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -138,7 +138,7 @@ def from_function( connectivities=connectivities, enable_jit=enable_jit, static_params=static_params, - static_domains=static_domains + static_domains=static_domains, ) # needed in testing @@ -289,7 +289,7 @@ def path_to_expr(path: Sequence[int]): argument_descriptor_mapping = { arguments.StaticArg: self.static_params, - arguments.FieldDomainDescriptor: static_domain_args + arguments.FieldDomainDescriptor: static_domain_args, } program_type = self.past_stage.past_node.type diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index eeba2ba5db..d7173b5d9d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -15,7 +15,7 @@ import devtools from gt4py.eve import NodeTranslator, concepts, traits, utils as eve_utils -from gt4py.next import common, config, errors +from gt4py.next import common, config, errors, utils from gt4py.next.ffront import ( fbuiltins, gtcallable, @@ -30,7 +30,6 @@ from gt4py.next.iterator.transforms import remap_symbols, transform_get_domain_range from gt4py.next.otf import arguments, stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next import utils # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove @@ -119,11 +118,11 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: if arguments.FieldDomainDescriptor in inp.args.argument_descriptors: field_domains = { - param: utils.tree_map(lambda x: x.domain)(v) for param, v in inp.args.argument_descriptors[arguments.FieldDomainDescriptor].items() + param: utils.tree_map(lambda x: x.domain)(v) + for param, v in inp.args.argument_descriptors[arguments.FieldDomainDescriptor].items() } itir_program = transform_get_domain_range.TransformGetDomainRange.apply( - itir_program, - sizes=field_domains + itir_program, sizes=field_domains ) if config.DEBUG or inp.data.debug: diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 9aa116a5ed..331a3a381a 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -22,6 +22,17 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +#: Threshold fraction of domain points which may be added to a domain on translation in order +#: to have a contiguous domain before a warning is raised. +NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD: float = 1 / 4 + +#: Skip printing warnings after exceeding `NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD` this many times. +NON_CONTIGUOUS_DOMAIN_MAX_WARNINGS: int = 5 + +#: Number of warnings raised after exceeding `NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD` +_NON_CONTIGUOUS_DOMAIN_WARNING_COUNTER: int = 0 + + def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: """ Extract horizontal domain sizes from an `offset_provider`. @@ -135,64 +146,67 @@ def translate( horizontal_sizes: dict[str, tuple[itir.Expr, itir.Expr]] old_dim = connectivity_type.source_dim new_dim = connectivity_type.codomain + assert new_dim not in new_ranges or old_dim == new_dim if symbolic_domain_sizes is not None: - horizontal_sizes = { - k: (im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), im.ensure_expr(v)) - for k, v in symbolic_domain_sizes.items() - } + new_range = SymbolicRange( + im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), + im.ensure_expr(symbolic_domain_sizes[new_dim.value]), + ) else: - # note: ugly but cheap re-computation, but should disappear assert common.is_offset_provider(offset_provider) - horizontal_sizes = { - k: ( - im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), - im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN), + skip_value = offset_provider[off.value].skip_value + + # fold & convert expr into actual integers + range_exprs = new_ranges[old_dim].start, new_ranges[old_dim].stop + range_exprs = [ + collapse_tuple.CollapseTuple.apply( + expr, + within_stencil=False, + allow_undeclared_symbols=True, + ) + for expr in range_exprs + ] + assert all(isinstance(expr, itir.Literal) for expr in range_exprs) + start, stop = (int(literal.value) for literal in range_exprs) + + if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: + nb_index = slice(None) + else: + nb_index = val.value + + accessed = offset_provider[off.value].ndarray[start:stop, nb_index] + + if np.any(accessed == skip_value): + raise NotImplementedError( + f"Translating '{self.as_expr()}' using '{shift[0].value}' contains " + f"skipped values. This is not supported." ) - for k, v in _max_domain_sizes_by_location_type(offset_provider).items() - } - start = 0 - stop = -1 - start_ = collapse_tuple.CollapseTuple.apply( - new_ranges[old_dim].start, - within_stencil=False, - allow_undeclared_symbols=True, - ) - stop_ = collapse_tuple.CollapseTuple.apply( - new_ranges[old_dim].stop, - within_stencil=False, - allow_undeclared_symbols=True, - ) - if isinstance(start_, itir.Literal) and isinstance(stop_, itir.Literal): - start = int(start_.value) - stop = int(stop_.value) - off_index = ( - slice(None) if val == trace_shifts.Sentinel.ALL_NEIGHBORS else val.value - ) - accessed = offset_provider[off.value].ndarray[start:stop, off_index] min_ = np.min(accessed) max_ = np.max(accessed) + 1 - if (covered := np.unique(accessed).size) < (max_ - min_) / 2: + fraction_accessed = np.unique(accessed).size / (max_ - min_) + + if ( + fraction_accessed < NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD + and (_NON_CONTIGUOUS_DOMAIN_WARNING_COUNTER := +1) + < NON_CONTIGUOUS_DOMAIN_MAX_WARNINGS + ): warnings.warn( UserWarning( - f"For {new_dim} the accessed range [{min_}, {max_}[ covers {max_ - min_} values, " - f"but only {covered} are actually present and {max_ - min_ - covered} were added " - f"in between {accessed}. Please consider reordering the mesh." + f"Translating '{self.as_expr()}' using '{shift[0].value}' requires " + f"computations on many additional points " + f"({round((1 - fraction_accessed) * 100)}%) in order to get a contiguous " + f"domain. Please consider reordering your mesh." ), stacklevel=2, ) - horizontal_sizes[new_dim.value] = ( + new_range = SymbolicRange( im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), im.literal(str(max_), builtins.INTEGER_INDEX_BUILTIN), ) - assert new_dim not in new_ranges or old_dim == new_dim - - new_range = SymbolicRange( - horizontal_sizes[new_dim.value][0], horizontal_sizes[new_dim.value][1] - ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) for dim, range_ in new_ranges.items() diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 1ad3334c98..53474bbcff 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -485,6 +485,7 @@ def domain( expr.type = ts.DomainType(dims=list(ranges.keys())) return expr + def get_field_domain( grid_type: Union[common.GridType, str], field: str | itir.Expr, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index bc94420c37..8e4ebeb29f 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -100,8 +100,9 @@ def validate(self, name: str, type_: ts.TypeSpec) -> None: def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: return {"value": arg_expr} + @dataclasses.dataclass(frozen=True) -class FieldDomainDescriptor(ArgumentDescriptor): +class FieldDomainDescriptor(ArgStaticDescriptor): domain: common.Domain @classmethod diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index 69a2ed772b..a0280fc040 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -7,13 +7,21 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest - +import numpy as np from gt4py.next import common +from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im +from gt4py.next import backend as next_backend, common, allocators as next_allocators, constructors I = common.Dimension("I") J = common.Dimension("J") +K = common.Dimension("J", kind=common.DimensionKind.VERTICAL) +Vertex = common.Dimension("Vertex") +Edge = common.Dimension("Edge") +V2EDim = common.Dimension("V2E", kind=common.DimensionKind.LOCAL) +E2VDim = common.Dimension("E2V", kind=common.DimensionKind.LOCAL) +V2VDim = common.Dimension("V2V", kind=common.DimensionKind.LOCAL) a_range = domain_utils.SymbolicRange(0, 10) another_range = domain_utils.SymbolicRange(5, 15) @@ -180,3 +188,94 @@ def test_is_finite_symbolic_domain(ranges, expected): ) == expected ) + + +@pytest.mark.parametrize( + "shift_chain, expected_end_domain", + [ + (("V2V", 0), {Vertex: (0, 4)}), + (("V2V", 1), {Vertex: (0, 4)}), + (("V2V", 2), {Vertex: (0, 1)}), + (("V2V", 3), {Vertex: (1, 4)}), + (("V2V", 0, "V2V", 3, "V2V", 0), {Vertex: (1, 4)}), + (("V2E", 0), {Edge: (0, 4)}), + (("V2E", 0, "E2V", 0), {Vertex: (0, 4)}), + (("V2V", 3, "V2E", 0), {Edge: (1, 4)}), + ], +) +def test_unstructured_translate(shift_chain, expected_end_domain): + offset_provider = { + "V2V": constructors.as_connectivity( + domain={Vertex: (0, 4), V2VDim: 5}, + codomain=Vertex, + data=np.asarray( + [[0, 3, 0, 1, -1], [1, 2, 0, 1, 1], [2, 1, 0, 3, 2], [3, 0, 0, 3, -1]], + dtype=fbuiltins.IndexType, + ), + ), + "V2E": constructors.as_connectivity( + domain={Vertex: (0, 4), V2EDim: 1}, + codomain=Edge, + data=np.asarray( + [ + [0, 1, 2, 3], + ], + dtype=fbuiltins.IndexType, + ).reshape((4, 1)), + ), + "E2V": constructors.as_connectivity( + domain={Edge: (0, 4), E2VDim: 1}, + codomain=Vertex, + data=np.asarray( + [ + [0, 1, 2, 3], + ], + dtype=fbuiltins.IndexType, + ).reshape((4, 1)), + ), + } + shift_chain = [im.ensure_offset(o) for o in shift_chain] + expected_end_domain = im.domain(common.GridType.UNSTRUCTURED, expected_end_domain) + + init_domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 4)}) + ) + end_domain = init_domain.translate(shift_chain, offset_provider).as_expr() + assert end_domain == expected_end_domain + + +def test_non_contiguous_domain_warning(): + offset_provider = { + "V2V": constructors.as_connectivity( + domain={Vertex: (0, 100), V2VDim: 1}, + codomain=Vertex, + data=np.asarray([0] + [99] * 99, dtype=fbuiltins.IndexType).reshape((100, 1)), + ) + } + shift_chain = ("V2V", 0) + shift_chain = [im.ensure_offset(o) for o in shift_chain] + domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 2)}) + ) + with pytest.warns( + UserWarning, + match=r"98%.*Please consider reordering your mesh.", + ): + domain.translate(shift_chain, offset_provider).as_expr() + + +def test_contains_skip_values_error(): + offset_provider = { + "V2V": constructors.as_connectivity( + domain={Vertex: (0, 3), V2VDim: 1}, + codomain=Vertex, + data=np.asarray([0, -1, 1], dtype=fbuiltins.IndexType).reshape((3, 1)), + ) + } + shift_chain = ("V2V", 0) + shift_chain = [im.ensure_offset(o) for o in shift_chain] + domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 3)}) + ) + with pytest.raises(NotImplementedError, match=r"contains skipped values"): + domain.translate(shift_chain, offset_provider).as_expr() diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py index dad33c4d2a..4b3fbf82aa 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py @@ -433,8 +433,12 @@ def test_trivial_cartesian_forward(): common.GridType.CARTESIAN, { IDim: ( - im.minus(im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(IDim))), 4), - im.minus(im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(IDim))), 4), + im.minus( + im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(IDim))), 4 + ), + im.minus( + im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(IDim))), 4 + ), ) }, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py index 98d52830e2..5ac023fac7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py @@ -85,7 +85,9 @@ def test_get_domain_tuples(): common.GridType.UNSTRUCTURED, im.tuple_get(1, "out"), sizes["out"][1].dims ) - run_test_program(["inp", "out"], sizes, "out", im.domain_as_expr(sizes["out"][1]), get_domain_expr) + run_test_program( + ["inp", "out"], sizes, "out", im.domain_as_expr(sizes["out"][1]), get_domain_expr + ) def test_get_domain_nested_tuples(): diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 40dabef1d6..80af3cee7d 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -113,6 +113,7 @@ def pirate(program: toolchain.CompilableProgram): _verify_program_has_expected_true_value(hijacked_program.data) + def _verify_program_has_expected_domain(program: itir.Program, expected_domain: gtx.Domain): assert isinstance(program.body[0], itir.SetAt) assert isinstance(program.body[0].expr, itir.FunCall) @@ -120,6 +121,7 @@ def _verify_program_has_expected_domain(program: itir.Program, expected_domain: domain = CollapseTuple.apply(program.body[0].domain, within_stencil=False) assert domain == im.domain(common.GridType.CARTESIAN, expected_domain) + def test_inlining_of_static_domain_works(): domain = gtx.Domain(dims=(TDim,), ranges=(gtx.UnitRange(0, 1),)) input_pair = toolchain.CompilableProgram( @@ -129,9 +131,11 @@ def test_inlining_of_static_domain_works(): kwargs={}, offset_provider={}, column_axis=None, - argument_descriptors={arguments.FieldDomainDescriptor: {"out": arguments.FieldDomainDescriptor(domain)}}, + argument_descriptors={ + arguments.FieldDomainDescriptor: {"out": arguments.FieldDomainDescriptor(domain)} + }, ), ) transformed = backend.DEFAULT_TRANSFORMS(input_pair).data - _verify_program_has_expected_domain(transformed, domain) \ No newline at end of file + _verify_program_has_expected_domain(transformed, domain) From a99b06a832c79c03f80a366723a229eddcf034a4 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 22 Sep 2025 21:40:27 +0200 Subject: [PATCH 60/93] Cleanup --- src/gt4py/next/ffront/decorator.py | 2 +- src/gt4py/next/iterator/ir_utils/domain_utils.py | 1 - src/gt4py/next/otf/arguments.py | 2 +- .../feature_tests/ffront_tests/test_program.py | 6 ++++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index bd0c171d43..f8133fc6e1 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -284,7 +284,7 @@ def path_to_expr(path: Sequence[int]): func_type = self.past_stage.past_node.type.definition param_types = func_type.pos_or_kw_args | func_type.kw_only_args for name, type_ in param_types.items(): - for el_type, path in type_info.primitive_constituents(type_, with_path_arg=True): + for _, path in type_info.primitive_constituents(type_, with_path_arg=True): static_domain_args.append(f"{name}{path_to_expr(path)}") argument_descriptor_mapping = { diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 331a3a381a..9af34d0a6d 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -143,7 +143,6 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - horizontal_sizes: dict[str, tuple[itir.Expr, itir.Expr]] old_dim = connectivity_type.source_dim new_dim = connectivity_type.codomain assert new_dim not in new_ranges or old_dim == new_dim diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 8e4ebeb29f..9daa4989af 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -106,7 +106,7 @@ class FieldDomainDescriptor(ArgStaticDescriptor): domain: common.Domain @classmethod - def attribute_extractor(cls, arg_expr: str): + def attribute_extractor_exprs(cls, arg_expr: str): return {"domain": f"({arg_expr}).domain"} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 88613e9c99..dff13d8dca 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -62,7 +62,7 @@ def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: # TODO(tehrengruber): slicing located fields not supported currently # shift_by_one(in_field, out=out_field[:-1], offset_provider={"Ioff": IDim}) - @gtx.program + @gtx.program(static_domains=True) def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatField): shift_by_one(in_field, out=out_field[:-1]) @@ -82,7 +82,9 @@ def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatFie def test_copy_execution(cartesian_case, copy_program_def): - copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend, static_domains=True) + copy_program = gtx.program( + copy_program_def, backend=cartesian_case.backend, static_domains=True + ) cases.verify_with_default_data(cartesian_case, copy_program, ref=lambda in_field: in_field) From fe90bff2dbc52673a15fc90d3f4212d69f41164e Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Tue, 23 Sep 2025 02:35:29 +0200 Subject: [PATCH 61/93] Cleanup --- src/gt4py/next/ffront/decorator.py | 23 ++++----- src/gt4py/next/ffront/past_to_itir.py | 2 + .../next/iterator/ir_utils/domain_utils.py | 51 ++++++------------- .../next/iterator/transforms/global_tmps.py | 2 +- .../next/iterator/transforms/infer_domain.py | 12 ++--- .../next/iterator/transforms/pass_manager.py | 49 ++++++++++++++++-- .../iterator/transforms/symbol_ref_utils.py | 14 ++--- .../transforms/transform_get_domain_range.py | 6 +-- src/gt4py/next/otf/arguments.py | 2 +- src/gt4py/next/otf/compiled_program.py | 2 +- .../codegens/gtfn/gtfn_module.py | 2 +- .../ffront_tests/test_compiled_program.py | 42 ++++++++++++++- .../ffront_tests/test_program.py | 2 +- 13 files changed, 137 insertions(+), 72 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index f8133fc6e1..9c5acf1033 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -273,24 +273,22 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.backend is None or self.backend == eve.NOTHING: raise RuntimeError("Cannot compile a program without backend.") - if self.static_params is None: - object.__setattr__(self, "static_params", ()) - - def path_to_expr(path: Sequence[int]): + def path_to_expr(path: Sequence[int]) -> str: return "".join(map(lambda idx: f"[{idx}]", path)) - static_domain_args = [] + argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} + + if self.static_params: + argument_descriptor_mapping[arguments.StaticArg] = self.static_params + if self.static_domains: - func_type = self.past_stage.past_node.type.definition + static_domain_args = [] + func_type = self.past_stage.past_node.type.definition # type: ignore[union-attr] # type inference done at this point param_types = func_type.pos_or_kw_args | func_type.kw_only_args for name, type_ in param_types.items(): for _, path in type_info.primitive_constituents(type_, with_path_arg=True): static_domain_args.append(f"{name}{path_to_expr(path)}") - - argument_descriptor_mapping = { - arguments.StaticArg: self.static_params, - arguments.FieldDomainDescriptor: static_domain_args, - } + argument_descriptor_mapping[arguments.FieldDomainDescriptor] = static_domain_args program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) @@ -298,7 +296,7 @@ def path_to_expr(path: Sequence[int]): backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, - argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible + argument_descriptor_mapping=argument_descriptor_mapping, ) def _extend_offset_provider( @@ -776,6 +774,7 @@ def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program: connectivities=None, enable_jit=False, # TODO(havogt): revisit ProgramFromPast static_params=None, # TODO(havogt): revisit ProgramFromPast + static_domains=False, # TODO(havogt): revisit ProgramFromPast ) def __call__(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index d7173b5d9d..c001a35b4c 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -96,6 +96,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) + # TODO(tehrengruber): Put this in a dedicated transformation step. if arguments.StaticArg in inp.args.argument_descriptors: static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] if not all( @@ -116,6 +117,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: body=body, ) + # TODO(tehrengruber): Put this in a dedicated transformation step. if arguments.FieldDomainDescriptor in inp.args.argument_descriptors: field_domains = { param: utils.tree_map(lambda x: x.domain)(v) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 9af34d0a6d..266c34021f 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -11,7 +11,7 @@ import dataclasses import functools import warnings -from typing import Any, Callable, Iterable, Literal, Mapping, Optional +from typing import Callable, Iterable, Literal, Optional import numpy as np @@ -33,27 +33,6 @@ _NON_CONTIGUOUS_DOMAIN_WARNING_COUNTER: int = 0 -def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: - """ - Extract horizontal domain sizes from an `offset_provider`. - - Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum - value inside the neighbor table to get the size of each `codomain`. - """ - sizes = dict[str, int]() - for provider in offset_provider.values(): - if common.is_neighbor_connectivity(provider): - conn_type = provider.__gt_type__() - sizes[conn_type.source_dim.value] = max( - sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] - ) - sizes[conn_type.codomain.value] = max( - sizes.get(conn_type.codomain.value, 0), - provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject - ) - return sizes - - @dataclasses.dataclass(frozen=True) class SymbolicRange: start: itir.Expr @@ -113,7 +92,7 @@ def translate( offset_provider: common.OffsetProvider | common.OffsetProviderType, #: A dictionary mapping axes names to their length. See #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, ) -> SymbolicDomain: offset_provider_type = common.offset_provider_to_type(offset_provider) @@ -146,34 +125,37 @@ def translate( old_dim = connectivity_type.source_dim new_dim = connectivity_type.codomain assert new_dim not in new_ranges or old_dim == new_dim - if symbolic_domain_sizes is not None: + if symbolic_domain_sizes is not None and new_dim.value in symbolic_domain_sizes: new_range = SymbolicRange( im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), im.ensure_expr(symbolic_domain_sizes[new_dim.value]), ) else: assert common.is_offset_provider(offset_provider) - skip_value = offset_provider[off.value].skip_value + connectivity = offset_provider[off.value] + assert isinstance(connectivity, common.Connectivity) + skip_value = connectivity.skip_value # fold & convert expr into actual integers range_exprs = new_ranges[old_dim].start, new_ranges[old_dim].stop - range_exprs = [ + range_exprs = tuple( collapse_tuple.CollapseTuple.apply( expr, within_stencil=False, allow_undeclared_symbols=True, ) for expr in range_exprs - ] + ) # type: ignore[assignment] # mypy not smart enough assert all(isinstance(expr, itir.Literal) for expr in range_exprs) - start, stop = (int(literal.value) for literal in range_exprs) + start, stop = (int(literal.value) for literal in range_exprs) # type: ignore[attr-defined] # mypy does not understand assert above + nb_index: slice | int if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: nb_index = slice(None) else: - nb_index = val.value + nb_index = val.value # type: ignore[assignment] # assert above - accessed = offset_provider[off.value].ndarray[start:stop, nb_index] + accessed = connectivity.ndarray[start:stop, nb_index] if np.any(accessed == skip_value): raise NotImplementedError( @@ -181,10 +163,9 @@ def translate( f"skipped values. This is not supported." ) - min_ = np.min(accessed) - max_ = np.max(accessed) + 1 + new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject - fraction_accessed = np.unique(accessed).size / (max_ - min_) + fraction_accessed = np.unique(accessed).size / (new_stop - new_start) # type: ignore[call-overload] # TODO(havogt): improve typing for NDArrayObject if ( fraction_accessed < NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD @@ -202,8 +183,8 @@ def translate( ) new_range = SymbolicRange( - im.literal(str(min_), builtins.INTEGER_INDEX_BUILTIN), - im.literal(str(max_), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str(new_start), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str(new_stop), builtins.INTEGER_INDEX_BUILTIN), ) new_ranges = dict( diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5bee4daf8d..5f94804e0e 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -319,7 +319,7 @@ def create_global_tmps( offset_provider: common.OffsetProvider | common.OffsetProviderType, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, *, uids: Optional[eve_utils.UIDGenerator] = None, ) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index c22b775468..b52e9060ef 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -55,7 +55,7 @@ class DomainAccessDescriptor(eve.StrEnum): class InferenceOptions(typing.TypedDict): offset_provider: common.OffsetProvider | common.OffsetProviderType - symbolic_domain_sizes: Optional[dict[str, str]] + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] allow_uninferred: bool keep_existing_domains: bool @@ -126,7 +126,7 @@ def _extract_accessed_domains( input_ids: list[str], target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]], + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], ) -> dict[str, NonTupleDomainAccess]: accessed_domains: dict[str, NonTupleDomainAccess] = {} @@ -182,7 +182,7 @@ def _infer_as_fieldop( target_domain: DomainAccess, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]], + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], allow_uninferred: bool, keep_existing_domains: bool, ) -> tuple[itir.FunCall, AccessedDomains]: @@ -441,7 +441,7 @@ def infer_expr( domain: DomainAccess, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, allow_uninferred: bool = False, keep_existing_domains: bool = False, ) -> tuple[itir.Expr, AccessedDomains]: @@ -457,7 +457,7 @@ def infer_expr( Keyword Arguments: - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol - name that evaluates to the length of that axis. + name or expression that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and @@ -550,7 +550,7 @@ def infer_program( program: itir.Program, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, allow_uninferred: bool = False, keep_existing_domains: bool = False, ) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 9c5e154011..70e1213d70 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -5,12 +5,12 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - -from typing import Optional, Protocol +from typing import Any, Mapping, Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import ( concat_where, dead_code_elimination, @@ -22,6 +22,7 @@ inline_fundefs, inline_lifts, remove_broadcast, + symbol_ref_utils, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -42,6 +43,42 @@ def __call__( ) -> itir.Program: ... +def _max_domain_range_sizes(offset_provider: Mapping[str, Any]) -> dict[str, itir.Literal]: + """ + Extract horizontal domain sizes from an `offset_provider`. + + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. + """ + sizes: dict[str, int] = {} + for provider in offset_provider.values(): + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] + ) + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + ) + sizes_exprs = {k: im.literal_from_value(v) for k, v in sizes.items()} + return sizes_exprs + + +def _has_dynamic_domains(ir: itir.Program) -> bool: + # note: this function does not respect symbol collisions with builtins. As it is a temporary + # workaround we don't care about this corner cases. + domains = set() + domains |= ir.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + for as_fop in ( + ir.walk_values() + .if_isinstance(itir.FunCall) + .filter(lambda node: cpm.is_call_to(node, "as_fieldop") and len(node.args) == 2) + ): + domains.add(as_fop.args[1]) + return len(symbol_ref_utils.collect_symbol_refs(domains)) > 0 + + # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( @@ -56,12 +93,18 @@ def apply_common_transforms( force_inline_lambda_args=False, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, ) -> itir.Program: assert isinstance(ir, itir.Program) offset_provider_type = common.offset_provider_to_type(offset_provider) + # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins + # to work with / translate domains. + if _has_dynamic_domains(ir): + assert not symbolic_domain_sizes, "Options are mutually exclusive." + symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] + tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() collapse_tuple_uids = eve_utils.UIDGenerator() diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index d7ba35eed0..e87009c564 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -10,7 +10,7 @@ from collections import Counter import gt4py.eve as eve -from gt4py.eve.extended_typing import Iterable, Literal, Optional, Sequence, cast, overload +from gt4py.eve.extended_typing import Iterable, Literal, Optional, cast, overload from gt4py.next.iterator import ir as itir @@ -22,7 +22,7 @@ class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): @classmethod def apply( cls, - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -33,7 +33,7 @@ def apply( @classmethod def apply( cls, - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -43,7 +43,7 @@ def apply( @classmethod def apply( cls, - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -101,7 +101,7 @@ def visit_Lambda(self, node: itir.Lambda, *, inactive_refs: set[str]): @overload def collect_symbol_refs( - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -111,7 +111,7 @@ def collect_symbol_refs( @overload def collect_symbol_refs( - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -120,7 +120,7 @@ def collect_symbol_refs( def collect_symbol_refs( - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py b/src/gt4py/next/iterator/transforms/transform_get_domain_range.py index 27de217d10..bbd5513310 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py +++ b/src/gt4py/next/iterator/transforms/transform_get_domain_range.py @@ -7,10 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Dict from gt4py._core import definitions as core_defs from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -28,7 +28,7 @@ def visit_Node(self, node: itir.Node, **kwargs): return None # means we could not deduce the domain def visit_SymRef( - self, node: itir.SymRef, *, sizes: Dict[str, common.Domain], **kwargs + self, node: itir.SymRef, *, sizes: dict[str, MaybeNestedInTuple[common.Domain]], **kwargs ) -> DomainOrTupleThereof | None: return sizes.get(node.id, None) @@ -94,7 +94,7 @@ class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): """ @classmethod - def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): + def apply(cls, program: itir.Program, sizes: dict[str, MaybeNestedInTuple[common.Domain]]): return cls().visit(program, sizes=sizes) def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 9daa4989af..608d56696d 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -106,7 +106,7 @@ class FieldDomainDescriptor(ArgStaticDescriptor): domain: common.Domain @classmethod - def attribute_extractor_exprs(cls, arg_expr: str): + def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: return {"domain": f"({arg_expr}).domain"} diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 126f25861a..0f13547fe8 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -246,7 +246,7 @@ def _initialize_argument_descriptor_mapping( self._validate_argument_descriptor_mapping() else: for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): - if (expected := set(self.argument_descriptor_mapping[descr_cls])) != ( + if (expected := set(self.argument_descriptor_mapping.get(descr_cls, {}))) != ( got := set(descriptor_expr_mapping.keys()) ): raise ValueError( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 5bc796de9d..ed1b657129 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -52,7 +52,7 @@ class GTFNTranslationStep( enable_itir_transforms: bool = True use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 609836be29..eb67f270a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +from typing import Optional from unittest import mock import numpy as np @@ -30,6 +30,7 @@ skip_value_mesh, ) +from gt4py.next.otf import arguments _raise_on_compile = mock.Mock() _raise_on_compile.compile.side_effect = AssertionError("This function should never be called.") @@ -812,3 +813,42 @@ def test_compile_variants_tuple(cartesian_case, compile_variants_testee_tuple): offset_provider=cartesian_case.offset_provider, ) assert np.allclose(out.asnumpy(), field_a.asnumpy() * 3 + field_b.asnumpy() * 4) + + +def test_compile_variants_decorator_static_domains(compile_variants_field_operator, cartesian_case): + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + captured_cargs: Optional[arguments.CompileTimeArgs] = None + + class CaptureCompileTimeArgsBackend: + def __getattr__(self, name): + return getattr(cartesian_case.backend, name) + + def compile(self, program, compile_time_args): + nonlocal captured_cargs + captured_cargs = compile_time_args + + return cartesian_case.backend.compile(program, compile_time_args) + + @gtx.field_operator + def identity(inp: cases.IField): + return inp + + @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) + def testee(inp: cases.IField, out: cases.IField): + identity(inp, out=out) + + inp = cases.allocate(cartesian_case, testee, "inp")() + out = cases.allocate(cartesian_case, testee, "out")() + + testee(inp, out, offset_provider={}) + assert np.allclose(out.ndarray, inp.ndarray) + + assert testee._compiled_programs.argument_descriptor_mapping[ + arguments.FieldDomainDescriptor + ] == ["inp", "out"] + assert captured_cargs.argument_descriptors[arguments.FieldDomainDescriptor] == { + "inp": arguments.FieldDomainDescriptor(inp.domain), + "out": arguments.FieldDomainDescriptor(out.domain), + } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index dff13d8dca..47cdd0e3d1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -62,7 +62,7 @@ def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: # TODO(tehrengruber): slicing located fields not supported currently # shift_by_one(in_field, out=out_field[:-1], offset_provider={"Ioff": IDim}) - @gtx.program(static_domains=True) + @gtx.program def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatField): shift_by_one(in_field, out=out_field[:-1]) From 639508642eb0490f3f701b4a0e90fd5f74f0f1e9 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Tue, 23 Sep 2025 03:04:57 +0200 Subject: [PATCH 62/93] Fix fieldview transforms --- src/gt4py/next/iterator/transforms/pass_manager.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 70e1213d70..d0ca18fff9 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -215,6 +215,13 @@ def apply_fieldview_transforms( ) -> itir.Program: offset_provider_type = common.offset_provider_to_type(offset_provider) + # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins + # to work with / translate domains. + if _has_dynamic_domains(ir): + symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] + else: + symbolic_domain_sizes = None + ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = dead_code_elimination.dead_code_elimination(ir, offset_provider_type=offset_provider_type) @@ -226,6 +233,8 @@ def apply_fieldview_transforms( ir = concat_where.canonicalize_domain_argument(ir) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program( + ir, symbolic_domain_sizes=symbolic_domain_sizes, offset_provider=offset_provider + ) ir = remove_broadcast.RemoveBroadcast.apply(ir) return ir From d71ae2260fd631a993e6652d71ceeeedda0f9c18 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Tue, 23 Sep 2025 03:06:13 +0200 Subject: [PATCH 63/93] Fix fieldview transforms --- src/gt4py/next/iterator/transforms/pass_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index d0ca18fff9..7220951287 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -218,7 +218,7 @@ def apply_fieldview_transforms( # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins # to work with / translate domains. if _has_dynamic_domains(ir): - symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] + symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) else: symbolic_domain_sizes = None @@ -234,7 +234,9 @@ def apply_fieldview_transforms( ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( - ir, symbolic_domain_sizes=symbolic_domain_sizes, offset_provider=offset_provider + ir, + symbolic_domain_sizes=symbolic_domain_sizes, # type: ignore[arg-type] + offset_provider=offset_provider, ) ir = remove_broadcast.RemoveBroadcast.apply(ir) return ir From b38d3b1b307c4b0617abb7c0fe09d6a6eb6cdaaa Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Tue, 23 Sep 2025 03:07:23 +0200 Subject: [PATCH 64/93] Cleanup --- .../test_temporary_domain_inference.py | 502 ------------------ 1 file changed, 502 deletions(-) delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py deleted file mode 100644 index 4b3fbf82aa..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_temporary_domain_inference.py +++ /dev/null @@ -1,502 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Optional, Dict - -import pytest -from next_tests.integration_tests.cases import ( - IDim, - Vertex, - mesh_descriptor, - exec_alloc_descriptor, -) -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - simple_cartesian_grid, - Edge, - simple_mesh, -) - -from gt4py import next as gtx -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ( - ir_makers as im, -) -from gt4py.next.iterator.transforms import inline_fundefs, global_tmps -from gt4py.next.iterator.transforms.transform_get_domain_range import TransformGetDomainRange -from gt4py.next.type_system import type_specifications as ts - -IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) - -float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -v_field_type = ts.FieldType(dims=[Vertex], dtype=float_type) -e_field_type = ts.FieldType(dims=[Edge], dtype=float_type) -i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) - - -# override mesh descriptor to contain only the simple mesh -@pytest.fixture -def mesh_descriptor(): - return simple_mesh(None) - - -def program_factory( - params: list[itir.Sym], - body: list[itir.SetAt], - declarations: Optional[list[itir.Temporary]] = None, -) -> itir.Program: - return itir.Program( - id="testee", - function_definitions=[], - params=params, - declarations=declarations or [], - body=body, - ) - - -def run_test_program( - testee: itir.Program, - expected: itir.Program, - sizes: Dict[str, common.Domain], - offset_provider: common.OffsetProvider, -) -> None: - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomainRange.apply(ir, sizes=sizes) - actual_program = global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - assert actual_program == expected - - -def test_trivial_shift(mesh_descriptor): - sizes = {"out": gtx.domain({Edge: (9, 13), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain( - common.GridType.UNSTRUCTURED, - { - Edge: ( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(Edge))), - ) - }, - ) - - unstructured_domain_E = im.domain( - common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(9, 13)), im.tuple_get(1, im.make_tuple(9, 13)))}, - ) - - unstructured_domain_V_37 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (3, 7)}) - - offset_provider = mesh_descriptor.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - im.as_fieldop("deref")("vertex_values") - ), - domain=unstructured_domain_get_E, - ) - ], - ) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_37, dtype=float_type) - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V_37)("vertex_values"), - domain=unstructured_domain_V_37, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E - )("__tmp_1"), - domain=unstructured_domain_E, - ), - ], - ) - - run_test_program(testee, expected, sizes, offset_provider) - - -def test_trivial_shift_warning(mesh_descriptor): - with pytest.warns( - UserWarning, - match=r"For Vertex\[horizontal\] the accessed range \[3, 9\[ covers 6 values, " - r"but only 2 are actually present and 4 were added in between \[8 3\]\. " - r"Please consider reordering the mesh\.", - ): - sizes = {"out": gtx.domain({Edge: (8, 10), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain( - common.GridType.UNSTRUCTURED, - { - Edge: ( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(Edge))), - ) - }, - ) - - offset_provider = mesh_descriptor.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - im.as_fieldop("deref")("vertex_values") - ), - domain=unstructured_domain_get_E, - ) - ], - ) - ir = inline_fundefs.InlineFundefs().visit(testee) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = TransformGetDomainRange.apply(ir, sizes=sizes) - - global_tmps.create_global_tmps(ir, offset_provider=offset_provider) - - -def test_trivial_shift_switched(mesh_descriptor): - sizes = {"out": gtx.domain({Edge: (2, 16), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain( - common.GridType.UNSTRUCTURED, - { - Edge: ( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(Edge))), - ) - }, - ) - - unstructured_domain_E = im.domain( - common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(2, 16)), im.tuple_get(1, im.make_tuple(2, 16)))}, - ) - - offset_provider = mesh_descriptor.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref")( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - "vertex_values" - ) - ), - domain=unstructured_domain_get_E, - ) - ], - ) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=unstructured_domain_E, dtype=float_type)], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), unstructured_domain_E - )("vertex_values"), - domain=unstructured_domain_E, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop("deref", unstructured_domain_E)("__tmp_1"), - domain=unstructured_domain_E, - ), - ], - ) - - run_test_program(testee, expected, sizes, offset_provider) - - -def test_two_shifts(mesh_descriptor): - sizes = {"out": gtx.domain({Edge: (3, 8), Vertex: (0, 9)})} - unstructured_domain_get_E = im.domain( - common.GridType.UNSTRUCTURED, - { - Edge: ( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(Edge))), - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(Edge))), - ) - }, - ) - - unstructured_domain_E = im.domain( - common.GridType.UNSTRUCTURED, - {Edge: (im.tuple_get(0, im.make_tuple(3, 8)), im.tuple_get(1, im.make_tuple(3, 8)))}, - ) - - unstructured_domain_V_39 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (3, 9)}) - - offset_provider = mesh_descriptor.offset_provider - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")( - im.plus( - im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) - ) - ) - )(im.as_fieldop("deref")("vertex_values")), - domain=unstructured_domain_get_E, - ) - ], - ) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_V_39, dtype=float_type) - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", unstructured_domain_V_39)("vertex_values"), - domain=unstructured_domain_V_39, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")( - im.plus( - im.deref(im.shift("E2V", 0)("x")), im.deref(im.shift("E2V", 1)("x")) - ) - ), - unstructured_domain_E, - )("__tmp_1"), - domain=unstructured_domain_E, - ), - ], - ) - - run_test_program(testee, expected, sizes, offset_provider) - - -def test_nested_shift(mesh_descriptor): - sizes = {"out": gtx.domain({Edge: (0, 18), Vertex: (3, 7)})} - unstructured_domain_V = im.domain( - common.GridType.UNSTRUCTURED, - {Vertex: (im.tuple_get(0, im.make_tuple(3, 7)), im.tuple_get(1, im.make_tuple(3, 7)))}, - ) - unstructured_domain_get_V = im.domain( - common.GridType.UNSTRUCTURED, - { - Vertex: ( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(Vertex))), - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(Vertex))), - ) - }, - ) - - unstructured_domain_V_69 = im.domain(common.GridType.UNSTRUCTURED, {Vertex: (6, 9)}) - - unstructured_domain_E_1216 = im.domain(common.GridType.UNSTRUCTURED, {Edge: (12, 16)}) - - offset_provider = mesh_descriptor.offset_provider - - testee = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", v_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("V2E", 3)("x"))))( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))))( - im.as_fieldop("deref")("vertex_values") - ) - ), - domain=unstructured_domain_get_V, - ) - ], - ) - - expected = program_factory( - params=[im.sym("vertex_values", v_field_type), im.sym("out", e_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=unstructured_domain_E_1216, dtype=float_type), - itir.Temporary(id="__tmp_2", domain=unstructured_domain_V_69, dtype=float_type), - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_2"), - expr=im.as_fieldop("deref", unstructured_domain_V_69)("vertex_values"), - domain=unstructured_domain_V_69, - ), - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("E2V", 1)("x"))), - unstructured_domain_E_1216, - )("__tmp_2"), - domain=unstructured_domain_E_1216, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("V2E", 3)("x"))), unstructured_domain_V - )("__tmp_1"), - domain=unstructured_domain_V, - ), - ], - ) - - run_test_program(testee, expected, sizes, offset_provider) - - -def test_trivial_cartesian(): - grid = simple_cartesian_grid() - offset_provider = {"Ioff": grid.offset_provider["Ioff"]} - sizes = {"out": gtx.domain({IDim: (2, 7)})} - - cartesian_domain = im.domain( - common.GridType.CARTESIAN, - {IDim: (im.tuple_get(0, im.make_tuple(2, 7)), im.tuple_get(1, im.make_tuple(2, 7)))}, - ) - cartesian_domain_get = im.domain( - common.GridType.CARTESIAN, - { - IDim: ( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(IDim))), - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(IDim))), - ) - }, - ) - - cartesian_domain_27_p1 = im.domain( - common.GridType.CARTESIAN, - { - IDim: ( - im.plus(im.tuple_get(0, im.make_tuple(2, 7)), 1), - im.plus(im.tuple_get(1, im.make_tuple(2, 7)), 1), - ) - }, - ) - testee = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))))( - im.as_fieldop("deref")("i_values") - ), - domain=cartesian_domain_get, - ) - ], - ) - - expected = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - declarations=[ - itir.Temporary(id="__tmp_1", domain=cartesian_domain_27_p1, dtype=float_type) - ], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop("deref", cartesian_domain_27_p1)("i_values"), - domain=cartesian_domain_27_p1, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 1)("x"))), cartesian_domain - )("__tmp_1"), - domain=cartesian_domain, - ), - ], - ) - - run_test_program(testee, expected, sizes, offset_provider) - - -def test_trivial_cartesian_forward(): - grid = simple_cartesian_grid() - offset_provider = {"Ioff": grid.offset_provider["Ioff"]} - sizes = {"out": gtx.domain({IDim: (2, 7)})} - - cartesian_domain_get = im.domain( - common.GridType.CARTESIAN, - { - IDim: ( - im.minus( - im.tuple_get(0, im.call("get_domain_range")("out", im.axis_literal(IDim))), 4 - ), - im.minus( - im.tuple_get(1, im.call("get_domain_range")("out", im.axis_literal(IDim))), 4 - ), - ) - }, - ) - testee = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - body=[ - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), - )( - im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), - )("i_values") - ), - domain=cartesian_domain_get, - ) - ], - ) - - cartesian_domain_m2 = im.domain( - common.GridType.CARTESIAN, - { - IDim: ( - im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 2), - im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 2), - ) - }, - ) - - cartesian_domain_m4 = im.domain( - common.GridType.CARTESIAN, - { - IDim: ( - im.minus(im.tuple_get(0, im.make_tuple(2, 7)), 4), - im.minus(im.tuple_get(1, im.make_tuple(2, 7)), 4), - ) - }, - ) - expected = program_factory( - params=[im.sym("i_values", i_field_type), im.sym("out", i_field_type)], - declarations=[itir.Temporary(id="__tmp_1", domain=cartesian_domain_m2, dtype=float_type)], - body=[ - itir.SetAt( - target=im.ref("__tmp_1"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m2 - )("i_values"), - domain=cartesian_domain_m2, - ), - itir.SetAt( - target=im.ref("out"), - expr=im.as_fieldop( - im.lambda_("x")(im.deref(im.shift("Ioff", 2)("x"))), cartesian_domain_m4 - )("__tmp_1"), - domain=cartesian_domain_m4, - ), - ], - ) - - run_test_program(testee, expected, sizes, offset_provider) From ccdf31256af017046c4af775ae7e368b08b72b02 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 23 Sep 2025 13:30:09 +0200 Subject: [PATCH 65/93] Cleanup --- .../dace_tests/test_dace_domain.py | 4 +- .../dace_tests/test_dace_translation.py | 6 +- .../dace_tests/test_gtir_to_sdfg.py | 100 +++++++++--------- 3 files changed, 54 insertions(+), 56 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py index 4ac3d3140d..1159f51fce 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py @@ -39,7 +39,7 @@ def test_simplify_domain_expr(param): domain_expr = im.domain( gtx_common.GridType.CARTESIAN, - ranges={ + { Cell: ("horizontal_start", "horizontal_end"), KDim: ("vertical_start", "vertical_end"), }, @@ -58,7 +58,7 @@ def test_gtir_domain(): ir = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: (1, 10), KDim: (2, 20), }, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py index f469fa1a51..83d24f25eb 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_translation.py @@ -70,7 +70,7 @@ def test_find_constant_symbols(has_unit_stride): )("x"), domain=im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={Vertex: (0, "h_size"), KDim: (0, "v_size")}, + {Vertex: (0, "h_size"), KDim: (0, "v_size")}, ), target=itir.SymRef(id="y"), ) @@ -210,7 +210,7 @@ def test_generate_sdfg_async_call(make_async_sdfg_call: bool, device_type: core_ body=[ itir.SetAt( expr=im.op_as_fieldop("plus")("x", 1.0), - domain=im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "N")}), + domain=im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, "N")}), target=itir.SymRef(id="y"), ), ], @@ -247,7 +247,7 @@ def test_generate_sdfg_async_call_no_map(device_type: core_defs.DeviceType): body=[ itir.SetAt( expr=itir.SymRef(id="x"), - domain=im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "N")}), + domain=im.domain(gtx_common.GridType.CARTESIAN, {IDim: (0, "N")}), target=itir.SymRef(id="y"), ), ], diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index ea61552192..5a62da5bfe 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -147,7 +147,7 @@ def test_gtir_broadcast(): val = np.random.rand() domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim)}, + {IDim: get_domain_range("x", IDim)}, ) testee = gtir.Program( id="gtir_broadcast", @@ -176,7 +176,7 @@ def test_gtir_broadcast(): def test_gtir_cast(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) IFTYPE_FLOAT32 = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) IFTYPE_BOOL = ts.FieldType(IFTYPE.dims, dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)) @@ -212,7 +212,7 @@ def test_gtir_cast(): def test_gtir_copy_self(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, 2)}) + domain = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (1, 2)}) testee = gtir.Program( id="gtir_copy_self", function_definitions=[], @@ -241,7 +241,7 @@ def test_gtir_copy_self(): def test_gtir_tuple_swap(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim)}, + {IDim: get_domain_range("x", IDim)}, ) testee = gtir.Program( id="gtir_tuple_swap", @@ -275,7 +275,7 @@ def test_gtir_tuple_swap(): def test_gtir_tuple_args(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim)}, + {IDim: get_domain_range("y", IDim)}, ) testee = gtir.Program( id="gtir_tuple_args", @@ -329,7 +329,7 @@ def test_gtir_tuple_args(): def test_gtir_tuple_expr(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id="gtir_tuple_expr", @@ -372,7 +372,7 @@ def test_gtir_tuple_expr(): def test_gtir_tuple_broadcast_scalar(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim)}, + {IDim: get_domain_range("y", IDim)}, ) testee = gtir.Program( id="gtir_tuple_broadcast_scalar", @@ -430,9 +430,9 @@ def test_gtir_tuple_broadcast_scalar(): def test_gtir_zero_dim_fields(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim)}, + {IDim: get_domain_range("y", IDim)}, ) - empty_domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={}) + empty_domain = im.domain(gtx_common.GridType.CARTESIAN, {}) testee = gtir.Program( id="gtir_zero_dim_fields", function_definitions=[], @@ -464,7 +464,7 @@ def test_gtir_zero_dim_fields(): def test_gtir_tuple_return(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range(im.tuple_get(0, im.tuple_get(0, "z")), IDim)}, + {IDim: get_domain_range(im.tuple_get(0, im.tuple_get(0, "z")), IDim)}, ) testee = gtir.Program( id="gtir_tuple_return", @@ -516,7 +516,7 @@ def test_gtir_tuple_return(): def test_gtir_tuple_target(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim)}, + {IDim: get_domain_range("x", IDim)}, ) testee = gtir.Program( id="gtir_tuple_target", @@ -549,7 +549,7 @@ def test_gtir_tuple_target(): def test_gtir_update(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim)}, + {IDim: get_domain_range("x", IDim)}, ) stencil1 = im.as_fieldop( im.lambda_("a")(im.plus(im.deref("a"), im.plus(im.minus(0.0, 2.0), 1.0))), @@ -585,7 +585,7 @@ def test_gtir_update(): def test_gtir_sum2(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id="sum_2fields", @@ -618,7 +618,7 @@ def test_gtir_sum2(): def test_gtir_sum2_sym(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id="sum_2fields_sym", @@ -649,7 +649,7 @@ def test_gtir_sum2_sym(): def test_gtir_sum3(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) stencil1 = im.op_as_fieldop("plus", domain)( "x", @@ -697,7 +697,7 @@ def test_gtir_sum3(): def test_gtir_cond(s1, s2): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id=f"cond_2sums_{s1}_{s2}", @@ -743,7 +743,7 @@ def test_gtir_cond(s1, s2): def test_gtir_cond_with_tuple_return(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={ + { IDim: get_domain_range(im.tuple_get(0, "z"), IDim), }, ) @@ -801,7 +801,7 @@ def test_gtir_cond_with_tuple_return(): def test_gtir_cond_nested(s1, s2): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id=f"cond_nested_{int(s1)}_{int(s2)}", @@ -844,7 +844,7 @@ def test_gtir_cartesian_shift_left(): OFFSET = 1 domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={ + { IDim: get_domain_range("x", IDim, (0, OFFSET)), }, ) @@ -949,7 +949,7 @@ def test_gtir_cartesian_shift_right(): OFFSET = 1 domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim, (OFFSET, 0))}, + {IDim: get_domain_range("x", IDim, (OFFSET, 0))}, ) # cartesian shift with literal integer offset @@ -1052,18 +1052,18 @@ def test_gtir_connectivity_shift(): E2V_neighbor_idx = 1 edge_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={Edge: get_domain_range("ce_field", Edge)}, + {Edge: get_domain_range("ce_field", Edge)}, ) ce_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Cell: get_domain_range("ce_field", Cell), Edge: get_domain_range("ce_field", Edge), }, ) cv_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Cell: get_domain_range("ce_field", Cell), Vertex: get_domain_range("ev_field", Vertex), }, @@ -1218,7 +1218,7 @@ def test_gtir_connectivity_shift_chain(): V2E_neighbor_idx = 2 edge_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={Edge: get_domain_range("edges", Edge)}, + {Edge: get_domain_range("edges", Edge)}, ) testee = gtir.Program( id="connectivity_shift_chain", @@ -1285,14 +1285,14 @@ def test_gtir_neighbors_as_input(): init_value = np.random.rand() outer_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), KDim: get_domain_range("vertices", KDim, (MARGIN, MARGIN)), }, ) inner_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("x", Vertex), KDim: get_domain_range("x", KDim), }, @@ -1383,14 +1383,14 @@ def test_gtir_neighbors_as_output(): pytest.skip("Field of lists not fully supported by GTIR type inference") v2e_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), V2EDim: (0, SIMPLE_MESH.offset_provider_type["V2E"].max_neighbors), }, ) vertex_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), }, ) @@ -1435,7 +1435,7 @@ def test_gtir_reduce(): init_value = np.random.rand() vertex_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), }, ) @@ -1494,7 +1494,7 @@ def test_gtir_reduce_with_skip_values(): init_value = np.random.rand() vertex_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), }, ) @@ -1557,7 +1557,7 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), }, ) @@ -1631,7 +1631,7 @@ def test_gtir_reduce_with_cond_neighbors(use_sparse): init_value = np.random.rand() vertex_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Vertex: get_domain_range("vertices", Vertex), }, ) @@ -1709,16 +1709,14 @@ def test_gtir_symbolic_domain(): MARGIN = 2 assert MARGIN < N OFFSET = 1000 * 1000 * 1000 - domain = im.domain( - gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} - ) + domain = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (MARGIN, im.minus("size", MARGIN))}) left_domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + {IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, ) right_domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + {IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, ) shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) @@ -1793,11 +1791,11 @@ def test_gtir_symbolic_domain(): def test_gtir_let_lambda(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim)}, + {IDim: get_domain_range("y", IDim)}, ) subdomain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim, (1, 1))}, + {IDim: get_domain_range("y", IDim, (1, 1))}, ) testee = gtir.Program( id="let_lambda", @@ -1844,10 +1842,10 @@ def test_gtir_let_lambda(): def test_gtir_let_lambda_scalar_expression(): - domain_inner = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, "size_inner")}) + domain_inner = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (1, "size_inner")}) domain_outer = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim)}, + {IDim: get_domain_range("y", IDim)}, ) testee = gtir.Program( id="let_lambda_scalar_expression", @@ -1901,7 +1899,7 @@ def test_gtir_let_lambda_with_connectivity(): C2V_neighbor_idx = 2 cell_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={Cell: get_domain_range("cells", Cell)}, + {Cell: get_domain_range("cells", Cell)}, ) connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] @@ -1966,7 +1964,7 @@ def test_gtir_let_lambda_with_origin(): C2E_neighbor_idx = 1 cell_domain = im.domain( gtx_common.GridType.UNSTRUCTURED, - ranges={ + { Cell: get_domain_range("cells", Cell), KDim: get_domain_range("cells", KDim, (1, 0)), }, @@ -2030,7 +2028,7 @@ def test_gtir_let_lambda_with_origin(): def test_gtir_let_lambda_with_cond(s): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("y", IDim)}, + {IDim: get_domain_range("y", IDim)}, ) testee = gtir.Program( id=f"let_lambda_with_cond_{int(s)}", @@ -2065,7 +2063,7 @@ def test_gtir_let_lambda_with_cond(s): def test_gtir_let_lambda_with_tuple1(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range(im.tuple_get(0, "z"), IDim)}, + {IDim: get_domain_range(im.tuple_get(0, "z"), IDim)}, ) testee = gtir.Program( id="let_lambda_with_tuple1", @@ -2120,7 +2118,7 @@ def test_gtir_let_lambda_with_tuple1(): def test_gtir_let_lambda_with_tuple2(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={ + { IDim: get_domain_range(im.tuple_get(0, "z"), IDim), }, ) @@ -2178,7 +2176,7 @@ def test_gtir_let_lambda_with_tuple2(): def test_gtir_if_scalars(s): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id=f"if_scalars_{int(s)}", @@ -2242,7 +2240,7 @@ def test_gtir_if_scalars(s): def test_gtir_if_values(): domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) testee = gtir.Program( id="if_values", @@ -2279,11 +2277,11 @@ def test_gtir_index(): assert (MARGIN * 2) < N domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim)}, + {IDim: get_domain_range("x", IDim)}, ) subdomain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("x", IDim, (MARGIN, MARGIN))}, + {IDim: get_domain_range("x", IDim, (MARGIN, MARGIN))}, ) testee = gtir.Program( @@ -2326,7 +2324,7 @@ def test_gtir_concat_where(): assert SUBSET_SIZE < N domain = im.domain( gtx_common.GridType.CARTESIAN, - ranges={IDim: get_domain_range("z", IDim)}, + {IDim: get_domain_range("z", IDim)}, ) domain_cond_lhs = im.domain( gtx_common.GridType.CARTESIAN, {IDim: (gtir.InfinityLiteral.NEGATIVE, N - SUBSET_SIZE)} From bba5db215062ed3c7a867664af44865ed335c28f Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 24 Sep 2025 19:26:42 +0200 Subject: [PATCH 66/93] Cleanup --- src/gt4py/next/ffront/stages.py | 13 ++++++----- src/gt4py/next/otf/compiled_program.py | 22 +++++++++++++++++-- .../ffront_tests/test_compiled_program.py | 2 +- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 46c2e9811d..0c590b6f44 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -129,12 +129,15 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in sorted( - obj.items(), - key=lambda x: (x[0].__module__, x[0].__qualname__) if isinstance(x[0], type) else x[0], - ): + # just a small helper to additionally allow sorting types (by just using their name) + def key_function(key: Any) -> Any: + if isinstance(key, type): + return key + return key + + for key in sorted(obj, key=key_function): add_content_to_fingerprint(key, hasher) - add_content_to_fingerprint(value, hasher) + add_content_to_fingerprint(obj[key], hasher) @add_content_to_fingerprint.register diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 0f13547fe8..cb4d3827ca 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -55,6 +55,22 @@ def _make_param_context_from_func_type( func_type: ts.FunctionType, type_map: Callable[[ts.TypeSpec], T] = lambda x: x, # type: ignore[assignment, return-value] # mypy not smart enough to narrow type for default ) -> dict[str, extended_typing.MaybeNestedInTuple[T]]: + """ + Create a context to evaluate expressions in from a function type. + + >>> int32_t, int64_t = ( + ... ts.ScalarType(kind=ts.ScalarKind.INT32), + ... ts.ScalarType(kind=ts.ScalarKind.INT64), + ... ) + >>> type_ = ts.FunctionType( + ... pos_only_args=[], + ... pos_or_kw_args={"inp1": ts.TupleType(types=[int32_t, int64_t])}, + ... kw_only_args={"inp2": int64_t}, + ... returns=int64_t, + ... ) + >>> context = _make_param_context_from_func_type(type_) + >>> assert context == {"inp1": (int32_t, int64_t), "inp2": int64_t} + """ params = func_type.pos_or_kw_args | func_type.kw_only_args return { param: type_info.apply_to_primitive_constituents( @@ -155,8 +171,8 @@ def __call__( """ args, kwargs = type_info.canonicalize_arguments(self.program_type, args, kwargs) static_args_values = self._argument_descriptor_cache_key_from_args(*args, **kwargs) - # TODO: Dispatching over offset provider type is wrong, especially when we use compile time - # domains. + # TODO(tehrengruber): Dispatching over offset provider type is wrong, especially when we + # use compile time domains. key = (static_args_values, self._offset_provider_to_type_unsafe(offset_provider)) try: self._compiled_programs[key](*args, **kwargs, offset_provider=offset_provider) @@ -315,6 +331,8 @@ def _offset_provider_to_type_unsafe( self._offset_provider_type_cache[offset_provider] = op_type return op_type + # TODO(tehrengruber): Rework the interface to allow precompilation with compile time + # domains. def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index eb67f270a6..3cdf52a232 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -836,7 +836,7 @@ def identity(inp: cases.IField): return inp @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) - def testee(inp: cases.IField, out: cases.IField): + def testee(inp: tuple[cases.IField, cases.IField], out: cases.IField): identity(inp, out=out) inp = cases.allocate(cartesian_case, testee, "inp")() From d0297a626a1dfb15d60ca6ac375e337d535b0dbc Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 24 Sep 2025 19:26:42 +0200 Subject: [PATCH 67/93] Cleanup --- src/gt4py/next/ffront/stages.py | 13 ++++++++----- src/gt4py/next/otf/compiled_program.py | 22 ++++++++++++++++++++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 46c2e9811d..0c590b6f44 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -129,12 +129,15 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in sorted( - obj.items(), - key=lambda x: (x[0].__module__, x[0].__qualname__) if isinstance(x[0], type) else x[0], - ): + # just a small helper to additionally allow sorting types (by just using their name) + def key_function(key: Any) -> Any: + if isinstance(key, type): + return key + return key + + for key in sorted(obj, key=key_function): add_content_to_fingerprint(key, hasher) - add_content_to_fingerprint(value, hasher) + add_content_to_fingerprint(obj[key], hasher) @add_content_to_fingerprint.register diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 126f25861a..d31d31a6b7 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -55,6 +55,22 @@ def _make_param_context_from_func_type( func_type: ts.FunctionType, type_map: Callable[[ts.TypeSpec], T] = lambda x: x, # type: ignore[assignment, return-value] # mypy not smart enough to narrow type for default ) -> dict[str, extended_typing.MaybeNestedInTuple[T]]: + """ + Create a context to evaluate expressions in from a function type. + + >>> int32_t, int64_t = ( + ... ts.ScalarType(kind=ts.ScalarKind.INT32), + ... ts.ScalarType(kind=ts.ScalarKind.INT64), + ... ) + >>> type_ = ts.FunctionType( + ... pos_only_args=[], + ... pos_or_kw_args={"inp1": ts.TupleType(types=[int32_t, int64_t])}, + ... kw_only_args={"inp2": int64_t}, + ... returns=int64_t, + ... ) + >>> context = _make_param_context_from_func_type(type_) + >>> assert context == {"inp1": (int32_t, int64_t), "inp2": int64_t} + """ params = func_type.pos_or_kw_args | func_type.kw_only_args return { param: type_info.apply_to_primitive_constituents( @@ -155,8 +171,8 @@ def __call__( """ args, kwargs = type_info.canonicalize_arguments(self.program_type, args, kwargs) static_args_values = self._argument_descriptor_cache_key_from_args(*args, **kwargs) - # TODO: Dispatching over offset provider type is wrong, especially when we use compile time - # domains. + # TODO(tehrengruber): Dispatching over offset provider type is wrong, especially when we + # use compile time domains. key = (static_args_values, self._offset_provider_to_type_unsafe(offset_provider)) try: self._compiled_programs[key](*args, **kwargs, offset_provider=offset_provider) @@ -315,6 +331,8 @@ def _offset_provider_to_type_unsafe( self._offset_provider_type_cache[offset_provider] = op_type return op_type + # TODO(tehrengruber): Rework the interface to allow precompilation with compile time + # domains. def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], From 03bb958259527bbc47bbcd176e472e5ca337cac8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Sep 2025 08:46:10 +0200 Subject: [PATCH 68/93] Cleanup --- src/gt4py/next/ffront/past_process_args.py | 2 +- src/gt4py/next/ffront/past_to_itir.py | 10 ++- src/gt4py/next/ffront/stages.py | 4 +- src/gt4py/next/otf/arguments.py | 14 +-- src/gt4py/next/otf/compiled_program.py | 90 ++++++++++++++----- .../runners/dace/program.py | 2 +- 6 files changed, 88 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 81f9363822..2c9d3d2770 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -34,7 +34,7 @@ def transform_program_args(inp: AOT_PRG) -> AOT_PRG: kwargs=rewritten_kwargs, offset_provider=inp.args.offset_provider, column_axis=inp.args.column_axis, - argument_descriptors=inp.args.argument_descriptors, + argument_descriptor_contexts=inp.args.argument_descriptor_contexts, ), ) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f82680be15..c48d72f41d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -96,16 +96,18 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) - if arguments.StaticArg in inp.args.argument_descriptors: - static_arg_descriptors = inp.args.argument_descriptors[arguments.StaticArg] + # TODO(tehrengruber): Put this in a dedicated transformation step. + if arguments.StaticArg in inp.args.argument_descriptor_contexts: + static_arg_descriptors = inp.args.argument_descriptor_contexts[arguments.StaticArg] if not all( - isinstance(arg_descriptor, arguments.StaticArg) + isinstance(arg_descriptor, arguments.StaticArg) or arg_descriptor is None for arg_descriptor in static_arg_descriptors.values() ): raise NotImplementedError("Only top-level arguments can be static.") static_args = { - name: im.literal_from_tuple_value(descr.value) # type: ignore[attr-defined] # type checked above + name: im.literal_from_tuple_value(descr.value) # type: ignore[union-attr] # type checked above for name, descr in static_arg_descriptors.items() + if descr } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 0c590b6f44..d97de69704 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -132,10 +132,10 @@ def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None # just a small helper to additionally allow sorting types (by just using their name) def key_function(key: Any) -> Any: if isinstance(key, type): - return key + return key.__module__, key.__qualname__ return key - for key in sorted(obj, key=key_function): + for key in sorted(obj.keys(), key=key_function): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(obj[key], hasher) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 9a80a83137..fc439a233c 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -17,7 +17,7 @@ from typing_extensions import Self from gt4py._core import definitions as core_defs -from gt4py.eve import extended_typing +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common, errors from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -75,7 +75,7 @@ def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: @dataclasses.dataclass(frozen=True) class StaticArg(ArgStaticDescriptor, Generic[core_defs.ScalarT]): - value: extended_typing.MaybeNestedInTuple[core_defs.ScalarT] + value: MaybeNestedInTuple[core_defs.ScalarT] def __post_init__(self) -> None: # transform enum value into the actual value @@ -121,9 +121,13 @@ class CompileTimeArgs: kwargs: dict[str, ts.TypeSpec] offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] - argument_descriptors: Mapping[ + #: A mapping from an argument descriptor type to a context containing the actual descriptors. + #: If an argument or element of an argument has no descriptor, the respective value is `None`. + #: E.g., for a tuple argument `a` with type `ts.TupleTupe(types=[field_t, int32_t])` a possible + # context would be `{"a": (FieldDomainDescriptor(...), None)}`. + argument_descriptor_contexts: Mapping[ type[ArgStaticDescriptor], - dict[str, ArgStaticDescriptor], + dict[str, MaybeNestedInTuple[ArgStaticDescriptor | None]], ] @property @@ -141,7 +145,7 @@ def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None }, - argument_descriptors={}, + argument_descriptor_contexts={}, ) @classmethod diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index d31d31a6b7..b428c79e0f 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -36,8 +36,11 @@ type[arguments.ArgStaticDescriptor], dict[str, arguments.ArgStaticDescriptor] ] ArgumentDescriptorContext: TypeAlias = dict[ + str, extended_typing.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None] +] +ArgumentDescriptorContexts: TypeAlias = dict[ type[arguments.ArgStaticDescriptor], - dict[str, extended_typing.MaybeNestedInTuple[arguments.ArgStaticDescriptor | None]], + ArgumentDescriptorContext, ] @@ -108,6 +111,61 @@ def _make_argument_descriptors( return descriptors +def _convert_to_argument_descriptor_context( + func_type: ts.FunctionType, argument_descriptors: ArgumentDescriptors +) -> ArgumentDescriptorContexts: + """ + Given argument descriptors, i.e., a mapping from an expr to a descriptor, transform them into a + context of argument descriptors in which we can evaluate expressions. + + >>> int32_t, int64_t = ( + ... ts.ScalarType(kind=ts.ScalarKind.INT32), + ... ts.ScalarType(kind=ts.ScalarKind.INT64), + ... ) + >>> type_ = ts.FunctionType( + ... pos_only_args=[], + ... pos_or_kw_args={"inp1": ts.TupleType(types=[int32_t, int64_t])}, + ... kw_only_args={"inp2": int64_t}, + ... returns=int64_t, + ... ) + >>> argument_descriptors = {arguments.StaticArg: {"inp1[1]": arguments.StaticArg(value=1)}} + >>> contexts = _convert_to_argument_descriptor_context(type_, argument_descriptors) + >>> contexts[arguments.StaticArg] + {'inp1': (None, StaticArg(value=1)), 'inp2': None} + """ + descriptor_contexts: ArgumentDescriptorContexts = {} + for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): + context: ArgumentDescriptorContext = _make_param_context_from_func_type( + func_type, lambda x: None + ) + # convert tuples to list such that we can alter the context easily + context = { + k: gtx_utils.tree_map( + lambda v: v, collection_type=tuple, result_collection_constructor=list + )(v) + for k, v in context.items() + } + assert "__descriptor" not in context + for expr, descriptor in descriptor_expr_mapping.items(): + # note: we don't need to handle any errors here since the `expr` has been validated + # in `_validate_argument_descriptor_mapping` + exec( + f"{expr} = __descriptor", + {"__descriptor": descriptor}, + context, + ) + # convert lists back to tuples + context = { + k: gtx_utils.tree_map( + lambda v: v, collection_type=list, result_collection_constructor=tuple + )(v) + for k, v in context.items() + } + descriptor_contexts[descriptor_cls] = context + + return descriptor_contexts + + def _validate_argument_descriptors( program_type: ts_ffront.ProgramType, all_descriptors: ArgumentDescriptors, @@ -215,7 +273,7 @@ def _argument_descriptor_cache_key_from_args(self) -> Callable: def _argument_descriptor_cache_key_from_descriptors( self, - argument_descriptors: ArgumentDescriptors, + argument_descriptor_contexts: ArgumentDescriptorContexts, ) -> tuple: """ Given a set of argument descriptors deduce the cache key used to retrieve the instance @@ -224,22 +282,6 @@ def _argument_descriptor_cache_key_from_descriptors( This function is not performance critical as it is only called once when compiling a variant. """ - # first build a context that we can evaluate parameter expressions on descriptors in - descriptor_context: ArgumentDescriptorContext = {} - for descriptor_cls, descriptor_expr_mapping in argument_descriptors.items(): - descriptor_context[descriptor_cls] = _make_param_context_from_func_type( - self.program_type.definition, lambda x: None - ) - assert "__descriptor" not in descriptor_context[descriptor_cls] - for expr, descriptor in descriptor_expr_mapping.items(): - # note: we don't need to handle any errors here since the `expr` has been validated - # in `_validate_argument_descriptor_mapping` - exec( - f"{expr} = __descriptor", - {"__descriptor": descriptor}, - descriptor_context[descriptor_cls], - ) - elements = [] for descriptor_cls, arg_exprs in self.argument_descriptor_mapping.items(): # type: ignore[union-attr] # can never be `None` at this point for arg_expr in arg_exprs: @@ -247,7 +289,10 @@ def _argument_descriptor_cache_key_from_descriptors( attrs = attr_extractor.keys() for attr in attrs: elements.append( - getattr(eval(f"{arg_expr}", descriptor_context[descriptor_cls]), attr) + getattr( + eval(f"{arg_expr}", {}, argument_descriptor_contexts[descriptor_cls]), + attr, + ) ) return tuple(elements) @@ -297,8 +342,11 @@ def _compile_variant( self._initialize_argument_descriptor_mapping(argument_descriptors) _validate_argument_descriptors(self.program_type, argument_descriptors) + argument_descriptor_contexts = _convert_to_argument_descriptor_context( + self.program_type.definition, argument_descriptors + ) key = ( - self._argument_descriptor_cache_key_from_descriptors(argument_descriptors), + self._argument_descriptor_cache_key_from_descriptors(argument_descriptor_contexts), self._offset_provider_to_type_unsafe(offset_provider), ) if key in self._compiled_programs: @@ -310,7 +358,7 @@ def _compile_variant( args=tuple(self.program_type.definition.pos_only_args) + tuple(self.program_type.definition.pos_or_kw_args.values()), kwargs=self.program_type.definition.kw_only_args, - argument_descriptors=argument_descriptors, + argument_descriptor_contexts=argument_descriptor_contexts, ) self._compiled_programs[key] = _async_compilation_pool.submit( self.backend.compile, self.definition_stage, compile_time_args=compile_time_args diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py index 6e5fcc29c2..df357f69aa 100644 --- a/src/gt4py/next/program_processors/runners/dace/program.py +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -55,7 +55,7 @@ def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: kwargs={}, column_axis=column_axis, offset_provider=offset_provider, - argument_descriptors={}, + argument_descriptor_contexts={}, ), ) ) From 886df4d586c4c83e790ca8ce8f357cfea9d2a2fc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Sep 2025 09:15:42 +0200 Subject: [PATCH 69/93] Static domains for args of inhomogenous types --- src/gt4py/next/ffront/decorator.py | 5 ++-- src/gt4py/next/ffront/past_to_itir.py | 6 ++--- .../ffront_tests/test_compiled_program.py | 27 ++++++++++++------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9c5acf1033..578e5ab7ac 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -286,8 +286,9 @@ def path_to_expr(path: Sequence[int]) -> str: func_type = self.past_stage.past_node.type.definition # type: ignore[union-attr] # type inference done at this point param_types = func_type.pos_or_kw_args | func_type.kw_only_args for name, type_ in param_types.items(): - for _, path in type_info.primitive_constituents(type_, with_path_arg=True): - static_domain_args.append(f"{name}{path_to_expr(path)}") + for type_, path in type_info.primitive_constituents(type_, with_path_arg=True): + if isinstance(type_, ts.FieldType): + static_domain_args.append(f"{name}{path_to_expr(path)}") argument_descriptor_mapping[arguments.FieldDomainDescriptor] = static_domain_args program_type = self.past_stage.past_node.type diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 74a951565f..115ffd10b5 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -119,10 +119,10 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ) # TODO(tehrengruber): Put this in a dedicated transformation step. - if arguments.FieldDomainDescriptor in inp.args.argument_descriptors: + if arguments.FieldDomainDescriptor in inp.args.argument_descriptor_contexts: + context = inp.args.argument_descriptor_contexts[arguments.FieldDomainDescriptor] field_domains = { - param: utils.tree_map(lambda x: x.domain)(v) - for param, v in inp.args.argument_descriptors[arguments.FieldDomainDescriptor].items() + param: utils.tree_map(lambda x: x.domain if x else x)(v) for param, v in context.items() } itir_program = transform_get_domain_range.TransformGetDomainRange.apply( itir_program, sizes=field_domains diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 3cdf52a232..6b221fa858 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -832,23 +832,32 @@ def compile(self, program, compile_time_args): return cartesian_case.backend.compile(program, compile_time_args) @gtx.field_operator - def identity(inp: cases.IField): - return inp + def identity_like(inp: tuple[cases.IField, cases.IField, float]): + return inp[0], inp[1] + # the float argument here is merely to test that static domains work for tuple arguments + # of inhomogeneous types @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) - def testee(inp: tuple[cases.IField, cases.IField], out: cases.IField): - identity(inp, out=out) + def testee(inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField]): + identity_like(inp, out=out) inp = cases.allocate(cartesian_case, testee, "inp")() out = cases.allocate(cartesian_case, testee, "out")() testee(inp, out, offset_provider={}) - assert np.allclose(out.ndarray, inp.ndarray) + assert np.allclose((inp[0].ndarray, inp[1].ndarray), (out[0].ndarray, out[1].ndarray)) assert testee._compiled_programs.argument_descriptor_mapping[ arguments.FieldDomainDescriptor - ] == ["inp", "out"] - assert captured_cargs.argument_descriptors[arguments.FieldDomainDescriptor] == { - "inp": arguments.FieldDomainDescriptor(inp.domain), - "out": arguments.FieldDomainDescriptor(out.domain), + ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] + assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { + "inp": ( + arguments.FieldDomainDescriptor(inp[0].domain), + arguments.FieldDomainDescriptor(inp[1].domain), + None + ), + "out": ( + arguments.FieldDomainDescriptor(out[0].domain), + arguments.FieldDomainDescriptor(out[1].domain), + ), } From 2fd1c2b16c36fb92e22345d806121a25d996e634 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Sep 2025 10:57:44 +0200 Subject: [PATCH 70/93] Fix non scan projector --- .../next/iterator/transforms/global_tmps.py | 5 ---- .../transforms_tests/test_global_tmps.py | 30 ++++++++++++++++++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5bee4daf8d..092c3291bf 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -157,11 +157,6 @@ def _transform_by_pattern( # hide projector from extraction projector, expr = ir_utils_misc.extract_projector(stmt.expr) - # If we extracted a projector and the expression is not an as_fieldop of a scan, - # collapse tuple did not work as expected. We would expect that collapse - # tuple eleminated all top-level tuple expressions for non-scans. - assert projector is None or _is_as_fieldop_of_scan(expr) - new_expr, extracted_fields, _ = cse.extract_subexpression( expr, predicate=predicate, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 6fe76542ca..4a58eb9308 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -11,10 +11,13 @@ import functools from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import global_tmps, infer_domain, collapse_tuple from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.ir_utils import ( + ir_makers as im, + misc as ir_utils_misc, +) IDim = common.Dimension(value="IDim") @@ -535,3 +538,28 @@ def test_domain_preservation(): actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected + + +def test_non_scan_projector(): + domain = im.domain("cartesian_domain", {IDim: (0, 2)}) + offset_provider = {} + stmt = itir.SetAt( + target=im.ref("out"), + expr=im.make_tuple(im.tuple_get(0, "inp")), + domain=domain, + ) + testee = program_factory( + params=[ + im.sym("inp", ts.TupleType(types=[i_field_type, float_type])), + im.sym("out", ts.TupleType(types=[i_field_type])), + ], + body=[stmt], + ) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) + + # make sure the statement actually has a projector + projector, expr = ir_utils_misc.extract_projector(stmt.expr) + assert projector is not None + + actual = global_tmps.create_global_tmps(testee, offset_provider) + assert actual == testee From bc29ac349794a4e1e4c0c995114ffe021995009c Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Sep 2025 11:04:42 +0200 Subject: [PATCH 71/93] Small fix --- src/gt4py/next/ffront/decorator.py | 4 ++-- .../feature_tests/ffront_tests/test_compiled_program.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 578e5ab7ac..4420cb9a83 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -286,8 +286,8 @@ def path_to_expr(path: Sequence[int]) -> str: func_type = self.past_stage.past_node.type.definition # type: ignore[union-attr] # type inference done at this point param_types = func_type.pos_or_kw_args | func_type.kw_only_args for name, type_ in param_types.items(): - for type_, path in type_info.primitive_constituents(type_, with_path_arg=True): - if isinstance(type_, ts.FieldType): + for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): + if isinstance(el_type_, ts.FieldType): static_domain_args.append(f"{name}{path_to_expr(path)}") argument_descriptor_mapping[arguments.FieldDomainDescriptor] = static_domain_args diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 6b221fa858..5d34b219b7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -838,7 +838,9 @@ def identity_like(inp: tuple[cases.IField, cases.IField, float]): # the float argument here is merely to test that static domains work for tuple arguments # of inhomogeneous types @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) - def testee(inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField]): + def testee( + inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField] + ): identity_like(inp, out=out) inp = cases.allocate(cartesian_case, testee, "inp")() @@ -854,7 +856,7 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IFiel "inp": ( arguments.FieldDomainDescriptor(inp[0].domain), arguments.FieldDomainDescriptor(inp[1].domain), - None + None, ), "out": ( arguments.FieldDomainDescriptor(out[0].domain), From b8619a730892ebe34b9bbbc79b1f3aa1dd5c6a8e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Sep 2025 11:09:23 +0200 Subject: [PATCH 72/93] Fix tests --- src/gt4py/next/ffront/foast_to_past.py | 2 +- src/gt4py/next/ffront/past_to_itir.py | 2 +- tests/next_tests/unit_tests/otf_tests/test_compiled_program.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 272823ad6b..66b0f31034 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -75,7 +75,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... kwargs={}, ... offset_provider={"I", IDim}, ... column_axis=None, - ... argument_descriptors={}, + ... argument_descriptor_contexts={}, ... ) >>> copy_program = op_to_prog( diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c48d72f41d..2dfe22db7a 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -58,7 +58,7 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ... kwargs={}, ... offset_provider={"I": IDim}, ... column_axis=None, - ... argument_descriptors={}, + ... argument_descriptor_contexts={}, ... ) >>> itir_copy = past_to_gtir( diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 1c132773e0..98c48c294a 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -79,7 +79,7 @@ def test_inlining_of_scalars_works(): kwargs={}, offset_provider={}, column_axis=None, - argument_descriptors={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}}, + argument_descriptor_contexts={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}}, ), ) From 807a1fbed26f63c66c6e31a07a69df31a6eaa2be Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 25 Sep 2025 11:10:47 +0200 Subject: [PATCH 73/93] Fix format --- .../next_tests/unit_tests/otf_tests/test_compiled_program.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 98c48c294a..59b12b3f0a 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -79,7 +79,9 @@ def test_inlining_of_scalars_works(): kwargs={}, offset_provider={}, column_axis=None, - argument_descriptor_contexts={arguments.StaticArg: {"cond": arguments.StaticArg(value=True)}}, + argument_descriptor_contexts={ + arguments.StaticArg: {"cond": arguments.StaticArg(value=True)} + }, ), ) From f9648b3d6ce74575f83e924c890453a2f8020f42 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 26 Sep 2025 11:31:06 +0200 Subject: [PATCH 74/93] Make static domain translation configurable --- src/gt4py/eve/utils.py | 2 +- .../next/iterator/transforms/pass_manager.py | 56 ++++++++++++++----- .../codegens/gtfn/gtfn_module.py | 2 + .../runners/dace/workflow/translation.py | 7 ++- 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index b681c41295..3de4414851 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -910,7 +910,7 @@ def reset_sequence(self, start: int = 1, *, warn_unsafe: Optional[bool] = None) if warn_unsafe is None: warn_unsafe = self.warn_unsafe if warn_unsafe and start < next(self._counter): - warnings.warn("Unsafe reset of UIDGenerator ({self})", stacklevel=2) + warnings.warn(f"Unsafe reset of UIDGenerator ({self})", stacklevel=2) self._counter = itertools.count(start) return self diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 7220951287..b36be31950 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import warnings from typing import Any, Mapping, Optional, Protocol from gt4py.eve import utils as eve_utils @@ -67,7 +68,7 @@ def _max_domain_range_sizes(offset_provider: Mapping[str, Any]) -> dict[str, iti def _has_dynamic_domains(ir: itir.Program) -> bool: # note: this function does not respect symbol collisions with builtins. As it is a temporary - # workaround we don't care about this corner cases. + # workaround we don't care about this corner case. domains = set() domains |= ir.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() for as_fop in ( @@ -79,6 +80,30 @@ def _has_dynamic_domains(ir: itir.Program) -> bool: return len(symbol_ref_utils.collect_symbol_refs(domains)) > 0 +def _process_symbolic_domains_option( + ir: itir.Program, + offset_provider: common.OffsetProvider | common.OffsetProviderType, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], + use_max_domain_range_on_unstructured_shift: Optional[bool], +) -> Optional[dict[str, str | itir.Expr]]: + has_dynamic_domains = _has_dynamic_domains(ir) + if has_dynamic_domains and use_max_domain_range_on_unstructured_shift is None: + use_max_domain_range_on_unstructured_shift = True + else: + use_max_domain_range_on_unstructured_shift = False + if use_max_domain_range_on_unstructured_shift: + if not has_dynamic_domains: + warnings.warn( + "You are using static domains together with " + "'use_max_domain_range_on_unstructured_shift'. This is" + "likely not what you wanted.", + stacklevel=2, + ) + assert not symbolic_domain_sizes, "Options are mutually exclusive." + symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] + return symbolic_domain_sizes + + # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( @@ -94,16 +119,17 @@ def apply_common_transforms( #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins + # to work with / translate domains. + use_max_domain_range_on_unstructured_shift: Optional[bool] = None, ) -> itir.Program: assert isinstance(ir, itir.Program) offset_provider_type = common.offset_provider_to_type(offset_provider) - # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins - # to work with / translate domains. - if _has_dynamic_domains(ir): - assert not symbolic_domain_sizes, "Options are mutually exclusive." - symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] + symbolic_domain_sizes = _process_symbolic_domains_option( + ir, offset_provider, symbolic_domain_sizes, use_max_domain_range_on_unstructured_shift + ) tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() @@ -211,16 +237,18 @@ def apply_common_transforms( def apply_fieldview_transforms( - ir: itir.Program, *, offset_provider: common.OffsetProvider + ir: itir.Program, + *, + offset_provider: common.OffsetProvider, + # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins + # to work with / translate domains. + use_max_domain_range_on_unstructured_shift: Optional[bool] = None, ) -> itir.Program: offset_provider_type = common.offset_provider_to_type(offset_provider) - # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins - # to work with / translate domains. - if _has_dynamic_domains(ir): - symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) - else: - symbolic_domain_sizes = None + symbolic_domain_sizes = _process_symbolic_domains_option( + ir, offset_provider, None, use_max_domain_range_on_unstructured_shift + ) ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) @@ -235,7 +263,7 @@ def apply_fieldview_transforms( ir = infer_domain.infer_program( ir, - symbolic_domain_sizes=symbolic_domain_sizes, # type: ignore[arg-type] + symbolic_domain_sizes=symbolic_domain_sizes, offset_provider=offset_provider, ) ir = remove_broadcast.RemoveBroadcast.apply(ir) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ed1b657129..87855d7d13 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -53,6 +53,7 @@ class GTFNTranslationStep( use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None + use_max_domain_range_on_unstructured_shift: Optional[bool] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -163,6 +164,7 @@ def _preprocess_program( extract_temporaries=True, offset_provider=offset_provider, symbolic_domain_sizes=self.symbolic_domain_sizes, + use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, ) new_program = apply_common_transforms( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 083d661b90..8010111901 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -191,6 +191,7 @@ class DaCeTranslator( async_sdfg_call: bool = False disable_itir_transforms: bool = False disable_field_origin_on_program_arguments: bool = False + use_max_domain_range_on_unstructured_shift: Optional[bool] = None # auto-optimize arguments gpu_block_size: tuple[int, int, int] = (32, 8, 1) @@ -216,7 +217,11 @@ def _generate_sdfg_without_configuring_dace( column_axis: Optional[common.Dimension], ) -> dace.SDFG: if not self.disable_itir_transforms: - ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + ir = itir_transforms.apply_fieldview_transforms( + ir, + use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, + offset_provider=offset_provider, + ) offset_provider_type = common.offset_provider_to_type(offset_provider) on_gpu = self.device_type != core_defs.DeviceType.CPU From cac49f61b3a4896bda35cdf2abaa79d060aa8520 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 26 Sep 2025 14:05:49 +0200 Subject: [PATCH 75/93] Fix --- src/gt4py/next/ffront/past_to_itir.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 2dfe22db7a..628efb001c 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -15,7 +15,7 @@ import devtools from gt4py.eve import NodeTranslator, concepts, traits, utils as eve_utils -from gt4py.next import common, config, errors +from gt4py.next import common, config, errors, utils as gtx_utils from gt4py.next.ffront import ( fbuiltins, gtcallable, @@ -100,14 +100,15 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: if arguments.StaticArg in inp.args.argument_descriptor_contexts: static_arg_descriptors = inp.args.argument_descriptor_contexts[arguments.StaticArg] if not all( - isinstance(arg_descriptor, arguments.StaticArg) or arg_descriptor is None + isinstance(arg_descriptor, arguments.StaticArg) + or all(el is None for el in gtx_utils.flatten_nested_tuple(arg_descriptor)) # type: ignore[arg-type] for arg_descriptor in static_arg_descriptors.values() ): raise NotImplementedError("Only top-level arguments can be static.") static_args = { name: im.literal_from_tuple_value(descr.value) # type: ignore[union-attr] # type checked above for name, descr in static_arg_descriptors.items() - if descr + if not any(el is None for el in gtx_utils.flatten_nested_tuple(descr)) # type: ignore[arg-type] } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( From 3cb901f176e5c5333d56c0dc039a9fc42ec22a71 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 30 Sep 2025 14:09:35 +0200 Subject: [PATCH 76/93] Improve oob error msg --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 8 ++++---- .../iterator_tests/ir_utils_test.py/test_domain_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 266c34021f..56779ceb55 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -157,10 +157,10 @@ def translate( accessed = connectivity.ndarray[start:stop, nb_index] - if np.any(accessed == skip_value): - raise NotImplementedError( - f"Translating '{self.as_expr()}' using '{shift[0].value}' contains " - f"skipped values. This is not supported." + if isinstance(val, itir.OffsetLiteral) and np.any(accessed == skip_value): + raise ValueError( + f"Translating '{self.as_expr()}' using '{shift[0].value}' has " + f"an out-of-bounds access." ) new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index a0280fc040..cabfedcbf0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -264,7 +264,7 @@ def test_non_contiguous_domain_warning(): domain.translate(shift_chain, offset_provider).as_expr() -def test_contains_skip_values_error(): +def test_oob_error(): offset_provider = { "V2V": constructors.as_connectivity( domain={Vertex: (0, 3), V2VDim: 1}, @@ -277,5 +277,5 @@ def test_contains_skip_values_error(): domain = domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 3)}) ) - with pytest.raises(NotImplementedError, match=r"contains skipped values"): + with pytest.raises(ValueError, match=r"out-of-bounds"): domain.translate(shift_chain, offset_provider).as_expr() From 44e02c781d839d382e0a6995b0761473d35ff092 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Tue, 7 Oct 2025 09:53:03 +0200 Subject: [PATCH 77/93] Fix dace tests --- .../dace_tests/test_gtir_to_sdfg.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 5a62da5bfe..19fe783883 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -22,6 +22,7 @@ from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import infer_domain +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -123,7 +124,11 @@ def build_dace_sdfg( ) -> Callable[..., Any]: if not skip_domain_inference: # run domain inference in order to add the domain annex information to the IR nodes - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program( + ir, + offset_provider=offset_provider, + symbolic_domain_sizes=pass_manager._max_domain_range_sizes(offset_provider), + ) offset_provider_type = gtx_common.offset_provider_to_type(offset_provider) return dace_backend.build_sdfg_from_gtir(ir, offset_provider_type, column_axis=KDim) @@ -2370,7 +2375,11 @@ def test_gtir_concat_where(): ) # run domain inference in order to add the domain annex information to the concat_where node. - testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + testee = infer_domain.infer_program( + testee, + offset_provider=CARTESIAN_OFFSETS, + symbolic_domain_sizes=pass_manager._max_domain_range_sizes(CARTESIAN_OFFSETS), + ) sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) c = np.empty_like(a) @@ -2454,7 +2463,11 @@ def test_gtir_concat_where_two_dimensions(): } # run domain inference in order to add the domain annex information to the concat_where node. - testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + testee = infer_domain.infer_program( + testee, + offset_provider=CARTESIAN_OFFSETS, + symbolic_domain_sizes=pass_manager._max_domain_range_sizes(CARTESIAN_OFFSETS), + ) sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, d, **field_symbols) From d75a092127407c8d5e5d00ef3ddc8e3a42b50f32 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 7 Oct 2025 10:48:41 +0200 Subject: [PATCH 78/93] Fix failing tests --- .../next/iterator/ir_utils/domain_utils.py | 14 ++++++++++--- .../next/iterator/transforms/pass_manager.py | 4 ++++ .../transforms_tests/test_domain_inference.py | 3 ++- .../test_transform_get_domain_range.py | 20 ++++++++++++++++--- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 56779ceb55..c36cef2b3c 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -158,9 +158,17 @@ def translate( accessed = connectivity.ndarray[start:stop, nb_index] if isinstance(val, itir.OffsetLiteral) and np.any(accessed == skip_value): - raise ValueError( - f"Translating '{self.as_expr()}' using '{shift[0].value}' has " - f"an out-of-bounds access." + # TODO(tehrengruber): Turn this into a configurable error. This is currently + # not possible since some test cases starting from ITIR containing + # `can_deref` might lead here. The frontend never emits such IR and domain + # inference runs after we transform reductions into stmts containing + # `can_deref`. + warnings.warn( + UserWarning( + f"Translating '{self.as_expr()}' using '{shift[0].value}' has " + f"an out-of-bounds access." + ), + stacklevel=2, ) new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b36be31950..2df1f8b45f 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -86,6 +86,10 @@ def _process_symbolic_domains_option( symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], use_max_domain_range_on_unstructured_shift: Optional[bool], ) -> Optional[dict[str, str | itir.Expr]]: + if symbolic_domain_sizes: + assert not use_max_domain_range_on_unstructured_shift, "Options are mutually exclusive." + return symbolic_domain_sizes + has_dynamic_domains = _has_dynamic_domains(ir) if has_dynamic_domains and use_max_domain_range_on_unstructured_shift is None: use_max_domain_range_on_unstructured_shift = True diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 0d9a55ceef..dd405ca7b0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -229,7 +229,8 @@ def test_multi_length_shift(offset_provider): def test_unstructured_shift(unstructured_offset_provider): stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) - expected_domains = {"in_field1": {Vertex: (0, 2)}} + accessed_vertex = unstructured_offset_provider["E2V"].ndarray[0, 1] + expected_domains = {"in_field1": {Vertex: (accessed_vertex, accessed_vertex + np.int32(1))}} testee, expected = setup_test_as_fieldop(stencil, domain, expected_domains=expected_domains) run_test_expr(testee, expected, domain, expected_domains, unstructured_offset_provider) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py index 5ac023fac7..4e8ef94cc7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py @@ -75,7 +75,13 @@ def test_get_domain(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} get_domain_expr = im.get_field_domain(common.GridType.UNSTRUCTURED, "out", sizes["out"].dims) - run_test_program(["inp", "out"], sizes, "out", im.domain_as_expr(sizes["out"]), get_domain_expr) + run_test_program( + ["inp", "out"], + sizes, + "out", + im.domain(common.GridType.UNSTRUCTURED, sizes["out"]), + get_domain_expr, + ) def test_get_domain_tuples(): @@ -86,7 +92,11 @@ def test_get_domain_tuples(): ) run_test_program( - ["inp", "out"], sizes, "out", im.domain_as_expr(sizes["out"][1]), get_domain_expr + ["inp", "out"], + sizes, + "out", + im.domain(common.GridType.UNSTRUCTURED, sizes["out"][1]), + get_domain_expr, ) @@ -105,5 +115,9 @@ def test_get_domain_nested_tuples(): ) run_test_program( - ["inp", "a", "b", "c", "d"], sizes, "a", im.domain_as_expr(sizes["a"]), get_domain_expr + ["inp", "a", "b", "c", "d"], + sizes, + "a", + im.domain(common.GridType.UNSTRUCTURED, sizes["a"]), + get_domain_expr, ) From 7db7d24082304f0b5c48ed65925966b590bb2f74 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 8 Oct 2025 11:40:15 +0200 Subject: [PATCH 79/93] Fix dace tests --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 19fe783883..7815620fac 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -2374,12 +2374,6 @@ def test_gtir_concat_where(): ], ) - # run domain inference in order to add the domain annex information to the concat_where node. - testee = infer_domain.infer_program( - testee, - offset_provider=CARTESIAN_OFFSETS, - symbolic_domain_sizes=pass_manager._max_domain_range_sizes(CARTESIAN_OFFSETS), - ) sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) c = np.empty_like(a) @@ -2462,12 +2456,6 @@ def test_gtir_concat_where_two_dimensions(): "__z_stride_1": d.strides[1] // d.itemsize, } - # run domain inference in order to add the domain annex information to the concat_where node. - testee = infer_domain.infer_program( - testee, - offset_provider=CARTESIAN_OFFSETS, - symbolic_domain_sizes=pass_manager._max_domain_range_sizes(CARTESIAN_OFFSETS), - ) sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, d, **field_symbols) From 7098ccfd352ddceb8aa0e31ecc62ec306a9782a2 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 8 Oct 2025 12:14:52 +0200 Subject: [PATCH 80/93] Small fixes --- src/gt4py/next/iterator/transforms/pass_manager.py | 2 +- .../iterator_tests/ir_utils_test.py/test_domain_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2df1f8b45f..f8dc7f2ac8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -60,7 +60,7 @@ def _max_domain_range_sizes(offset_provider: Mapping[str, Any]) -> dict[str, iti ) sizes[conn_type.codomain.value] = max( sizes.get(conn_type.codomain.value, 0), - provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + int(provider.ndarray.max()) + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) sizes_exprs = {k: im.literal_from_value(v) for k, v in sizes.items()} return sizes_exprs diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index 2195bc7e76..b361c5437e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -278,7 +278,7 @@ def test_oob_error(): im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 3)}) ) with pytest.warns( - UserWarning, - match=r"out-of-bounds", + UserWarning, + match=r"out-of-bounds", ): domain.translate(shift_chain, offset_provider).as_expr() From 1cf559c14ae5ee57907ec83136967839431d13a3 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 8 Oct 2025 12:25:31 +0200 Subject: [PATCH 81/93] Small fixes --- .../feature_tests/ffront_tests/test_compiled_program.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 14af73f7bc..368b37044e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -897,7 +897,8 @@ def testee( out = cases.allocate(cartesian_case, testee, "out")() testee(inp, out, offset_provider={}) - assert np.allclose((inp[0].ndarray, inp[1].ndarray), (out[0].ndarray, out[1].ndarray)) + assert np.allclose(inp[0].ndarray, out[0].ndarray) + assert np.allclose(inp[1].ndarray, out[1].ndarray) assert testee._compiled_programs.argument_descriptor_mapping[ arguments.FieldDomainDescriptor From 6b263239d8680784dd315fa2998c1b5d2314929d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 13 Nov 2025 14:26:09 +0100 Subject: [PATCH 82/93] Fix format --- src/gt4py/next/ffront/decorator.py | 1 - src/gt4py/next/ffront/past_to_itir.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index af5ce39a3c..11f96da285 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -252,7 +252,6 @@ def with_bound_args(self, **kwargs: Any) -> ProgramWithBoundArgs: }, ) - @functools.cached_property def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.backend is None or self.backend == eve.NOTHING: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 50f572b91f..d351d8ab99 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -14,7 +14,7 @@ import devtools -from gt4py.eve import NodeTranslator, traits, utils as eve_utils +from gt4py.eve import NodeTranslator, traits from gt4py.next import common, config, errors, utils from gt4py.next.ffront import ( fbuiltins, From 39312c2c8d5b97252401453756708bd968b4f853 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 13 Nov 2025 14:27:52 +0100 Subject: [PATCH 83/93] Fix broken merge --- src/gt4py/next/ffront/decorator.py | 54 ++++++++++-------------------- 1 file changed, 17 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 11f96da285..259dc92a74 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -173,12 +173,23 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.backend is None or self.backend == eve.NOTHING: raise RuntimeError("Cannot compile a program without backend.") - if self.static_params is None: - object.__setattr__(self, "static_params", ()) + def path_to_expr(path: Sequence[int]) -> str: + return "".join(map(lambda idx: f"[{idx}]", path)) - argument_descriptor_mapping = { - arguments.StaticArg: self.static_params, - } + argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} + + if self.static_params: + argument_descriptor_mapping[arguments.StaticArg] = self.static_params + + if self.static_domains: + static_domain_args = [] + func_type = self.past_stage.past_node.type.definition # type: ignore[union-attr] # type inference done at this point + param_types = func_type.pos_or_kw_args | func_type.kw_only_args + for name, type_ in param_types.items(): + for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): + if isinstance(el_type_, ts.FieldType): + static_domain_args.append(f"{name}{path_to_expr(path)}") + argument_descriptor_mapping[arguments.FieldDomainDescriptor] = static_domain_args program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) @@ -186,7 +197,7 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, - argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible + argument_descriptor_mapping=argument_descriptor_mapping, ) def with_backend(self, backend: next_backend.Backend) -> Program: @@ -252,37 +263,6 @@ def with_bound_args(self, **kwargs: Any) -> ProgramWithBoundArgs: }, ) - @functools.cached_property - def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: - if self.backend is None or self.backend == eve.NOTHING: - raise RuntimeError("Cannot compile a program without backend.") - - def path_to_expr(path: Sequence[int]) -> str: - return "".join(map(lambda idx: f"[{idx}]", path)) - - argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} - - if self.static_params: - argument_descriptor_mapping[arguments.StaticArg] = self.static_params - - if self.static_domains: - static_domain_args = [] - func_type = self.past_stage.past_node.type.definition # type: ignore[union-attr] # type inference done at this point - param_types = func_type.pos_or_kw_args | func_type.kw_only_args - for name, type_ in param_types.items(): - for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): - if isinstance(el_type_, ts.FieldType): - static_domain_args.append(f"{name}{path_to_expr(path)}") - argument_descriptor_mapping[arguments.FieldDomainDescriptor] = static_domain_args - - program_type = self.past_stage.past_node.type - assert isinstance(program_type, ts_ffront.ProgramType) - return compiled_program.CompiledProgramsPool( - backend=self.backend, - definition_stage=self.definition_stage, - program_type=program_type, - argument_descriptor_mapping=argument_descriptor_mapping, - ) def __call__( self, From f1efb3f6f47eff582590c37572117febe8c6dd98 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 13 Nov 2025 14:28:47 +0100 Subject: [PATCH 84/93] Fix format --- src/gt4py/next/ffront/decorator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 259dc92a74..e466ddafdb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -263,7 +263,6 @@ def with_bound_args(self, **kwargs: Any) -> ProgramWithBoundArgs: }, ) - def __call__( self, *args: Any, From 888a2a9936fc16e619e5a36586fac9eba4106fb8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 19 Nov 2025 17:16:18 +0100 Subject: [PATCH 85/93] Fix failing test in dace --- .../runners_tests/dace_tests/test_dace_domain.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py index 19314bc2d8..511d8dbcab 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py @@ -37,14 +37,14 @@ ], ) def test_simplify_domain_expr(param): - domain_expr = im.domain( + domain_expr = domain_utils.SymbolicDomain.from_expr(im.domain( gtx_common.GridType.CARTESIAN, { Cell: ("horizontal_start", "horizontal_end"), KDim: ("vertical_start", "vertical_end"), }, - ) - domain = gtx_dace_domain.extract_domain(domain_expr) + )) + domain = gtx_dace_domain.get_field_domain(domain_expr) expr = dace.symbolic.pystr_to_symbolic(param[0]) expected_expr = dace.symbolic.pystr_to_symbolic(param[1]) @@ -56,21 +56,21 @@ def test_gtir_domain(): Vertex = gtx_common.Dimension(value="Vertex", kind=gtx_common.DimensionKind.HORIZONTAL) KDim = gtx_common.Dimension(value="KDim", kind=gtx_common.DimensionKind.VERTICAL) - ir = im.domain( + ir = domain_utils.SymbolicDomain.from_expr(im.domain( gtx_common.GridType.UNSTRUCTURED, { Vertex: (1, 10), KDim: (2, 20), }, - ) + )) - assert gtir_domain.extract_domain(ir) == [ - gtir_domain.FieldopDomainRange( + assert gtx_dace_domain.get_field_domain(ir) == [ + gtx_dace_domain.FieldopDomainRange( Vertex, 1, 10, ), - gtir_domain.FieldopDomainRange( + gtx_dace_domain.FieldopDomainRange( KDim, 2, 20, From 86556af1aaf4b338a0f3716073696bb9489a8167 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 19 Nov 2025 17:17:08 +0100 Subject: [PATCH 86/93] Fix format --- .../dace_tests/test_dace_domain.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py index 511d8dbcab..0c899ead44 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py @@ -37,13 +37,15 @@ ], ) def test_simplify_domain_expr(param): - domain_expr = domain_utils.SymbolicDomain.from_expr(im.domain( - gtx_common.GridType.CARTESIAN, - { - Cell: ("horizontal_start", "horizontal_end"), - KDim: ("vertical_start", "vertical_end"), - }, - )) + domain_expr = domain_utils.SymbolicDomain.from_expr( + im.domain( + gtx_common.GridType.CARTESIAN, + { + Cell: ("horizontal_start", "horizontal_end"), + KDim: ("vertical_start", "vertical_end"), + }, + ) + ) domain = gtx_dace_domain.get_field_domain(domain_expr) expr = dace.symbolic.pystr_to_symbolic(param[0]) @@ -56,13 +58,15 @@ def test_gtir_domain(): Vertex = gtx_common.Dimension(value="Vertex", kind=gtx_common.DimensionKind.HORIZONTAL) KDim = gtx_common.Dimension(value="KDim", kind=gtx_common.DimensionKind.VERTICAL) - ir = domain_utils.SymbolicDomain.from_expr(im.domain( - gtx_common.GridType.UNSTRUCTURED, - { - Vertex: (1, 10), - KDim: (2, 20), - }, - )) + ir = domain_utils.SymbolicDomain.from_expr( + im.domain( + gtx_common.GridType.UNSTRUCTURED, + { + Vertex: (1, 10), + KDim: (2, 20), + }, + ) + ) assert gtx_dace_domain.get_field_domain(ir) == [ gtx_dace_domain.FieldopDomainRange( From 52e67d9adab16586dd8306fb51c8ca2feae035b6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 26 Nov 2025 20:37:29 +0100 Subject: [PATCH 87/93] Address review comments --- src/gt4py/next/ffront/decorator.py | 22 ++++---- src/gt4py/next/ffront/past_to_itir.py | 6 +- .../next/iterator/ir_utils/domain_utils.py | 25 ++++----- .../next/iterator/transforms/pass_manager.py | 6 +- .../ir_utils_test.py/test_domain_utils.py | 4 +- .../dace_tests/test_dace_domain.py | 56 ------------------- 6 files changed, 31 insertions(+), 88 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e466ddafdb..71f0d89110 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -54,6 +54,16 @@ DEFAULT_BACKEND: next_backend.Backend | None = None +def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) -> list[str]: + static_domain_args = [] + param_types = func_type.pos_or_kw_args | func_type.kw_only_args + for name, type_ in param_types.items(): + for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): + if isinstance(el_type_, ts.FieldType): + path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) + static_domain_args.append(f"{name}{path_as_expr}") + return static_domain_args + # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -173,23 +183,13 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.backend is None or self.backend == eve.NOTHING: raise RuntimeError("Cannot compile a program without backend.") - def path_to_expr(path: Sequence[int]) -> str: - return "".join(map(lambda idx: f"[{idx}]", path)) - argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} if self.static_params: argument_descriptor_mapping[arguments.StaticArg] = self.static_params if self.static_domains: - static_domain_args = [] - func_type = self.past_stage.past_node.type.definition # type: ignore[union-attr] # type inference done at this point - param_types = func_type.pos_or_kw_args | func_type.kw_only_args - for name, type_ in param_types.items(): - for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): - if isinstance(el_type_, ts.FieldType): - static_domain_args.append(f"{name}{path_to_expr(path)}") - argument_descriptor_mapping[arguments.FieldDomainDescriptor] = static_domain_args + argument_descriptor_mapping[arguments.FieldDomainDescriptor] = _field_domain_descriptor_mapping_from_func_type(self.past_stage.past_node.type.definition) program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index d351d8ab99..9194e87c2e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -15,6 +15,7 @@ import devtools from gt4py.eve import NodeTranslator, traits +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common, config, errors, utils from gt4py.next.ffront import ( fbuiltins, @@ -121,10 +122,9 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ) # TODO(tehrengruber): Put this in a dedicated transformation step. - if arguments.FieldDomainDescriptor in inp.args.argument_descriptor_contexts: - context = inp.args.argument_descriptor_contexts[arguments.FieldDomainDescriptor] + if (context := inp.args.argument_descriptor_contexts.get(arguments.FieldDomainDescriptor, None)): field_domains = { - param: utils.tree_map(lambda x: x.domain if x else x)(v) for param, v in context.items() + param: utils.tree_map(lambda x: x.domain if x is not None else x)(v) for param, v in context.items() } itir_program = transform_get_domain_range.TransformGetDomainRange.apply( itir_program, sizes=field_domains diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 7eebe984ae..152d7b930f 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -26,11 +26,8 @@ #: to have a contiguous domain before a warning is raised. NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD: float = 1 / 4 -#: Skip printing warnings after exceeding `NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD` this many times. -NON_CONTIGUOUS_DOMAIN_MAX_WARNINGS: int = 5 - -#: Number of warnings raised after exceeding `NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD` -_NON_CONTIGUOUS_DOMAIN_WARNING_COUNTER: int = 0 +#: Offset tags for which a non-contiguous domain warning has already been printed +_NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS: set[str] = set() @dataclasses.dataclass(frozen=True) @@ -152,17 +149,17 @@ def translate( skip_value = connectivity.skip_value # fold & convert expr into actual integers - range_exprs = new_ranges[old_dim].start, new_ranges[old_dim].stop - range_exprs = tuple( + start_expr, stop_expr = new_ranges[old_dim].start, new_ranges[old_dim].stop + start_expr, stop_expr = ( collapse_tuple.CollapseTuple.apply( expr, within_stencil=False, allow_undeclared_symbols=True, ) - for expr in range_exprs + for expr in (start_expr, stop_expr) ) # type: ignore[assignment] # mypy not smart enough - assert all(isinstance(expr, itir.Literal) for expr in range_exprs) - start, stop = (int(literal.value) for literal in range_exprs) # type: ignore[attr-defined] # mypy does not understand assert above + assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) + start, stop = (int(literal.value) for literal in (start_expr, stop_expr)) # type: ignore[attr-defined] # mypy does not understand assert above nb_index: slice | int if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: @@ -180,7 +177,7 @@ def translate( # `can_deref`. warnings.warn( UserWarning( - f"Translating '{self.as_expr()}' using '{shift[0].value}' has " + f"Translating '{self.as_expr()}' using '{off.value}' has " f"an out-of-bounds access." ), stacklevel=2, @@ -192,12 +189,12 @@ def translate( if ( fraction_accessed < NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD - and (_NON_CONTIGUOUS_DOMAIN_WARNING_COUNTER := +1) - < NON_CONTIGUOUS_DOMAIN_MAX_WARNINGS + and (off.value not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS) ): + _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS.add(off.value) warnings.warn( UserWarning( - f"Translating '{self.as_expr()}' using '{shift[0].value}' requires " + f"Translating '{self.as_expr()}' using '{off.value}' requires " f"computations on many additional points " f"({round((1 - fraction_accessed) * 100)}%) in order to get a contiguous " f"domain. Please consider reordering your mesh." diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 78dc49aeb7..aadce76cca 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -45,7 +45,7 @@ def __call__( ) -> itir.Program: ... -def _max_domain_range_sizes(offset_provider: Mapping[str, Any]) -> dict[str, itir.Literal]: +def _max_domain_range_sizes(offset_provider: common.OffsetProvider) -> dict[str, itir.Literal]: """ Extract horizontal domain sizes from an `offset_provider`. @@ -83,7 +83,7 @@ def _has_dynamic_domains(ir: itir.Program) -> bool: def _process_symbolic_domains_option( ir: itir.Program, - offset_provider: common.OffsetProvider | common.OffsetProviderType, + offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], use_max_domain_range_on_unstructured_shift: Optional[bool], ) -> Optional[dict[str, str | itir.Expr]]: @@ -116,7 +116,7 @@ def apply_common_transforms( *, # TODO(havogt): should be replaced by `common.OffsetProviderType`, but global_tmps currently # relies on runtime info or `symbolic_domain_sizes`. - offset_provider: common.OffsetProvider | common.OffsetProviderType, + offset_provider: common.OffsetProvider, extract_temporaries=False, unroll_reduce=False, common_subexpression_elimination=True, diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index b361c5437e..04e820cd27 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -244,7 +244,9 @@ def test_unstructured_translate(shift_chain, expected_end_domain): assert end_domain == expected_end_domain -def test_non_contiguous_domain_warning(): +def test_non_contiguous_domain_warning(monkeypatch): + monkeypatch.setattr(domain_utils, "_NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS", set()) + offset_provider = { "V2V": constructors.as_connectivity( domain={Vertex: (0, 100), V2VDim: 1}, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py index 0c899ead44..a31370baa1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py @@ -26,62 +26,6 @@ ) -@pytest.mark.parametrize( - "param", - [ - ( - "-horizontal_start + Max(horizontal_start, Min(horizontal_end, start_nudge_line_idx_e))", - "Max(0, -horizontal_start + Min(horizontal_end, start_nudge_line_idx_e))", - ), - ("Max(0, vertical_end - vertical_start)", "vertical_end - vertical_start"), - ], -) -def test_simplify_domain_expr(param): - domain_expr = domain_utils.SymbolicDomain.from_expr( - im.domain( - gtx_common.GridType.CARTESIAN, - { - Cell: ("horizontal_start", "horizontal_end"), - KDim: ("vertical_start", "vertical_end"), - }, - ) - ) - domain = gtx_dace_domain.get_field_domain(domain_expr) - - expr = dace.symbolic.pystr_to_symbolic(param[0]) - expected_expr = dace.symbolic.pystr_to_symbolic(param[1]) - - assert gtx_dace_domain.simplify_domain_expr(expr, domain) == expected_expr - - -def test_gtir_domain(): - Vertex = gtx_common.Dimension(value="Vertex", kind=gtx_common.DimensionKind.HORIZONTAL) - KDim = gtx_common.Dimension(value="KDim", kind=gtx_common.DimensionKind.VERTICAL) - - ir = domain_utils.SymbolicDomain.from_expr( - im.domain( - gtx_common.GridType.UNSTRUCTURED, - { - Vertex: (1, 10), - KDim: (2, 20), - }, - ) - ) - - assert gtx_dace_domain.get_field_domain(ir) == [ - gtx_dace_domain.FieldopDomainRange( - Vertex, - 1, - 10, - ), - gtx_dace_domain.FieldopDomainRange( - KDim, - 2, - 20, - ), - ] - - def test_symbolic_domain(): domain = domain_utils.SymbolicDomain.from_expr( im.get_field_domain(gtx_common.GridType.UNSTRUCTURED, "arg", [Vertex, KDim]) From d7bf7a789f26d1890e59fee7fd118f5b2205133a Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 26 Nov 2025 23:39:10 +0100 Subject: [PATCH 88/93] Address review comments --- src/gt4py/next/ffront/decorator.py | 7 +++++- src/gt4py/next/ffront/past_to_itir.py | 6 ++--- .../next/iterator/ir_utils/domain_utils.py | 11 +++++---- .../next/iterator/transforms/pass_manager.py | 24 ++++++++++++++++++- ...eplace_get_domain_range_with_constants.py} | 6 ++--- .../ffront_tests/test_program.py | 4 +--- ...eplace_get_domain_range_with_constants.py} | 6 +++-- 7 files changed, 46 insertions(+), 18 deletions(-) rename src/gt4py/next/iterator/transforms/{transform_get_domain_range.py => replace_get_domain_range_with_constants.py} (94%) rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_transform_get_domain_range.py => test_replace_get_domain_range_with_constants.py} (96%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 71f0d89110..c1845d08c6 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -64,6 +64,7 @@ def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) static_domain_args.append(f"{name}{path_as_expr}") return static_domain_args + # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -189,7 +190,11 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: argument_descriptor_mapping[arguments.StaticArg] = self.static_params if self.static_domains: - argument_descriptor_mapping[arguments.FieldDomainDescriptor] = _field_domain_descriptor_mapping_from_func_type(self.past_stage.past_node.type.definition) + argument_descriptor_mapping[arguments.FieldDomainDescriptor] = ( + _field_domain_descriptor_mapping_from_func_type( + self.past_stage.past_node.type.definition + ) + ) program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 9194e87c2e..621a643be7 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -15,7 +15,6 @@ import devtools from gt4py.eve import NodeTranslator, traits -from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common, config, errors, utils from gt4py.next.ffront import ( fbuiltins, @@ -122,9 +121,10 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ) # TODO(tehrengruber): Put this in a dedicated transformation step. - if (context := inp.args.argument_descriptor_contexts.get(arguments.FieldDomainDescriptor, None)): + if context := inp.args.argument_descriptor_contexts.get(arguments.FieldDomainDescriptor, None): field_domains = { - param: utils.tree_map(lambda x: x.domain if x is not None else x)(v) for param, v in context.items() + param: utils.tree_map(lambda x: x.domain if x is not None else x)(v) + for param, v in context.items() } itir_program = transform_get_domain_range.TransformGetDomainRange.apply( itir_program, sizes=field_domains diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 152d7b930f..6c73fc0d67 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -24,7 +24,7 @@ #: Threshold fraction of domain points which may be added to a domain on translation in order #: to have a contiguous domain before a warning is raised. -NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD: float = 1 / 4 +_NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD: float = 1 / 4 #: Offset tags for which a non-contiguous domain warning has already been printed _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS: set[str] = set() @@ -158,7 +158,9 @@ def translate( ) for expr in (start_expr, stop_expr) ) # type: ignore[assignment] # mypy not smart enough - assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) + assert isinstance(start_expr, itir.Literal) and isinstance( + stop_expr, itir.Literal + ) start, stop = (int(literal.value) for literal in (start_expr, stop_expr)) # type: ignore[attr-defined] # mypy does not understand assert above nb_index: slice | int @@ -187,9 +189,8 @@ def translate( fraction_accessed = np.unique(accessed).size / (new_stop - new_start) # type: ignore[call-overload] # TODO(havogt): improve typing for NDArrayObject - if ( - fraction_accessed < NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD - and (off.value not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS) + if fraction_accessed < _NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD and ( + off.value not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS ): _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS.add(off.value) warnings.warn( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index aadce76cca..441d7a40fa 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Any, Mapping, Optional, Protocol +from typing import Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common @@ -87,6 +87,28 @@ def _process_symbolic_domains_option( symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], use_max_domain_range_on_unstructured_shift: Optional[bool], ) -> Optional[dict[str, str | itir.Expr]]: + """ + Given a program, offset_provider and some configuration options determine how domains are + inferred. + + The output of this function is used as `symbolic_domain_sizes` argument of domain inference, i.e. + :func:`infer_domain.infer_program`. + + Right now domains of `as_fieldop` expressions can be inferred either a) using static information + from the offset provider, or b) they are set to an expression controlled by + the user and configured in the backend, or c) they are set to the maximum possible domain / + everywhere (see :func:`_max_domain_range_sizes`) + + Option a) applies when the program is decorated with `static_domains = True` (unless option c) + is explicitly requested). Then all dynamic domains were replaced with static ones + which we recognize here. The domain inference then uses this static information which we + communicate by returning `None`, i.e. no symbolic domain sizes. + Option b) applies when the user explicitly configured `symbolic_domain_sizes` in the backend. + In that case we just forward the value. + Option c) applies when `static_domains = False` or when explicitly configured in the backend + with `use_max_domain_range_on_unstructured_shift = True`. In that case we determine the + maximum sizes using :func:`_max_domain_range_sizes` and return them. + """ if symbolic_domain_sizes: assert not use_max_domain_range_on_unstructured_shift, "Options are mutually exclusive." return symbolic_domain_sizes diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py b/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py similarity index 94% rename from src/gt4py/next/iterator/transforms/transform_get_domain_range.py rename to src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py index bbd5513310..e4801d24a4 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py +++ b/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py @@ -48,9 +48,9 @@ def visit_FunCall(self, node, **kwargs): @dataclasses.dataclass(frozen=True) -class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): +class ReplaceGetDomainRangeWithConstants(PreserveLocationVisitor, NodeTranslator): """ - Transforms `get_domain` calls into a tuple containing start and stop. + Replace `get_domain` calls into a tuple containing start and stop. Example: >>> from gt4py import next as gtx @@ -86,7 +86,7 @@ class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): ... ), ... ], ... ) - >>> result = TransformGetDomainRange.apply(ir, sizes=sizes) + >>> result = ReplaceGetDomainRangeWithConstants.apply(ir, sizes=sizes) >>> print(result) test(inp, out) { out @ u⟨ Vertexₕ: [{0, 10}[0], {0, 10}[1][, KDimᵥ: [{0, 20}[0], {0, 20}[1][ ⟩ ← (⇑deref)(inp); diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 3b3d5b0fe2..f01c150a74 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -82,9 +82,7 @@ def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatFie def test_copy_execution(cartesian_case, copy_program_def): - copy_program = gtx.program( - copy_program_def, backend=cartesian_case.backend, static_domains=True - ) + copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) cases.verify_with_default_data(cartesian_case, copy_program, ref=lambda in_field: in_field) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py similarity index 96% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py index 4e8ef94cc7..dc190ac382 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py @@ -21,7 +21,9 @@ from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.transform_get_domain_range import TransformGetDomainRange +from gt4py.next.iterator.transforms.transform_get_domain_range import ( + ReplaceGetDomainRangeWithConstants, +) from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -64,7 +66,7 @@ def run_test_program( params=params, body=[setat_factory(domain=domain, target=im.ref(target))], ) - actual = TransformGetDomainRange.apply(testee, sizes=sizes) + actual = ReplaceGetDomainRangeWithConstants.apply(testee, sizes=sizes) actual = CollapseTuple.apply( actual, enabled_transformations=CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE ) From cdeff436ff83d6ec077256a61dc1d1816b012406 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 26 Nov 2025 23:58:58 +0100 Subject: [PATCH 89/93] Address review comments --- src/gt4py/next/ffront/past_to_itir.py | 8 +- .../next/iterator/ir_utils/domain_utils.py | 137 ++++++++++-------- ...replace_get_domain_range_with_constants.py | 2 +- 3 files changed, 80 insertions(+), 67 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 621a643be7..bd3ee6c2b4 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -28,7 +28,7 @@ from gt4py.next.ffront.stages import AOT_PRG from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import remap_symbols, transform_get_domain_range +from gt4py.next.iterator.transforms import remap_symbols, replace_get_domain_range_with_constants from gt4py.next.otf import arguments, stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -126,8 +126,10 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: param: utils.tree_map(lambda x: x.domain if x is not None else x)(v) for param, v in context.items() } - itir_program = transform_get_domain_range.TransformGetDomainRange.apply( - itir_program, sizes=field_domains + itir_program = ( + replace_get_domain_range_with_constants.ReplaceGetDomainRangeWithConstants.apply( + itir_program, sizes=field_domains + ) ) # Translate NamedCollectionTypes to TupleTypes in compile-time args diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 6c73fc0d67..30826f9d87 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -58,6 +58,78 @@ def empty(self) -> bool | None: } +def _unstructured_translate_range_statically( + range_: SymbolicRange, + tag: str, + val: itir.OffsetLiteral + | Literal[trace_shifts.Sentinel.VALUE, trace_shifts.Sentinel.ALL_NEIGHBORS], + offset_provider: common.OffsetProvider, + expr: itir.Expr | None = None, +) -> SymbolicRange: + """ + Translate `range_` using static connectivity information from `offset_provider`. + """ + assert common.is_offset_provider(offset_provider) + connectivity = offset_provider[tag] + assert isinstance(connectivity, common.Connectivity) + skip_value = connectivity.skip_value + + # fold & convert expr into actual integers + start_expr, stop_expr = range_.start, range_.stop + start_expr, stop_expr = ( + collapse_tuple.CollapseTuple.apply( + expr, + within_stencil=False, + allow_undeclared_symbols=True, + ) + for expr in (start_expr, stop_expr) + ) # type: ignore[assignment] # mypy not smart enough + assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) + start, stop = (int(literal.value) for literal in (start_expr, stop_expr)) # type: ignore[attr-defined] # mypy does not understand assert above + + nb_index: slice | int + if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: + nb_index = slice(None) + else: + nb_index = val.value # type: ignore[assignment] # assert above + + accessed = connectivity.ndarray[start:stop, nb_index] + + if isinstance(val, itir.OffsetLiteral) and np.any(accessed == skip_value): + # TODO(tehrengruber): Turn this into a configurable error. This is currently + # not possible since some test cases starting from ITIR containing + # `can_deref` might lead here. The frontend never emits such IR and domain + # inference runs after we transform reductions into stmts containing + # `can_deref`. + warnings.warn( + UserWarning(f"Translating '{expr}' using '{tag}' has an out-of-bounds access."), + stacklevel=2, + ) + + new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + + fraction_accessed = np.unique(accessed).size / (new_stop - new_start) # type: ignore[call-overload] # TODO(havogt): improve typing for NDArrayObject + + if fraction_accessed < _NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD and ( + tag not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS + ): + _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS.add(tag) + warnings.warn( + UserWarning( + f"Translating '{expr}' using '{tag}' requires " + f"computations on many additional points " + f"({round((1 - fraction_accessed) * 100)}%) in order to get a contiguous " + f"domain. Please consider reordering your mesh." + ), + stacklevel=2, + ) + + return SymbolicRange( + im.literal(str(new_start), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str(new_stop), builtins.INTEGER_INDEX_BUILTIN), + ) + + @dataclasses.dataclass(frozen=True) class SymbolicDomain: grid_type: common.GridType @@ -143,69 +215,8 @@ def translate( im.ensure_expr(symbolic_domain_sizes[new_dim.value]), ) else: - assert common.is_offset_provider(offset_provider) - connectivity = offset_provider[off.value] - assert isinstance(connectivity, common.Connectivity) - skip_value = connectivity.skip_value - - # fold & convert expr into actual integers - start_expr, stop_expr = new_ranges[old_dim].start, new_ranges[old_dim].stop - start_expr, stop_expr = ( - collapse_tuple.CollapseTuple.apply( - expr, - within_stencil=False, - allow_undeclared_symbols=True, - ) - for expr in (start_expr, stop_expr) - ) # type: ignore[assignment] # mypy not smart enough - assert isinstance(start_expr, itir.Literal) and isinstance( - stop_expr, itir.Literal - ) - start, stop = (int(literal.value) for literal in (start_expr, stop_expr)) # type: ignore[attr-defined] # mypy does not understand assert above - - nb_index: slice | int - if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: - nb_index = slice(None) - else: - nb_index = val.value # type: ignore[assignment] # assert above - - accessed = connectivity.ndarray[start:stop, nb_index] - - if isinstance(val, itir.OffsetLiteral) and np.any(accessed == skip_value): - # TODO(tehrengruber): Turn this into a configurable error. This is currently - # not possible since some test cases starting from ITIR containing - # `can_deref` might lead here. The frontend never emits such IR and domain - # inference runs after we transform reductions into stmts containing - # `can_deref`. - warnings.warn( - UserWarning( - f"Translating '{self.as_expr()}' using '{off.value}' has " - f"an out-of-bounds access." - ), - stacklevel=2, - ) - - new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject - - fraction_accessed = np.unique(accessed).size / (new_stop - new_start) # type: ignore[call-overload] # TODO(havogt): improve typing for NDArrayObject - - if fraction_accessed < _NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD and ( - off.value not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS - ): - _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS.add(off.value) - warnings.warn( - UserWarning( - f"Translating '{self.as_expr()}' using '{off.value}' requires " - f"computations on many additional points " - f"({round((1 - fraction_accessed) * 100)}%) in order to get a contiguous " - f"domain. Please consider reordering your mesh." - ), - stacklevel=2, - ) - - new_range = SymbolicRange( - im.literal(str(new_start), builtins.INTEGER_INDEX_BUILTIN), - im.literal(str(new_stop), builtins.INTEGER_INDEX_BUILTIN), + new_range = _unstructured_translate_range_statically( + new_ranges[old_dim], off.value, val, offset_provider, self.as_expr() ) new_ranges = dict( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py index dc190ac382..322c04f5e5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py @@ -21,7 +21,7 @@ from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.transform_get_domain_range import ( +from gt4py.next.iterator.transforms.replace_get_domain_range_with_constants import ( ReplaceGetDomainRangeWithConstants, ) from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple From 83b0e88f054a55b3e9eab668e211c6214e5c7661 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 27 Nov 2025 00:10:51 +0100 Subject: [PATCH 90/93] Address review comments --- src/gt4py/next/ffront/decorator.py | 1 + src/gt4py/next/iterator/ir_utils/domain_utils.py | 7 ++++--- src/gt4py/next/iterator/transforms/pass_manager.py | 7 ++++--- .../transforms/replace_get_domain_range_with_constants.py | 4 +++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c1845d08c6..31dc3aa5f7 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -190,6 +190,7 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: argument_descriptor_mapping[arguments.StaticArg] = self.static_params if self.static_domains: + assert isinstance(self.past_stage.past_node.type, ts_ffront.ProgramType) argument_descriptor_mapping[arguments.FieldDomainDescriptor] = ( _field_domain_descriptor_mapping_from_func_type( self.past_stage.past_node.type.definition diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 30826f9d87..835b6c2f2f 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -76,16 +76,16 @@ def _unstructured_translate_range_statically( # fold & convert expr into actual integers start_expr, stop_expr = range_.start, range_.stop - start_expr, stop_expr = ( + start_expr, stop_expr = ( # type: ignore[assignment] # mypy not smart enough collapse_tuple.CollapseTuple.apply( expr, within_stencil=False, allow_undeclared_symbols=True, ) for expr in (start_expr, stop_expr) - ) # type: ignore[assignment] # mypy not smart enough + ) assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) - start, stop = (int(literal.value) for literal in (start_expr, stop_expr)) # type: ignore[attr-defined] # mypy does not understand assert above + start, stop = int(start_expr.value), int(stop_expr.value) nb_index: slice | int if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: @@ -215,6 +215,7 @@ def translate( im.ensure_expr(symbolic_domain_sizes[new_dim.value]), ) else: + assert common.is_offset_provider(offset_provider) new_range = _unstructured_translate_range_statically( new_ranges[old_dim], off.value, val, offset_provider, self.as_expr() ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 441d7a40fa..22bf93d7f2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -136,9 +136,7 @@ def _process_symbolic_domains_option( def apply_common_transforms( ir: itir.Program, *, - # TODO(havogt): should be replaced by `common.OffsetProviderType`, but global_tmps currently - # relies on runtime info or `symbolic_domain_sizes`. - offset_provider: common.OffsetProvider, + offset_provider: common.OffsetProvider | common.OffsetProviderType, extract_temporaries=False, unroll_reduce=False, common_subexpression_elimination=True, @@ -151,6 +149,9 @@ def apply_common_transforms( use_max_domain_range_on_unstructured_shift: Optional[bool] = None, ) -> itir.Program: assert isinstance(ir, itir.Program) + # TODO(tehrengruber): Allow `common.OffsetProviderType`, but domain inference currently + # relies on static information or `symbolic_domain_sizes`. + assert common.is_offset_provider(offset_provider) offset_provider_type = common.offset_provider_to_type(offset_provider) diff --git a/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py b/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py index e4801d24a4..f228c24998 100644 --- a/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py +++ b/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py @@ -94,7 +94,9 @@ class ReplaceGetDomainRangeWithConstants(PreserveLocationVisitor, NodeTranslator """ @classmethod - def apply(cls, program: itir.Program, sizes: dict[str, MaybeNestedInTuple[common.Domain]]): + def apply( + cls, program: itir.Program, sizes: dict[str, MaybeNestedInTuple[common.Domain | None]] + ): return cls().visit(program, sizes=sizes) def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: From 9df083edcf32d21837a4dbba328a21d33221bb37 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 6 Jan 2026 14:43:09 +0100 Subject: [PATCH 91/93] Simplify `symbolic_domain_sizes` --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 2 +- src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- src/gt4py/next/iterator/transforms/infer_domain.py | 10 +++++----- src/gt4py/next/iterator/transforms/pass_manager.py | 6 +++--- .../program_processors/codegens/gtfn/gtfn_module.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 835b6c2f2f..97d12cb4da 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -176,7 +176,7 @@ def translate( offset_provider: common.OffsetProvider | common.OffsetProviderType, #: A dictionary mapping axes names to their length. See #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, ) -> SymbolicDomain: offset_provider_type = common.offset_provider_to_type(offset_provider) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 3bd734fa96..9388612466 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -312,7 +312,7 @@ def create_global_tmps( offset_provider: common.OffsetProvider | common.OffsetProviderType, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, *, uids: Optional[eve_utils.UIDGenerator] = None, ) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f4dfd4096e..f3c4af37dd 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -55,7 +55,7 @@ class DomainAccessDescriptor(eve.StrEnum): class InferenceOptions(typing.TypedDict): offset_provider: common.OffsetProvider | common.OffsetProviderType - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] allow_uninferred: bool keep_existing_domains: bool @@ -126,7 +126,7 @@ def _extract_accessed_domains( input_ids: list[str], target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], + symbolic_domain_sizes: Optional[dict[str, itir.Expr]], ) -> dict[str, NonTupleDomainAccess]: accessed_domains: dict[str, NonTupleDomainAccess] = {} @@ -182,7 +182,7 @@ def _infer_as_fieldop( target_domain: DomainAccess, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], + symbolic_domain_sizes: Optional[dict[str, itir.Expr]], allow_uninferred: bool, keep_existing_domains: bool, ) -> tuple[itir.FunCall, AccessedDomains]: @@ -441,7 +441,7 @@ def infer_expr( domain: DomainAccess, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, allow_uninferred: bool = False, keep_existing_domains: bool = False, ) -> tuple[itir.Expr, AccessedDomains]: @@ -557,7 +557,7 @@ def infer_program( program: itir.Program, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, allow_uninferred: bool = False, keep_existing_domains: bool = False, ) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 22bf93d7f2..543d6a9282 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -84,9 +84,9 @@ def _has_dynamic_domains(ir: itir.Program) -> bool: def _process_symbolic_domains_option( ir: itir.Program, offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], + symbolic_domain_sizes: Optional[dict[str, itir.Expr]], use_max_domain_range_on_unstructured_shift: Optional[bool], -) -> Optional[dict[str, str | itir.Expr]]: +) -> Optional[dict[str, itir.Expr]]: """ Given a program, offset_provider and some configuration options determine how domains are inferred. @@ -143,7 +143,7 @@ def apply_common_transforms( force_inline_lambda_args=False, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None, # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins # to work with / translate domains. use_max_domain_range_on_unstructured_shift: Optional[bool] = None, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 63d9416bc5..c480cd70e6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -52,7 +52,7 @@ class GTFNTranslationStep( enable_itir_transforms: bool = True use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None + symbolic_domain_sizes: Optional[dict[str, itir.Expr]] = None use_max_domain_range_on_unstructured_shift: Optional[bool] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: From a8039508d19f9b021e1cbb9a02bd25154af342be Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 6 Jan 2026 15:40:45 +0100 Subject: [PATCH 92/93] Cleanup --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 97d12cb4da..bd930f76ec 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -76,14 +76,8 @@ def _unstructured_translate_range_statically( # fold & convert expr into actual integers start_expr, stop_expr = range_.start, range_.stop - start_expr, stop_expr = ( # type: ignore[assignment] # mypy not smart enough - collapse_tuple.CollapseTuple.apply( - expr, - within_stencil=False, - allow_undeclared_symbols=True, - ) - for expr in (start_expr, stop_expr) - ) + # note: if you find tuple expressions on literals here, you likely forgot to collapse tuple + # expressions beforehand assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) start, stop = int(start_expr.value), int(stop_expr.value) From c9b4ad15fd77bea1d967138aed022ad9fb1978ae Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 6 Jan 2026 15:41:00 +0100 Subject: [PATCH 93/93] Format --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 2 +- .../next_tests/unit_tests/otf_tests/test_compiled_program.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index bd930f76ec..2a72f95441 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -18,7 +18,7 @@ from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import collapse_tuple, trace_shifts +from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 5afa841107..6746dd02d4 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -125,7 +125,9 @@ def pirate(program: toolchain.CompilableProgram): _verify_program_has_expected_true_value(hijacked_program.data) -def _verify_program_has_expected_domain(program: itir.Program, expected_domain: gtx.Domain, uids: utils.IDGeneratorPool): +def _verify_program_has_expected_domain( + program: itir.Program, expected_domain: gtx.Domain, uids: utils.IDGeneratorPool +): assert isinstance(program.body[0], itir.SetAt) assert isinstance(program.body[0].expr, itir.FunCall) assert program.body[0].expr.fun == itir.SymRef(id="fop")