Fix NCCL/RCCL rendezvous deadlock when mixing clique sizes#589
Open
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.8.0from
Open
Fix NCCL/RCCL rendezvous deadlock when mixing clique sizes#589phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.8.0from
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.8.0from
Conversation
When NCCL comm splitting is enabled, running operations with different device counts sequentially (e.g., 2-device then 4-device) can cause rendezvous deadlocks. This happens because: 1. A standalone clique [0,1] is created for 2-device operations 2. Later, a parent clique [0,1,2,3] is created for 4-device operations 3. When [2,3] needs a communicator, it splits from the parent [0,1,2,3] 4. But [0,1] still uses the old standalone communicator 5. The split communicators expect all devices to participate together, causing a deadlock The fix invalidates any cached subset cliques when a new superset clique is created. This ensures subset operations will create properly split communicators from the parent clique rather than reusing stale standalone communicators. This issue commonly manifests in test suites that run distributed tests with varying device counts (e.g., JAX TransformerEngine tests). Performance impact: Negligible - the check only runs during clique creation (initialization), not during collective operations.
i-chaochen
requested changes
Jan 27, 2026
Author
|
I still could not repoduce yet, the error comes from TE test case in the ticket https://ontrack-internal.amd.com/browse/SWDEV-571637 |
Collaborator
|
Thanks! Just to clarify, if you can have a jax reproduce like this commit that will be great! |
Author
|
It can produce here, You can test with docker |
Author
|
Upstream pull request : openxla#37063 |
i-chaochen
approved these changes
Jan 30, 2026
Collaborator
i-chaochen
left a comment
There was a problem hiding this comment.
Thanks for upstreaming this. Let's wait upstream feedback and we merge it into our release branches.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When NCCL comm splitting is enabled, running operations with different device counts sequentially (e.g., 2-device then 4-device) can cause rendezvous deadlocks. This happens because:
The fix invalidates any cached subset cliques when a new superset clique is created. This ensures subset operations will create properly split communicators from the parent clique rather than reusing stale standalone communicators.
This issue commonly manifests in test suites that run distributed tests with varying device counts (e.g., JAX TransformerEngine tests).
Performance impact: Negligible - the check only runs during clique creation (initialization), not during collective operations.