diff --git a/xla/stream_executor/cuda/cuda_blas.cc b/xla/stream_executor/cuda/cuda_blas.cc index cf946b68cae91..10c311e843097 100644 --- a/xla/stream_executor/cuda/cuda_blas.cc +++ b/xla/stream_executor/cuda/cuda_blas.cc @@ -1399,6 +1399,16 @@ absl::Status CUDABlas::GetVersion(std::string *version) { } void initialize_cublas() { + // Check if already registered before attempting - prevents duplicate + // registration error messages (can happen with multiple library loads) + auto already_registered = PluginRegistry::Instance()->HasFactory( + kCudaPlatformId, PluginKind::kBlas); + + if (already_registered) { + // Already registered, skip silently (mimics ROCm behavior) + return; + } + absl::Status status = PluginRegistry::Instance()->RegisterFactory( kCudaPlatformId, "cuBLAS", diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 631b8489c22c7..7cc04ecacd11a 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -6947,6 +6947,16 @@ absl::Status CudnnGraph::PopulateOrUpdateRawCommandBuffer( } // namespace gpu void initialize_cudnn() { + // Check if already registered before attempting - prevents duplicate + // registration error messages (can happen with multiple library loads) + auto already_registered = PluginRegistry::Instance()->HasFactory( + cuda::kCudaPlatformId, PluginKind::kDnn); + + if (already_registered) { + // Already registered, skip silently (mimics ROCm behavior) + return; + } + absl::Status status = PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, "cuDNN", diff --git a/xla/stream_executor/cuda/cuda_fft.cc b/xla/stream_executor/cuda/cuda_fft.cc index 5500b2c4586cd..54568733d4fe6 100644 --- a/xla/stream_executor/cuda/cuda_fft.cc +++ b/xla/stream_executor/cuda/cuda_fft.cc @@ -460,6 +460,16 @@ STREAM_EXECUTOR_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D) } // namespace gpu void initialize_cufft() { + // Check if already registered before attempting - prevents duplicate + // registration error messages (can happen with multiple library loads) + auto already_registered = PluginRegistry::Instance()->HasFactory( + cuda::kCudaPlatformId, PluginKind::kFft); + + if (already_registered) { + // Already registered, skip silently (mimics ROCm behavior) + return; + } + absl::Status status = PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, "cuFFT",