[ROCm] Add thread safety test for hipblaslt GetAlgorithms#609
Draft
i-chaochen wants to merge 2 commits intorocm-jaxlib-v0.8.0from
Draft
[ROCm] Add thread safety test for hipblaslt GetAlgorithms#609i-chaochen wants to merge 2 commits intorocm-jaxlib-v0.8.0from
i-chaochen wants to merge 2 commits intorocm-jaxlib-v0.8.0from
Conversation
Add unit tests to detect potential data races in hipblaslt's MasterSolutionLibrary and ContractionSolution when multiple threads concurrently call GetAlgorithms. This test is designed to help catch race conditions that can cause segfaults in multi-threaded JAX workloads using BF16 GEMM operations, where: - One thread loads solution libraries (loadLibrary) - Another thread calls getSolutionByIndex concurrently - Mutable state like autoGSU is modified without synchronization The test exercises: 1. Concurrent GetAlgorithms calls on the same plan 2. Concurrent access with different problem sizes (triggers lazy loading)
88e394c to
0b5bdc1
Compare
…er buffer race Add two new test cases to hip_blas_lt_thread_test.cc to exercise concurrent GEMM execution paths that may trigger the hipblaslt synchronizer buffer race: - MultiStreamGemmExecutionRace: Creates multiple streams with a shared BlasLt plan and executes concurrent GEMM operations, verifying result correctness. - FireAndForgetGemmRace: More aggressive test with 8 streams and rapid GEMM launches to maximize GPU kernel overlap. These tests allocate device memory, execute actual hipblasLtMatmul operations via XLA's BlasLt interface, and verify output correctness. While they may not reliably reproduce the GSU synchronizer buffer race (which requires specific algorithm selection like GSUAMBSK), they serve as: - Stress tests for concurrent GEMM execution - Regression tests after hipblaslt fixes are applied - Documentation of the known multi-stream race condition pattern Also adds scratch_allocator dependency to BUILD file.
bc34256 to
d29ad35
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add unit tests to detect potential data races in hipblaslt's MasterSolutionLibrary and ContractionSolution when multiple threads concurrently call GetAlgorithms.
Motivation
This test is designed to help catch race conditions that can cause segfaults in multi-threaded JAX workloads using BF16 GEMM operations, where:
Technical Details
The test exercises:
[ROCm] Add multi-stream GEMM execution tests for hipblaslt synchronizer buffer race
Add two new test cases to hip_blas_lt_thread_test.cc to exercise concurrent
GEMM execution paths that may trigger the hipblaslt synchronizer buffer race:
MultiStreamGemmExecutionRace: Creates multiple streams with a shared BlasLt
plan and executes concurrent GEMM operations, verifying result correctness.
FireAndForgetGemmRace: More aggressive test with 8 streams and rapid GEMM
launches to maximize GPU kernel overlap.
These tests allocate device memory, execute actual hipblasLtMatmul operations
via XLA's BlasLt interface, and verify output correctness. While they may not
reliably reproduce the GSU synchronizer buffer race, they serve as: