Fix crash manifesting after pr226#515
Draft
mminutoli wants to merge 2 commits intorocm-jaxlib-v0.8.0from
Draft
Conversation
The use of a unique pointer for pc is incorrect here as the lambda capturing the smart pointer can be executed after the unique pointer goes out of scope, and therefore, releasing the memory before its intended use.
There is a race condition on the first_failure_status. The initial state of first_failture_status is OK and that does not have a message associated with it. As a consequence, logging crashes with a segfault when the program logs before that status changes to something be not OK. I couldn't fix the race with the time I had available. Nevertheless, the real problem is that we have failures when running those tasks. This patch solves the crash at the log and uncorvers another one that needs triaging.
Arech8
requested changes
Jan 13, 2026
Arech8
left a comment
There was a problem hiding this comment.
Logging needs to be fixed, otherwise a great catch, Marco, thanks!
Comment on lines
-3158
to
+3164
| "the error log for all failures): \n\n" | ||
| << first_failure_status.message(); | ||
| "the error log for all failures): \n\n"; | ||
|
|
||
| if (!first_failure_status.ok()) { | ||
| LOG(FATAL) << first_failure_status.message(); | ||
| } else { | ||
| LOG(FATAL) << "(no failure status captured)"; | ||
| } |
There was a problem hiding this comment.
LOG(FATAL) unfolds to a no-return function call, IIRC std::abort() or smth like that. You need to put everything into a single call to the logger:
LOG(FATAL) << "Replicated computation...bla-bla-bla"
<< ( first_failure_status.ok()
? "(no failure status captured)"
: first_failure_status.message().c_str() ) // check if .c_str() is needed here, or smth else|
Oh, ok, I see it's more involved than just that. Ok, I'm taking over it... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR fixes a critical segmentation fault that occurs in
PjRtStreamExecutorLoadedExecutable::Executewhen logging error messages during multi-device execution failures. The crash happens at line 3154 when attempting to accessfirst_failure_status.message()on an uninitialized status object.The segfault was masking underlying execution failures and making debugging extremely difficult. This fix enables proper error reporting so that the actual root causes of execution failures can be identified and addressed.
Technical Details
Changes Made
File:
xla/pjrt/pjrt_stream_executor_client.ccThree related fixes to prevent crashes and undefined behavior:
Defensive check before accessing status message (Lines 3154-3165)
if (!first_failure_status.ok())before calling.message()first_failure_statusis still in default OK stateFix memory safety with ProfilingContext (Line 3106)
unique_ptrtoshared_ptrfor ProfilingContextExplicit lambda capture (Line 3119)
pcto lambda capture list:[&, pc, replica, partition, i]Root Cause
The segfault occurs due to a race condition where:
failedcounter, waking the main threadfirst_failure_statushasn't been assigned yet by any worker thread.message()on default-constructed (OK) status causes segfaultNote: This PR does not eliminate the underlying race condition, but makes the code defensive against it by preventing the segfault. The race would require more extensive synchronization refactoring to fully resolve.
Test Plan
Manual Testing
rocgdb --args python -m pytest jax/tests/shard_map_test.py::ShardMapTest::test_all_gather_with_axis_index_groupsTest Results
Behavior Changes
Before: Segmentation fault when logging error messages during multi-device failures
After: Segmentation fault in a different place; more investigation is currently required.
Known Limitations
The underlying race condition on
first_failure_statusis not eliminated. This is a defensive fix that prevents crashes and enables proper error reporting. A complete fix would require refactoring the synchronization logic.Submission Checklist