Skip to content

Comments

Fix crash manifesting after pr226#515

Draft
mminutoli wants to merge 2 commits intorocm-jaxlib-v0.8.0from
fix_crash_manifesting_after_pr226
Draft

Fix crash manifesting after pr226#515
mminutoli wants to merge 2 commits intorocm-jaxlib-v0.8.0from
fix_crash_manifesting_after_pr226

Conversation

@mminutoli
Copy link

Motivation

This PR fixes a critical segmentation fault that occurs in PjRtStreamExecutorLoadedExecutable::Execute when logging error messages during multi-device execution failures. The crash happens at line 3154 when attempting to access first_failure_status.message() on an uninitialized status object.

The segfault was masking underlying execution failures and making debugging extremely difficult. This fix enables proper error reporting so that the actual root causes of execution failures can be identified and addressed.

Technical Details

Changes Made

File: xla/pjrt/pjrt_stream_executor_client.cc

Three related fixes to prevent crashes and undefined behavior:

  1. Defensive check before accessing status message (Lines 3154-3165)

    • Added null check: if (!first_failure_status.ok()) before calling .message()
    • Prevents segfault when first_failure_status is still in default OK state
    • Provides diagnostic message "(no failure status captured)" when this occurs
  2. Fix memory safety with ProfilingContext (Line 3106)

    • Changed from unique_ptr to shared_ptr for ProfilingContext
    • Ensures object lifetime extends to all scheduled lambda executions
  3. Explicit lambda capture (Line 3119)

    • Added pc to lambda capture list: [&, pc, replica, partition, i]
    • Increments shared_ptr reference count for thread safety

Root Cause

The segfault occurs due to a race condition where:

  • Worker threads increment failed counter, waking the main thread
  • Main thread proceeds to log the error message
  • But first_failure_status hasn't been assigned yet by any worker thread
  • Accessing .message() on default-constructed (OK) status causes segfault

Note: This PR does not eliminate the underlying race condition, but makes the code defensive against it by preventing the segfault. The race would require more extensive synchronization refactoring to fully resolve.

Test Plan

Manual Testing

  • Run: rocgdb --args python -m pytest jax/tests/shard_map_test.py::ShardMapTest::test_all_gather_with_axis_index_groups

Test Results

Behavior Changes

Before: Segmentation fault when logging error messages during multi-device failures
After: Segmentation fault in a different place; more investigation is currently required.

Known Limitations

The underlying race condition on first_failure_status is not eliminated. This is a defensive fix that prevents crashes and enables proper error reporting. A complete fix would require refactoring the synchronization logic.

Submission Checklist

The use of a unique pointer for pc is incorrect here as the
lambda capturing the smart pointer can be executed after the unique
pointer goes out of scope, and therefore, releasing the memory
before its intended use.
There is a race condition on the first_failure_status.
The initial state of first_failture_status is OK and that does not
have a message associated with it. As a consequence, logging crashes
with a segfault when the program logs before that status changes to
something be not OK.

I couldn't fix the race with the time I had available. Nevertheless,
the real problem is that we have failures when running those tasks.
This patch solves the crash at the log and uncorvers another one that
needs triaging.
@Arech8 Arech8 added cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. open-upstream Tag when you want a copy of this PR to be opened on upstream labels Jan 13, 2026
Copy link

@Arech8 Arech8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging needs to be fixed, otherwise a great catch, Marco, thanks!

Comment on lines -3158 to +3164
"the error log for all failures): \n\n"
<< first_failure_status.message();
"the error log for all failures): \n\n";

if (!first_failure_status.ok()) {
LOG(FATAL) << first_failure_status.message();
} else {
LOG(FATAL) << "(no failure status captured)";
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(FATAL) unfolds to a no-return function call, IIRC std::abort() or smth like that. You need to put everything into a single call to the logger:

LOG(FATAL) << "Replicated computation...bla-bla-bla"
   << ( first_failure_status.ok()
       ? "(no failure status captured)"
       : first_failure_status.message().c_str() ) // check if .c_str() is needed here, or smth else

@Arech8
Copy link

Arech8 commented Jan 13, 2026

Oh, ok, I see it's more involved than just that. Ok, I'm taking over it...

@Arech8 Arech8 removed cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. open-upstream Tag when you want a copy of this PR to be opened on upstream labels Jan 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants