Skip to content

[ROCm] Fix reduction emitter tests to work on ROCm#542

Open
nurmukhametov wants to merge 1 commit intorocm-jaxlib-v0.8.2from
anurmukh/fix-reduce-emitter-mof-tests-v0.8.2
Open

[ROCm] Fix reduction emitter tests to work on ROCm#542
nurmukhametov wants to merge 1 commit intorocm-jaxlib-v0.8.2from
anurmukh/fix-reduce-emitter-mof-tests-v0.8.2

Conversation

@nurmukhametov
Copy link

The tests mof_scalar_variadic.hlo and side_output_broadcast.hlo were designed to test the Reduction emitter with Multi-Output Fusions (MOF) containing variadic reductions and side outputs. However, they used f32[6,6] (36 elements) which fails the
IsUnnestedReductionFasterThanElemental heuristic on ROCm (warp_size=64) while passing on CUDA (warp_size=32).

When the heuristic fails, the fusion is routed to the Loop emitter instead of the Reduction emitter. The Loop emitter cannot handle fusions with incompatible output shapes (e.g., f32[] and f32[6,6] in the same tuple), causing a segmentation fault in GetBitcastMap/ApplyIndexing.

Fix by increasing the test dimensions from 6x6 to 8x8 (64 elements), which satisfies the heuristic on both platforms. This ensures the tests exercise the Reduction emitter as originally intended on both CUDA and ROCm.

@i-chaochen i-chaochen requested a review from pemeliya January 20, 2026 17:16
The tests mof_scalar_variadic.hlo and side_output_broadcast.hlo were
designed to test the Reduction emitter with Multi-Output Fusions (MOF)
containing variadic reductions and side outputs. However, they used
f32[6,6] (36 elements) which fails the
IsUnnestedReductionFasterThanElemental heuristic on ROCm (warp_size=64)
while passing on CUDA (warp_size=32).

When the heuristic fails, the fusion is routed to the Loop emitter
instead of the Reduction emitter. The Loop emitter cannot handle fusions
with incompatible output shapes (e.g., f32[] and f32[6,6] in the same
tuple), causing a segmentation fault in GetBitcastMap/ApplyIndexing.

Fix by increasing the test dimensions from 6x6 to 8x8 (64 elements),
which satisfies the heuristic on both platforms. This ensures the tests
exercise the Reduction emitter as originally intended on both CUDA and
ROCm.
@nurmukhametov nurmukhametov force-pushed the anurmukh/fix-reduce-emitter-mof-tests-v0.8.2 branch from b4580c7 to bf3b075 Compare January 20, 2026 18:16
@i-chaochen i-chaochen added rocm-jaxlib-v0.8.2 cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. labels Jan 26, 2026
@nurmukhametov nurmukhametov removed the cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. label Jan 26, 2026
@nurmukhametov
Copy link
Author

I don't think it is cherry-pick-candidate anymore because same was merged upstream openxla#36612

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants