Skip to content

Comments

Fix NCCL/RCCL rendezvous deadlock when mixing clique sizes#589

Open
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.8.0from
phambinh/fix-nccl-clique-subset-invalidation
Open

Fix NCCL/RCCL rendezvous deadlock when mixing clique sizes#589
phambinhfin wants to merge 1 commit intorocm-jaxlib-v0.8.0from
phambinh/fix-nccl-clique-subset-invalidation

Conversation

@phambinhfin
Copy link

When NCCL comm splitting is enabled, running operations with different device counts sequentially (e.g., 2-device then 4-device) can cause rendezvous deadlocks. This happens because:

  1. A standalone clique [0,1] is created for 2-device operations
  2. Later, a parent clique [0,1,2,3] is created for 4-device operations
  3. When [2,3] needs a communicator, it splits from the parent [0,1,2,3]
  4. But [0,1] still uses the old standalone communicator
  5. The split communicators expect all devices to participate together, causing a deadlock

The fix invalidates any cached subset cliques when a new superset clique is created. This ensures subset operations will create properly split communicators from the parent clique rather than reusing stale standalone communicators.

This issue commonly manifests in test suites that run distributed tests with varying device counts (e.g., JAX TransformerEngine tests).

Performance impact: Negligible - the check only runs during clique creation (initialization), not during collective operations.

When NCCL comm splitting is enabled, running operations with different
device counts sequentially (e.g., 2-device then 4-device) can cause
rendezvous deadlocks. This happens because:

1. A standalone clique [0,1] is created for 2-device operations
2. Later, a parent clique [0,1,2,3] is created for 4-device operations
3. When [2,3] needs a communicator, it splits from the parent [0,1,2,3]
4. But [0,1] still uses the old standalone communicator
5. The split communicators expect all devices to participate together,
   causing a deadlock

The fix invalidates any cached subset cliques when a new superset clique
is created. This ensures subset operations will create properly split
communicators from the parent clique rather than reusing stale standalone
communicators.

This issue commonly manifests in test suites that run distributed tests
with varying device counts (e.g., JAX TransformerEngine tests).

Performance impact: Negligible - the check only runs during clique
creation (initialization), not during collective operations.
Copy link
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide the reproduce code/test for the rendezvous deadlock when mixing clique sizes?

In addition, please add UT for your changes.

@i-chaochen i-chaochen requested a review from ScXfjiang January 27, 2026 22:23
@i-chaochen i-chaochen added bug Something isn't working cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. Upstream rocm-jaxlib-v0.8.0 labels Jan 27, 2026
@phambinhfin
Copy link
Author

I still could not repoduce yet, the error comes from TE test case in the ticket https://ontrack-internal.amd.com/browse/SWDEV-571637

 cd /workspace/TransformerEngine/tests/jax
timeout 120 pytest -v -s ./test_distributed_softmax.py -k "test_softmax[False-True-float16-1.0-SoftmaxType.SCALED-data_shape0- and tp2"
========================================= test session starts ==========================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
metadata: {'Python': '3.12.3', 'Platform': 'Linux-6.2.0-25-generic-x86_64-with-glibc2.39', 'Packages': {'pytest': '9.0.1', 'pluggy': '1.6.0'}, 'Plugins': {'hypothesis': '6.148.5', 'html': '4.1.1', 'json-report': '1.5.0', 'metadata': '3.1.1', 'reportlog': '1.0.0', 'rerunfailures': '16.1'}}
rootdir: /workspace/TransformerEngine/tests/jax
configfile: pytest.ini
plugins: hypothesis-6.148.5, html-4.1.1, json-report-1.5.0, metadata-3.1.1, reportlog-1.0.0, rerunfailures-16.1
collected 312 items / 310 deselected / 2 selected                                                      

test_distributed_softmax.py::TestDistributedSoftmax::test_softmax[False-True-float16-1.0-SoftmaxType.SCALED-data_shape0-n2_dp1_tp2] W0129 14:40:22.973233  349919 sharding_propagation.cc:3126] GSPMD sharding propagation is going to be deprecated and not supported in the future. Please consider migrating to Shardy (https://openxla.org/shardy). For reference, Shardy is already the default partitioner in JAX.
W0129 14:40:27.270130  349919 sharding_propagation.cc:3126] GSPMD sharding propagation is going to be deprecated and not supported in the future. Please consider migrating to Shardy (https://openxla.org/shardy). For reference, Shardy is already the default partitioner in JAX.
W0129 14:40:27.563850  349919 sharding_propagation.cc:3126] GSPMD sharding propagation is going to be deprecated and not supported in the future. Please consider migrating to Shardy (https://openxla.org/shardy). For reference, Shardy is already the default partitioner in JAX.
W0129 14:40:27.875302  349919 sharding_propagation.cc:3126] GSPMD sharding propagation is going to be deprecated and not supported in the future. Please consider migrating to Shardy (https://openxla.org/shardy). For reference, Shardy is already the default partitioner in JAX.
PASSED
test_distributed_softmax.py::TestDistributedSoftmax::test_softmax[False-True-float16-1.0-SoftmaxType.SCALED-data_shape0-n4_dp2_tp2] W0129 14:40:30.632970  349919 sharding_propagation.cc:3126] GSPMD sharding propagation is going to be deprecated and not supported in the future. Please consider migrating to Shardy (https://openxla.org/shardy). For reference, Shardy is already the default partitioner in JAX.
W0129 14:40:41.560363  349919 sharding_propagation.cc:3126] GSPMD sharding propagation is going to be deprecated and not supported in the future. Please consider migrating to Shardy (https://openxla.org/shardy). For reference, Shardy is already the default partitioner in JAX.
E0129 14:40:51.853999  350697 rendezvous.cc:92] [id=0] This thread has been waiting for `initialize clique for rank 0; clique=devices=[2,3]; is_p2p=0; root_device=-1; num_local_participants=2; incarnations=[]; run_id=934432511; parent=devices=[0,1,2,3]; is_p2p=0; root_device=-1; num_local_participants=4; incarnations=[]` for 10 seconds and may be stuck. All 2 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
E0129 14:40:51.854047  350688 rendezvous.cc:100] [id=1] This thread has been waiting for `thunk initialization completion for device ordinal 0; run_id=934432511` for 10 seconds and may be stuck. Expected 4 threads to join the rendezvous, but not all of them arrived on time.
E0129 14:40:51.854047  350694 rendezvous.cc:100] [id=0] This thread has been waiting for `thunk initialization completion for device ordinal 1; run_id=934432511` for 10 seconds and may be stuck. Expected 4 threads to join the rendezvous, but not all of them arrived on time.
F0129 14:41:21.854192  350688 rendezvous.cc:127] [id=1] Termination timeout for `thunk initialization completion for device ordinal 0; run_id=934432511` of 30 seconds exceeded. Exiting to ensure a consistent program state. Expected 4 threads to join the rendezvous, but only 2 of them arrived on time.
F0129 14:41:21.854255  350694 rendezvous.cc:127] [id=0] Termination timeout for `thunk initialization completion for device ordinal 1; run_id=934432511` of 30 seconds exceeded. Exiting to ensure a consistent program state. Expected 4 threads to join the rendezvous, but only 2 of them arrived on time.
*** Check failure stack trace: ***
    @     0x7f5ad9f84834  absl::lts_20250814::log_internal::LogMessage::SendToLog()
    @     0x7f5ad9f847b6  absl::lts_20250814::log_internal::LogMessage::Flush()
    @     0x7f5ad4e398c1  xla::internal::AwaitAndLogIfStuck()
    @     0x7f5ace16d1f5  xla::gpu::(anonymous namespace)::ExecuteThunksImpl()
    @     0x7f5ace16b184  xla::gpu::GpuExecutable::ExecuteThunks()
    @     0x7f5acd4be2df  xla::StreamExecutorGpuClient::RunAsync()
    @     0x7f5acd4e88cd  xla::PjRtStreamExecutorLoadedExecutable::EnqueueExecution()
    @     0x7f5acd4ed73b  xla::PjRtStreamExecutorLoadedExecutable::ExecuteHelper()
    @     0x7f5acd50f93f  absl::lts_20250814::internal_any_invocable::RemoteInvoker<>()
    @     0x7f5acd581952  xla::WorkerThread::WorkLoop()
    @     0x7f5ad9d80448  tsl::(anonymous namespace)::PThread::ThreadFn()
    @     0x7f9c6cae4aa4  (unknown)
    @     0x7f9c6cb71c6c  (unknown)
Fatal Python error: Aborted

@i-chaochen
Copy link
Collaborator

i-chaochen commented Jan 29, 2026

Thanks! Just to clarify, if you can have a jax reproduce like this commit that will be great!

openxla#15935
openxla@fde6a0e

@phambinhfin
Copy link
Author

phambinhfin commented Jan 29, 2026

It can produce here, You can test with docker registry-sc-harbor.amd.com/framework/compute-rocm-rel-7.2:10_ubuntu24.04_py3.12_jax_rocm-jaxlib-v0.8.0_eeb59d9
It pass when set
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" also

#!/usr/bin/env python3
"""
RCCL clique deadlock reproducer using PURE JAX (no TransformerEngine).

This proves the bug is in XLA/RCCL, not TransformerEngine.
The deadlock occurs when:
1. 2-GPU operations run with different mesh axes (DP, then TP)
2. 4-GPU operations run afterward
3. Sharding patterns trigger collectives on hidden dimension

Run: python /workspace/reproduce_rccl_pure_jax.py
"""
import os
os.environ["JAX_TRACEBACK_FILTERING"] = "off"

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src.pjit import pjit

jax.config.update("jax_use_shardy_partitioner", False)


def run_test(n_devices, mesh_shape, axes, dp_res, tp_res):
    """Run pure JAX operations with specific sharding to trigger collectives."""
    print(f"Running {n_devices}-GPU test with axes={axes}...")
    
    devices = np.array(jax.devices()[:n_devices]).reshape(mesh_shape)
    mesh = Mesh(devices, axes)
    
    # Shard on hidden dim (like bad_sharding=True) to force collectives
    x_spec = P(dp_res, None, None, tp_res)
    
    x = random.normal(random.PRNGKey(0), [32, 12, 128, 128], jnp.float16)
    
    # Function 1: softmax + reduction (forces all-reduce)
    def fn1(x):
        y = jax.nn.softmax(x, axis=-1)  # softmax on sharded dim
        return jnp.mean(y)
    
    # Function 2: different pattern - layer norm style
    def fn2(x):
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        y = (x - mean) / jnp.sqrt(var + 1e-5)
        return jnp.mean(y)
    
    with mesh:
        x_ = jax.device_put(x, NamedSharding(mesh, x_spec))
        
        # Run first function with grad
        jit1 = pjit(jax.value_and_grad(fn1), in_shardings=x_spec, out_shardings=(None, x_spec))
        fwd1, grad1 = jit1(x_)
        
        # Get HLO (triggers compilation)
        _ = jit1.lower(x_).compile().as_text()
        
        # Run second function with grad - different collective pattern
        jit2 = pjit(jax.value_and_grad(fn2), in_shardings=x_spec, out_shardings=(None, x_spec))
        fwd2, grad2 = jit2(x_)
        grad2.block_until_ready()
    
    print(f"  {n_devices}-GPU PASSED")


if __name__ == "__main__":
    print(f"JAX version: {jax.__version__}")
    print(f"Devices: {jax.devices()[:4]}")
    print()
    
    # Same sequence as TE reproducer:
    # 1. 2-GPU with DP axis
    run_test(2, (2,), ("dp",), "dp", None)
    
    # 2. 2-GPU with TP axis  
    run_test(2, (2,), ("tp",), None, "tp")
    
    # 3. 4-GPU with DP+TP axes -> Does it deadlock?
    run_test(4, (2,2), ("dp","tp"), "dp", "tp")
    
    print("\nAll passed!")

@phambinhfin
Copy link
Author

Upstream pull request : openxla#37063

Copy link
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for upstreaming this. Let's wait upstream feedback and we merge it into our release branches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. rocm-jaxlib-v0.8.0 Upstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants