-
Notifications
You must be signed in to change notification settings - Fork 54
feat[next][dace]: Lowering of concat_where to SDFG #2137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This reverts commit ba40dd6.
This reverts commit b439c8b.
Broadcast scalar to array with the expected domain in order to avoid subset validation issue on scalar, in case of empty domain.
philip-paul-mueller
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some comments, but I do not see a major problem.
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
| # The strict order of lower and upper bounds of output domain is not guaranteed, | ||
| # for concat_where expressions, so we apply a runtime check. | ||
| out_shape[concat_dim_index] = dace.symbolic.pystr_to_symbolic( | ||
| f"max(0, {out_shape[concat_dim_index]})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not fully sure but shouldn't you use origin here.
Is it because you incorporated origin in the index computation below?
If so I would add a note to explain that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the array shape, so the minimum size is always zero. I use origin in the computation of the memlet subset.
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Show resolved
Hide resolved
| upper_bound = output_domain[concat_dim_index][2] | ||
| upper_domain = [(concat_dim, concat_dim_bound, upper_bound)] | ||
|
|
||
| if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does these cases mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for being persistent and a bit low here, but I think you should also add some typical concat_where() expressions to make it simpler to understand such as concat_where(k < 0, ij_field, 1.0) or something.
Because, I think that it is on a conceptional level simple, but its actual implementation is rather involved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
| upper_range_1 = output_domain[concat_dim_index][2] | ||
| upper_range_0 = dace.symbolic.pystr_to_symbolic( | ||
| f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| upper_range_1 = output_domain[concat_dim_index][2] | |
| upper_range_0 = dace.symbolic.pystr_to_symbolic( | |
| f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" | |
| ) | |
| upper_range_0 = dace.symbolic.pystr_to_symbolic( | |
| f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" | |
| ) | |
| upper_range_1 = output_domain[concat_dim_index][2] |
Now they have the same order as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use upper_range_1 in the computation of upper_range_0, that is why they appear in this order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it now.
I guess there was some tricky thinking involved to get that right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All unit tests still pass even if I switch the order. Honestly, I do not remember why I chose this order in the end. Probably, I wanted to follow the principle that the far left-hand side and right-hand side correspond exactly to the output domain, without any symbolic manipulation. Instead, I only manipulate the neighborhood of the concat boundary:
In other words, I could have written:
lower_range_0, upper_range_1 = output_domain[concat_dim_index][1:3]
lower_range_1 = dace.symbolic.pystr_to_symbolic(
f"max({lower_range_0}, {lower_domain[concat_dim_index][2]})"
)
upper_range_0 = dace.symbolic.pystr_to_symbolic(
f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})"
)
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
edopao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review comments, I agree it looks better now.
src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Outdated
Show resolved
Hide resolved
| # The strict order of lower and upper bounds of output domain is not guaranteed, | ||
| # for concat_where expressions, so we apply a runtime check. | ||
| out_shape[concat_dim_index] = dace.symbolic.pystr_to_symbolic( | ||
| f"max(0, {out_shape[concat_dim_index]})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the array shape, so the minimum size is always zero. I use origin in the computation of the memlet subset.
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
| upper_range_1 = output_domain[concat_dim_index][2] | ||
| upper_range_0 = dace.symbolic.pystr_to_symbolic( | ||
| f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use upper_range_1 in the computation of upper_range_0, that is why they appear in this order.
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Outdated
Show resolved
Hide resolved
| upper_bound = output_domain[concat_dim_index][2] | ||
| upper_domain = [(concat_dim, concat_dim_bound, upper_bound)] | ||
|
|
||
| if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description added.
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Show resolved
Hide resolved
philip-paul-mueller
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the problem is more with me.
I guess that from a conceptional point it is relatively simple, but its implementation is quite involved and there are some things I do not fully get to be honest.
| upper_range_1 = output_domain[concat_dim_index][2] | ||
| upper_range_0 = dace.symbolic.pystr_to_symbolic( | ||
| f"min({upper_range_1}, {upper_domain[concat_dim_index][1]})" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it now.
I guess there was some tricky thinking involved to get that right.
| origin = tuple( | ||
| [*field.origin[:concat_dim_index], concat_dim_origin, *field.origin[concat_dim_index:]] | ||
| ) | ||
| shape = tuple([*field_desc.shape[:concat_dim_index], 1, *field_desc.shape[concat_dim_index:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not understand that the concat_dim_index dimension alsways has length 1?
Or rather what happens if a slice is higher than one level?
src/gt4py/next/program_processors/runners/dace/gtir_to_sdfg_concat_where.py
Show resolved
Hide resolved
| upper_bound = output_domain[concat_dim_index][2] | ||
| upper_domain = [(concat_dim, concat_dim_bound, upper_bound)] | ||
|
|
||
| if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for being persistent and a bit low here, but I think you should also add some typical concat_where() expressions to make it simpler to understand such as concat_where(k < 0, ij_field, 1.0) or something.
Because, I think that it is on a conceptional level simple, but its actual implementation is rather involved.
philip-paul-mueller
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
…ange (#2142) This PR introduces a dataclass type to represent the domain range in the lowering to SDFG. It follows a suggestion from @philip-paul-mueller in the review of #2137.
Lowering to SDFG of `concat_where` primitive, follows GTIR PR GridTools#1713. Test coverage provided in previous PR.
…ange (GridTools#2142) This PR introduces a dataclass type to represent the domain range in the lowering to SDFG. It follows a suggestion from @philip-paul-mueller in the review of GridTools#2137.
Lowering to SDFG of `concat_where` primitive, follows GTIR PR GridTools#1713. Test coverage provided in previous PR.
…ange (GridTools#2142) This PR introduces a dataclass type to represent the domain range in the lowering to SDFG. It follows a suggestion from @philip-paul-mueller in the review of GridTools#2137.
Lowering to SDFG of
concat_whereprimitive, follows GTIR PR #1713.Test coverage provided in previous PR.