Skip to content

Conversation

@edopao
Copy link
Contributor

@edopao edopao commented Jul 11, 2025

Lowering to SDFG of concat_where primitive, follows GTIR PR #1713.
Test coverage provided in previous PR.

tehrengruber and others added 30 commits February 20, 2025 17:26
Broadcast scalar to array with the expected domain in order
to avoid subset validation issue on scalar, in case of empty domain.
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some comments, but I do not see a major problem.

# The strict order of lower and upper bounds of output domain is not guaranteed,
# for concat_where expressions, so we apply a runtime check.
out_shape[concat_dim_index] = dace.symbolic.pystr_to_symbolic(
f"max(0, {out_shape[concat_dim_index]})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not fully sure but shouldn't you use origin here.
Is it because you incorporated origin in the index computation below?
If so I would add a note to explain that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the array shape, so the minimum size is always zero. I use origin in the computation of the memlet subset.

upper_bound = output_domain[concat_dim_index][2]
upper_domain = [(concat_dim, concat_dim_bound, upper_bound)]

if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does these cases mean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Description added.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being persistent and a bit low here, but I think you should also add some typical concat_where() expressions to make it simpler to understand such as concat_where(k < 0, ij_field, 1.0) or something.
Because, I think that it is on a conceptional level simple, but its actual implementation is rather involved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

Comment on lines +275 to +278
upper_range_1 = output_domain[concat_dim_index][2]
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
upper_range_1 = output_domain[concat_dim_index][2]
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
)
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
)
upper_range_1 = output_domain[concat_dim_index][2]

Now they have the same order as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use upper_range_1 in the computation of upper_range_0, that is why they appear in this order.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it now.
I guess there was some tricky thinking involved to get that right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All unit tests still pass even if I switch the order. Honestly, I do not remember why I chose this order in the end. Probably, I wanted to follow the principle that the far left-hand side and right-hand side correspond exactly to the output domain, without any symbolic manipulation. Instead, I only manipulate the neighborhood of the concat boundary:

In other words, I could have written:

    lower_range_0, upper_range_1 = output_domain[concat_dim_index][1:3]
    lower_range_1 = dace.symbolic.pystr_to_symbolic(
        f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})"
    )
    upper_range_0 = dace.symbolic.pystr_to_symbolic(
        f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
    )

Copy link
Contributor Author

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review comments, I agree it looks better now.

# The strict order of lower and upper bounds of output domain is not guaranteed,
# for concat_where expressions, so we apply a runtime check.
out_shape[concat_dim_index] = dace.symbolic.pystr_to_symbolic(
f"max(0, {out_shape[concat_dim_index]})"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the array shape, so the minimum size is always zero. I use origin in the computation of the memlet subset.

Comment on lines +275 to +278
upper_range_1 = output_domain[concat_dim_index][2]
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use upper_range_1 in the computation of upper_range_0, that is why they appear in this order.

upper_bound = output_domain[concat_dim_index][2]
upper_domain = [(concat_dim, concat_dim_bound, upper_bound)]

if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Description added.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is more with me.
I guess that from a conceptional point it is relatively simple, but its implementation is quite involved and there are some things I do not fully get to be honest.

Comment on lines +275 to +278
upper_range_1 = output_domain[concat_dim_index][2]
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it now.
I guess there was some tricky thinking involved to get that right.

origin = tuple(
[*field.origin[:concat_dim_index], concat_dim_origin, *field.origin[concat_dim_index:]]
)
shape = tuple([*field_desc.shape[:concat_dim_index], 1, *field_desc.shape[concat_dim_index:]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand that the concat_dim_index dimension alsways has length 1?
Or rather what happens if a slice is higher than one level?

upper_bound = output_domain[concat_dim_index][2]
upper_domain = [(concat_dim, concat_dim_bound, upper_bound)]

if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being persistent and a bit low here, but I think you should also add some typical concat_where() expressions to make it simpler to understand such as concat_where(k < 0, ij_field, 1.0) or something.
Because, I think that it is on a conceptional level simple, but its actual implementation is rather involved.

Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@edopao edopao merged commit 0cad969 into GridTools:main Jul 14, 2025
23 checks passed
@edopao edopao deleted the gtir-dace-concat_where branch July 14, 2025 13:41
edopao added a commit that referenced this pull request Jul 14, 2025
…ange (#2142)

This PR introduces a dataclass type to represent the domain range in the
lowering to SDFG.
It follows a suggestion from @philip-paul-mueller in the review of #2137.
philip-paul-mueller pushed a commit to philip-paul-mueller/gt4py that referenced this pull request Jul 15, 2025
Lowering to SDFG of `concat_where` primitive, follows GTIR PR GridTools#1713.
Test coverage provided in previous PR.
philip-paul-mueller pushed a commit to philip-paul-mueller/gt4py that referenced this pull request Jul 15, 2025
…ange (GridTools#2142)

This PR introduces a dataclass type to represent the domain range in the
lowering to SDFG.
It follows a suggestion from @philip-paul-mueller in the review of GridTools#2137.
stubbiali pushed a commit to stubbiali/gt4py that referenced this pull request Aug 19, 2025
Lowering to SDFG of `concat_where` primitive, follows GTIR PR GridTools#1713.
Test coverage provided in previous PR.
stubbiali pushed a commit to stubbiali/gt4py that referenced this pull request Aug 19, 2025
…ange (GridTools#2142)

This PR introduces a dataclass type to represent the domain range in the
lowering to SDFG.
It follows a suggestion from @philip-paul-mueller in the review of GridTools#2137.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants