diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index d96209bbb1..aabcc062bc 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -406,7 +406,9 @@ def create_if( return im.let(cond_symref_name, cond_)(result) - _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where + def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: + domain, true_branch, false_branch = self.visit(node.args, **kwargs) + return im.concat_where(domain, true_branch, false_branch) def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.call("broadcast")(*self.visit(node.args, **kwargs)) @@ -488,7 +490,7 @@ def _map( Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. """ if all( - isinstance(t, ts.ScalarType) + isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType)) for arg_type in original_arg_types for t in type_info.primitive_constituents(arg_type) ): diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 9b065de658..e3f45f6c74 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -407,6 +407,11 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def concat_where(*args): + raise BackendNotSelectedError() + + UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -494,6 +499,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "scan", "tuple_get", "unstructured_domain", + "concat_where", *ARITHMETIC_BUILTINS, *TYPE_BUILTINS, } diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index d86a584b8d..3888ccf2de 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1802,6 +1802,11 @@ def index(axis: common.Dimension) -> common.Field: return IndexField(axis) +@builtins.concat_where.register(EMBEDDED) +def concat_where(*args): + raise NotImplementedError("To be implemented in frontend embedded.") + + def closure( domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 69881dd3ed..85a4854998 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -50,7 +50,7 @@ def _with_altered_iterator_position_dims( ) -def _is_trivial_make_tuple_call(node: ir.Expr): +def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False @@ -307,9 +307,10 @@ def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optiona self.fp_transform(im.tuple_get(idx.value, expr.fun.expr), **kwargs) ) )(*expr.args) - elif cpm.is_call_to(expr, "if_"): + elif cpm.is_call_to(expr, ("if_", "concat_where")): + fun = expr.fun cond, true_branch, false_branch = expr.args - return im.if_( + return im.call(fun)( cond, self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), diff --git a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py index a693770ad8..108488add6 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py @@ -25,8 +25,8 @@ def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr: """ Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain. - `in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` - -> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1` + pos = `{i, j, k}`, domain = `u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩` + -> `((i0 <= i) & (i < i1)) & ((j0 <= j) & (j < j1)) & ((k0 <= k)l & (k < k1))` """ ret = [ im.and_( diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 4a16122a71..48653ba5b5 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -98,6 +98,8 @@ class Transformation(enum.Flag): # `if_(True, true_branch, false_branch)` -> `true_branch` FOLD_IF = enum.auto() + FOLD_INFINITY_ARITHMETIC = enum.auto() + @classmethod def all(self) -> ConstantFolding.Transformation: return functools.reduce(operator.or_, self.__members__.values()) @@ -239,3 +241,60 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: assert node.args[0].value == "False" return node.args[2] return None + + def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.Node]: + if cpm.is_call_to(node, "plus"): + # `a + +/-inf` -> `+/-inf` + a, b = node.args + assert not (isinstance(a, ir.InfinityLiteral) and isinstance(b, ir.InfinityLiteral)) + for arg in a, b: + if isinstance(arg, ir.InfinityLiteral): + return arg + + if cpm.is_call_to(node, "minimum"): + if ir.InfinityLiteral.NEGATIVE in node.args: + # `minimum(-inf, a)` -> `-inf` + return ir.InfinityLiteral.NEGATIVE + if ir.InfinityLiteral.POSITIVE in node.args: + # `minimum(inf, a)` -> `a` + a, b = node.args + return b if a == ir.InfinityLiteral.POSITIVE else a + + if cpm.is_call_to(node, "maximum"): + if ir.InfinityLiteral.POSITIVE in node.args: + # `maximum(inf, a)` -> `inf` + return ir.InfinityLiteral.POSITIVE + if ir.InfinityLiteral.NEGATIVE in node.args: + # `maximum(-inf, a)` -> `a` + a, b = node.args + return b if a == ir.InfinityLiteral.NEGATIVE else a + + if cpm.is_call_to(node, ("less", "less_equal")): + a, b = node.args + # we don't handle `inf < inf` or `-inf < -inf`.args + assert a != b or not isinstance(a, ir.InfinityLiteral) + + # `-inf < v` -> `True` + # `v < inf` -> `True` + if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: + return im.literal_from_value(True) + # `inf < v` -> `False` + # `v < -inf ` -> `False` + if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: + return im.literal_from_value(False) + + if cpm.is_call_to(node, ("greater", "greater_equal")): + a, b = node.args + # we don't handle `inf > inf` or `-inf > -inf`.args + assert a != b or not isinstance(a, ir.InfinityLiteral) + + # `inf > v` -> `True` + # `v > -inf ` -> `True` + if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE: + return im.literal_from_value(True) + # `-inf > v` -> `False` + # `v > inf` -> `False` + if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE: + return im.literal_from_value(False) + + return None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 87241e1ba8..2fcbd5df0d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -87,9 +87,19 @@ def _is_collectable_expr(node: itir.Node) -> bool: # backend (single pass eager depth first visit approach) # do also not collect lifts or applied lifts as they become invisible to the lift inliner # otherwise - if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node): + # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if cpm.is_call_to( + node, ("lift", "shift", "reduce", "map_", "index") + ) or cpm.is_applied_lift(node): return False return True + # do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement + # instead of an as_fieldop + if cpm.is_call_to(node, "make_tuple") and all( + cpm.is_call_to(arg, "index") for arg in node.args + ): + return False elif isinstance(node, itir.Lambda): return True diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 0e01dafed0..4b3a258396 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -450,7 +450,4 @@ def visit(self, node, **kwargs): node = super().visit(node, **kwargs) - if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): - node.annex.domain = node.annex.domain - return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1e6d74a72..b3c81ca2d0 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -329,11 +329,16 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - offset_provider_type = common.offset_provider_to_type(offset_provider) + # TODO(tehrengruber): document why to keep existing domains and add test program = infer_domain.infer_program( - program, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes + program, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + keep_existing_domains=True, + ) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) - program = type_inference.infer(program, offset_provider_type=offset_provider_type) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index eef2c1bab0..08538788b6 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,10 +12,12 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + concat_where, dead_code_elimination, fuse_as_fieldop, global_tmps, infer_domain, + infer_domain_ops, inline_dynamic_shifts, inline_fundefs, inline_lifts, @@ -81,6 +83,10 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = concat_where.canonicalize_domain_argument(ir) + + ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -88,6 +94,8 @@ def apply_common_transforms( ) ir = remove_broadcast.RemoveBroadcast.apply(ir) + ir = concat_where.transform_to_as_fieldop(ir) + for _ in range(10): inlined = ir @@ -183,6 +191,11 @@ def apply_fieldview_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + + ir = infer_domain_ops.InferDomainOps.apply(ir) + ir = concat_where.canonicalize_domain_argument(ir) + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) ir = remove_broadcast.RemoveBroadcast.apply(ir) return ir diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index de91b9ee87..8173ceebbb 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -274,6 +274,9 @@ class TraceShifts(PreserveLocationVisitor, NodeTranslator): def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: return Sentinel.VALUE + def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, ctx: dict[str, Any]): + return Sentinel.VALUE + def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 702bd48dec..f7445461c0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -125,7 +125,11 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) + or _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, Literal)) + or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")), + el, + ) for el in value ): raise ValueError( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index e7ef2c7c74..e395bcf991 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -223,7 +223,9 @@ class Params: run_gtfn_gpu = GTFNBackendFactory(gpu=True) -run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) +run_gtfn_gpu_cached = GTFNBackendFactory( + gpu=True, cached=True, otf_workflow__cached_translation=True +) run_gtfn_no_transforms = GTFNBackendFactory( otf_workflow__bare_translation__enable_itir_transforms=False diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 2ccbb94443..ad3ff4bbfc 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -144,7 +144,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ @@ -161,6 +160,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] ) EMBEDDED_SKIP_LIST = [ @@ -179,6 +179,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] +GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [ + (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), +] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST + DOMAIN_INFERENCE_SKIP_LIST @@ -219,5 +222,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], - ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.GTIR_EMBEDDED: GTIR_EMBEDDED_SKIP_LIST, } diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 1360fd44cf..967cf0ab11 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy import dataclasses import functools import inspect @@ -66,6 +67,7 @@ JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] IKField: TypeAlias = gtx.Field[[IDim, KDim], np.int32] # type: ignore [valid-type] @@ -107,7 +109,7 @@ def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: ... @@ -140,11 +142,11 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: return constructors.full( - domain=common.domain(sizes), fill_value=self.value, dtype=dtype, allocator=allocator + domain=domain, fill_value=self.value, dtype=dtype, allocator=allocator ) @@ -166,16 +168,17 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: - if len(sizes) > 1: + if len(domain.dims) > 1: raise ValueError( - f"'IndexInitializer' only supports fields with a single 'Dimension', got {sizes}." + f"'IndexInitializer' only supports fields with a single 'Dimension', got {domain}." ) - n_data = list(sizes.values())[0] return constructors.as_field( - domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=allocator + domain=domain, + data=np.arange(domain.ranges[0].start, domain.ranges[0].stop, dtype=dtype), + allocator=allocator, ) def from_case( @@ -207,16 +210,15 @@ def scalar_value(self) -> ScalarValue: def field( self, allocator: next_allocators.FieldBufferAllocatorProtocol, - sizes: dict[gtx.Dimension, int], + domain: common.Domain, dtype: np.typing.DTypeLike, ) -> FieldValue: start = self.start - svals = tuple(sizes.values()) - n_data = int(np.prod(svals)) - self.start += n_data + assert isinstance(domain.size, int) + self.start += domain.size return constructors.as_field( - common.domain(sizes), - np.arange(start, start + n_data, dtype=dtype).reshape(svals), + common.domain(domain), + np.arange(start, self.start, dtype=dtype).reshape(domain.shape), allocator=allocator, ) @@ -329,6 +331,7 @@ def allocate( name: str, *, sizes: Optional[dict[gtx.Dimension, int]] = None, + domain: Optional[dict[gtx.Dimension, tuple[int, int]] | gtx.Domain] = None, strategy: Optional[DataInitializer] = None, dtype: Optional[np.typing.DTypeLike] = None, extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None, @@ -350,9 +353,22 @@ def allocate( Useful for shifted fields, which must start off bigger than the output field in the shifted dimension. """ - sizes = extend_sizes( - case.default_sizes | (sizes or {}), extend + if sizes: + assert not domain and all(dim in case.default_sizes for dim in sizes) + domain = { + dim: (0, sizes.get(dim, default_size)) + for dim, default_size in case.default_sizes.items() + } + + domain = domain or {dim: (0, size) for dim, size in case.default_sizes.items()} + + if not isinstance(domain, gtx.Domain): + domain = gtx.domain(domain) + + domain = extend_domain( + domain, extend ) # TODO: this should take into account the Domain of the allocated field + arg_type = get_param_types(fieldview_prog)[name] if strategy is None: if name in ["out", RETURN]: @@ -362,7 +378,7 @@ def allocate( return _allocate_from_type( case=case, arg_type=arg_type, - sizes=sizes, + domain=domain, dtype=dtype, strategy=strategy.from_case(case=case, fieldview_prog=fieldview_prog, arg_name=name), ) @@ -551,14 +567,14 @@ def unstructured_case_3d(unstructured_case): return dataclasses.replace( unstructured_case, default_sizes={**unstructured_case.default_sizes, KDim: 10}, - offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + offset_provider={**unstructured_case.offset_provider, "Koff": KDim}, ) def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, - sizes: dict[gtx.Dimension, int], + domain: gtx.Domain, strategy: DataInitializer, dtype: Optional[np.typing.DTypeLike] = None, tuple_start: Optional[int] = None, @@ -568,7 +584,7 @@ def _allocate_from_type( case ts.FieldType(dims=dims, dtype=arg_dtype): return strategy.field( allocator=case.allocator, - sizes={dim: sizes[dim] for dim in dims}, + domain=common.domain(tuple(domain[dim] for dim in dims)), dtype=dtype or arg_dtype.kind.name.lower(), ) case ts.ScalarType(kind=kind): @@ -577,7 +593,7 @@ def _allocate_from_type( return tuple( ( _allocate_from_type( - case=case, arg_type=t, sizes=sizes, dtype=dtype, strategy=strategy + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy ) for t in types ) @@ -613,15 +629,26 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> raise TypeError(f"Can not get size for parameter of type '{param_type}'.") -def extend_sizes( - sizes: dict[gtx.Dimension, int], extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None +def extend_domain( + domain: gtx.Domain, extend: Optional[dict[gtx.Dimension, tuple[int, int]]] = None ) -> dict[gtx.Dimension, int]: """Calculate the sizes per dimension given a set of extensions.""" - sizes = sizes.copy() if extend: + domain = copy.deepcopy(domain) for dim, (lower, upper) in extend.items(): - sizes[dim] += upper - lower - return sizes + domain = domain.replace( + dim, + common.named_range( + ( + dim, + ( + domain[dim].unit_range.start - lower, + domain[dim].unit_range.stop + upper, + ), + ) + ), + ) + return domain def get_default_data( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 44ef9b62f0..7be2ad6999 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -9,8 +9,10 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import KDim, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx +from gt4py.next import errors +from gt4py.next import broadcast from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -20,108 +22,290 @@ pytestmark = pytest.mark.uses_concat_where -def test_boundary_same_size_fields(cartesian_case): +def test_concat_where_simple(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim > 0, air, ground) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground")() + air = cases.allocate(cartesian_case, testee, "air")() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + + +def test_concat_where(cartesian_case): + @gtx.field_operator + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate(cartesian_case, testee, "ground")() + air = cases.allocate(cartesian_case, testee, "air")() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k[np.newaxis, np.newaxis, :] == 0, ground.asnumpy(), air.asnumpy()) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + + +def test_concat_where_non_overlapping(cartesian_case): + """Fields only defined in their respective region in concat_where.""" + + @gtx.field_operator + def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, ground, air) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + ground = cases.allocate( + cartesian_case, testee, "ground", domain=out.domain.slice_at[:, :, 0:1] + )() + air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate((ground.asnumpy(), air.asnumpy()), axis=2) + cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref) + + +def test_concat_where_scalar_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField: + return concat_where(KDim < N - 1, a, b) + + a = 3 + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.concatenate( + ( + np.full((*out.domain.shape[0:2], out.domain.shape[2] - 1), a), + b.asnumpy()[:, :, -1:], + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, cartesian_case.default_sizes[KDim], out=out, ref=ref) + + +def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case): + """Output domain such that the scalar branch is never active.""" + + @gtx.field_operator + def testee(a: np.int32, b: cases.KField, N: np.int32) -> cases.KField: + return concat_where(KDim < N, a, b) + + a = 3 + b = cases.allocate(cartesian_case, testee, "b")() + out = cases.allocate(cartesian_case, testee, cases.RETURN, domain=b.domain.slice_at[1:])() + + ref = b.asnumpy()[1:] + cases.verify(cartesian_case, testee, a, b, 1, out=out, ref=ref) + + +def test_concat_where_single_level_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + a = cases.allocate( + cartesian_case, testee, "a", domain=gtx.domain({KDim: out.domain.shape[2]}) + )() + b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate( + ( + np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + b.asnumpy(), + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + + +def test_concat_where_single_level_restricted_domain_broadcast(cartesian_case): + @gtx.field_operator + def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, a, b) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + # note: this field is only defined on K: 0, 1, i.e., contains only a single value + a = cases.allocate(cartesian_case, testee, "a", domain=gtx.domain({KDim: (0, 1)}))() + b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() + + ref = np.concatenate( + ( + np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + b.asnumpy(), + ), + axis=2, + ) + cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) + + +def test_boundary_single_layer_3d_bc(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary")() + boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + k[np.newaxis, np.newaxis, :] == 0, + np.broadcast_to(boundary.asnumpy(), interior.shape), + interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -def test_boundary_horizontal_slice(cartesian_case): +def test_boundary_single_layer_2d_bc(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJField - ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy()[:, :, np.newaxis], interior.asnumpy(), ) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -def test_boundary_single_layer(cartesian_case): +def test_boundary_single_layer_2d_bc_on_empty_branch(cartesian_case): @gtx.field_operator - def testee( - k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField - ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: + return concat_where(KDim == 0, boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() - boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate( + cartesian_case, testee, cases.RETURN, domain=interior.domain.slice_at[:, :, 1:] + )() + + ref = interior.asnumpy()[:, :, 1:] + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + + +def test_dimension_two_nested_conditions(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: + return concat_where((KDim < 2), boundary, concat_where((KDim >= 5), boundary, interior)) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, - np.broadcast_to(boundary.asnumpy(), interior.shape), + (k[np.newaxis, np.newaxis, :] < 2) | (k[np.newaxis, np.newaxis, :] >= 5), + boundary.asnumpy(), interior.asnumpy(), ) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) - cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension_two_conditions_and(cartesian_case): + @gtx.field_operator + def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: + return concat_where((0 < KDim) & (KDim < (nlev - 1)), interior, boundary) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + nlev = cartesian_case.default_sizes[KDim] + k = np.arange(0, nlev) + ref = np.where((0 < k) & (k < (nlev - 1)), interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, nlev, out=out, ref=ref) + + +def test_dimension_eq_in_middle_of_domain(cartesian_case): + @gtx.field_operator + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where((KDim == 2), interior, boundary) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where(k == 2, interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) -def test_alternating_mask(cartesian_case): + +def test_dimension_two_conditions_or(cartesian_case): @gtx.field_operator - def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: - return concat_where(k % 2 == 0, f1, f0) + def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where(((KDim < 2) | (KDim >= 5)), boundary, interior) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - f0 = cases.allocate(cartesian_case, testee, "f0")() - f1 = cases.allocate(cartesian_case, testee, "f1")() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where((k < 2) | (k >= 5), boundary.asnumpy(), interior.asnumpy()) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + + +def test_lap_like(cartesian_case): + @gtx.field_operator + def testee( + inp: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] + ) -> cases.IJField: + # TODO add support for multi-dimensional concat_where masks + return concat_where( + (IDim == 0) | (IDim == shape[0] - 1), + boundary, + concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, inp), + ) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + inp = cases.allocate(cartesian_case, testee, "inp", domain=out.domain.slice_at[1:-1, 1:-1])() + boundary = 2 - cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + ref = np.full(out.domain.shape, np.nan) + ref[0, :] = boundary + ref[:, 0] = boundary + ref[-1, :] = boundary + ref[:, -1] = boundary + ref[1:-1, 1:-1] = inp.asnumpy() + cases.verify(cartesian_case, testee, inp, boundary, out.domain.shape, out=out, ref=ref) @pytest.mark.uses_tuple_returns def test_with_tuples(cartesian_case): @gtx.field_operator def testee( - k: cases.KField, interior0: cases.IJKField, boundary0: cases.IJField, interior1: cases.IJKField, boundary1: cases.IJField, - ) -> Tuple[cases.IJKField, cases.IJKField]: - return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + ) -> tuple[cases.IJKField, cases.IJKField]: + return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior0 = cases.allocate(cartesian_case, testee, "interior0")() boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() interior1 = cases.allocate(cartesian_case, testee, "interior1")() boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() + k = np.arange(0, cartesian_case.default_sizes[KDim]) ref0 = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary0.asnumpy()[:, :, np.newaxis], interior0.asnumpy(), ) ref1 = np.where( - k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + k[np.newaxis, np.newaxis, :] == 0, boundary1.asnumpy()[:, :, np.newaxis], interior1.asnumpy(), ) @@ -129,7 +313,71 @@ def testee( cases.verify( cartesian_case, testee, - k, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) + + +def test_nested_conditions_with_empty_branches(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IField, boundary: cases.IField, N: gtx.int32) -> cases.IField: + interior = concat_where(IDim == 0, boundary, interior) + interior = concat_where((1 <= IDim) & (IDim < N - 1), interior * 2, interior) + interior = concat_where(IDim == N - 1, boundary, interior) + return interior + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + N = cartesian_case.default_sizes[IDim] + + i = np.arange(0, cartesian_case.default_sizes[IDim]) + ref = np.where( + (i[:] == 0) | (i[:] == N - 1), + boundary.asnumpy(), + interior.asnumpy() * 2, + ) + cases.verify(cartesian_case, testee, interior, boundary, N, out=out, ref=ref) + + +@pytest.mark.uses_tuple_returns +def test_with_tuples_different_domain(cartesian_case): + @gtx.field_operator + def testee( + interior0: cases.IJKField, + boundary0: cases.IJKField, + interior1: cases.KField, + boundary1: cases.KField, + ) -> tuple[cases.IJKField, cases.IJKField]: + a, b = concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) + # the broadcast is only needed since we can not return fields on different domains yet + return a, broadcast(b, (IDim, JDim, KDim)) + + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref0 = np.where( + k[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy(), + interior0.asnumpy(), + ) + ref1 = np.where( + k == 0, + boundary1.asnumpy(), + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, interior0, boundary0, interior1, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 803ab0c6bc..1a1984a71b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -135,7 +135,7 @@ def test_neighbor_sum(unstructured_case_3d, fop): @pytest.mark.uses_unstructured_shift -def test_reduction_execution_with_offset(unstructured_case): +def test_reduction_execution_with_offset(unstructured_case_3d): EKField: TypeAlias = gtx.Field[[Edge, KDim], np.int32] VKField: TypeAlias = gtx.Field[[Vertex, KDim], np.int32] @@ -152,12 +152,12 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].asnumpy() - field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() - out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() + v2e_table = unstructured_case_3d.offset_provider["V2E"].asnumpy() + field = cases.allocate(unstructured_case_3d, fencil, "edge_f", sizes={KDim: 2})() + out = cases.allocate(unstructured_case_3d, fencil_op, cases.RETURN, sizes={KDim: 1})() cases.verify( - unstructured_case, + unstructured_case_3d, fencil, field, out, @@ -168,7 +168,7 @@ def fencil(edge_f: EKField, out: VKField): initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE, ).reshape(out.shape), - offset_provider=unstructured_case.offset_provider | {"Koff": KDim}, + offset_provider=unstructured_case_3d.offset_provider, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index d97869156b..97ffb0ee6d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -5,11 +5,14 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import pytest + from gt4py.next import common from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts +from gt4py.next.type_system import type_specifications as ts + int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) @@ -114,7 +117,27 @@ def test_simple_tuple_get_make_tuple(): assert expected == actual -def test_propagate_tuple_get(): +@pytest.mark.parametrize("fun", ["if_", "concat_where"]) +def test_propagate_tuple_get(fun): + testee = im.tuple_get( + 0, im.call(fun)("cond", im.make_tuple("el1", "el2"), im.make_tuple("el1", "el2")) + ) + expected = im.call(fun)( + "cond", + im.tuple_get(0, im.make_tuple("el1", "el2")), + im.tuple_get(0, im.make_tuple("el1", "el2")), + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TUPLE_GET, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert expected == actual + + +def test_propagate_tuple_get_let(): expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) actual = CollapseTuple.apply( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index d9dea7e2d5..a56b539014 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -193,6 +193,71 @@ def test_value_from_literal(value, expected): im.plus(im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a")), ), ), + # InfinityLiteral folding + ( + im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + itir.InfinityLiteral.POSITIVE, + ), + ( + im.call("maximum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + itir.InfinityLiteral.POSITIVE, + ), + ( + im.call("maximum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + im.literal_from_value(1), + ), + ( + im.call("maximum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + im.literal_from_value(1), + ), + ( + im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + im.literal_from_value(1), + ), + ( + im.call("minimum")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + im.literal_from_value(1), + ), + ( + im.call("minimum")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + itir.InfinityLiteral.NEGATIVE, + ), + ( + im.call("minimum")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + itir.InfinityLiteral.NEGATIVE, + ), + ( + im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + im.literal_from_value(False), + ), + ( + im.call("greater")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + im.literal_from_value(True), + ), + ( + im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.POSITIVE), + im.literal_from_value(True), + ), + ( + im.call("less")(im.literal_from_value(1), itir.InfinityLiteral.NEGATIVE), + im.literal_from_value(False), + ), + ( + im.call("greater")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + im.literal_from_value(True), + ), + ( + im.call("greater")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + im.literal_from_value(False), + ), + ( + im.call("less")(itir.InfinityLiteral.POSITIVE, im.literal_from_value(1)), + im.literal_from_value(False), + ), + ( + im.call("less")(itir.InfinityLiteral.NEGATIVE, im.literal_from_value(1)), + im.literal_from_value(True), + ), ), ids=lambda x: str(x[0]), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index d157edde1a..0d9a55ceef 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,7 +8,7 @@ # TODO(SF-N): test scan operator -from typing import Iterable, Literal, Optional, Union +from typing import Iterable, Literal, Optional import numpy as np import pytest @@ -26,6 +26,7 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) @@ -1225,6 +1226,116 @@ def test_never_accessed_domain_tuple(offset_provider): run_test_expr(testee, testee, domain, expected_domains, offset_provider) +def test_concat_where(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 4)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 4)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (4, 11)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 10)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (10, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (10, 20), JDim: (10, 30)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_concat_where_two_dimensions_J(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 30)}) + domain_cond = im.domain(common.GridType.CARTESIAN, {JDim: (20, itir.InfinityLiteral.POSITIVE)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (20, 30)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + testee = im.concat_where( + domain_cond, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ) + + expected = im.concat_where( + domain_cond, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + +def test_nested_concat_where_two_dimensions(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 20)}) + domain_cond1 = im.domain(common.GridType.CARTESIAN, {JDim: (10, itir.InfinityLiteral.POSITIVE)}) + domain_cond2 = im.domain(common.GridType.CARTESIAN, {IDim: (itir.InfinityLiteral.NEGATIVE, 20)}) + domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 20), JDim: (10, 20)}) + domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (20, 30), JDim: (10, 20)}) + domain3 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 30), JDim: (0, 10)}) + testee = im.concat_where( + domain_cond1, + im.concat_where( + domain_cond2, im.as_fieldop("deref")("in_field1"), im.as_fieldop("deref")("in_field2") + ), + im.as_fieldop("deref")("in_field3"), + ) + + expected = im.concat_where( + domain_cond1, # 0, 30; 10,20 + im.concat_where( + domain_cond2, + im.as_fieldop("deref", domain1)("in_field1"), + im.as_fieldop("deref", domain2)("in_field2"), + ), + im.as_fieldop("deref", domain3)("in_field3"), + ) + expected_domains = {"in_field1": domain1, "in_field2": domain2, "in_field3": domain3} + + actual_call, actual_domains = infer_domain.infer_expr( + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider + ) + + folded_call = constant_fold_domain_exprs(actual_call) + assert expected == folded_call + assert expected_domains == constant_fold_accessed_domains(actual_domains) + + def test_broadcast(offset_provider): testee = im.call("broadcast")("in_field", im.make_tuple(itir.AxisLiteral(value="IDim"))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)})