diff --git a/cupy_backends/hip/cupy_hip_common.h b/cupy_backends/hip/cupy_hip_common.h index 48337125851..1ced342e28c 100644 --- a/cupy_backends/hip/cupy_hip_common.h +++ b/cupy_backends/hip/cupy_hip_common.h @@ -2,6 +2,7 @@ #define INCLUDE_GUARD_HIP_CUPY_COMMON_H #include +#define HIPBLAS_V2 #if HIP_VERSION >= 50530600 #include #include diff --git a/cupy_backends/hip/cupy_hipblas.h b/cupy_backends/hip/cupy_hipblas.h index 83b36772894..a559c4f93a2 100644 --- a/cupy_backends/hip/cupy_hipblas.h +++ b/cupy_backends/hip/cupy_hipblas.h @@ -49,18 +49,34 @@ static hipblasSideMode_t convert_hipblasSideMode_t(cublasSideMode_t mode) { return static_cast(static_cast(mode) + 141); } -static hipblasDatatype_t convert_hipblasDatatype_t(cudaDataType_t type) { +static hipDataType convert_hipDatatype(cudaDataType_t type) { switch(static_cast(type)) { - case 0 /* CUDA_R_32F */: return HIPBLAS_R_32F; - case 1 /* CUDA_R_64F */: return HIPBLAS_R_64F; - case 2 /* CUDA_R_16F */: return HIPBLAS_R_16F; - case 3 /* CUDA_R_8I */ : return HIPBLAS_R_8I; - case 4 /* CUDA_C_32F */: return HIPBLAS_C_32F; - case 5 /* CUDA_C_64F */: return HIPBLAS_C_64F; - case 6 /* CUDA_C_16F */: return HIPBLAS_C_16F; - case 7 /* CUDA_C_8I */ : return HIPBLAS_C_8I; - case 8 /* CUDA_R_8U */ : return HIPBLAS_R_8U; - case 9 /* CUDA_C_8U */ : return HIPBLAS_C_8U; + case 0 /* CUDA_R_32F */: return HIP_R_32F; + case 1 /* CUDA_R_64F */: return HIP_R_64F; + case 2 /* CUDA_R_16F */: return HIP_R_16F; + case 3 /* CUDA_R_8I */ : return HIP_R_8I; + case 4 /* CUDA_C_32F */: return HIP_C_32F; + case 5 /* CUDA_C_64F */: return HIP_C_64F; + case 6 /* CUDA_C_16F */: return HIP_C_16F; + case 7 /* CUDA_C_8I */ : return HIP_C_8I; + case 8 /* CUDA_R_8U */ : return HIP_R_8U; + case 9 /* CUDA_C_8U */ : return HIP_C_8U; + default: throw std::runtime_error("unrecognized type"); + } +} + +static hipblasComputeType_t convert_hipblasComputeType_t(cudaDataType_t type) { + switch(static_cast(type)) { + case 0 /* CUDA_R_32F */: return HIPBLAS_COMPUTE_32F; + case 1 /* CUDA_R_64F */: return HIPBLAS_COMPUTE_64F; + case 2 /* CUDA_R_16F */: return HIPBLAS_COMPUTE_16F; + case 3 /* CUDA_R_8I */ : return HIPBLAS_COMPUTE_32I; + case 4 /* CUDA_C_32F */: return HIPBLAS_COMPUTE_32F; + case 5 /* CUDA_C_64F */: return HIPBLAS_COMPUTE_64F; + case 6 /* CUDA_C_16F */: return HIPBLAS_COMPUTE_16F; + case 7 /* CUDA_C_8I */ : return HIPBLAS_COMPUTE_32I; + case 8 /* CUDA_R_8U */ : return HIPBLAS_COMPUTE_32I; + case 9 /* CUDA_C_8U */ : return HIPBLAS_COMPUTE_32I; default: throw std::runtime_error("unrecognized type"); } } @@ -119,11 +135,11 @@ cublasStatus_t cublasIdamax(cublasHandle_t handle, int n, const double *x, int i } cublasStatus_t cublasIcamax(cublasHandle_t handle, int n, const cuComplex *x, int incx, int *result) { - return hipblasIcamax(handle, n, reinterpret_cast(x), incx, result); + return hipblasIcamax(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasIzamax(cublasHandle_t handle, int n, const cuDoubleComplex *x, int incx, int *result) { - return hipblasIzamax(handle, n, reinterpret_cast(x), incx, result); + return hipblasIzamax(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasIsamin(cublasHandle_t handle, int n, float* x, int incx, int* result) { @@ -135,11 +151,11 @@ cublasStatus_t cublasIdamin(cublasHandle_t handle, int n, const double *x, int i } cublasStatus_t cublasIcamin(cublasHandle_t handle, int n, const cuComplex *x, int incx, int *result) { - return hipblasIcamin(handle, n, reinterpret_cast(x), incx, result); + return hipblasIcamin(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasIzamin(cublasHandle_t handle, int n, const cuDoubleComplex *x, int incx, int *result) { - return hipblasIzamin(handle, n, reinterpret_cast(x), incx, result); + return hipblasIzamin(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasSasum(cublasHandle_t handle, int n, float* x, int incx, float* result) { @@ -151,11 +167,11 @@ cublasStatus_t cublasDasum(cublasHandle_t handle, int n, double* x, int incx, do } cublasStatus_t cublasScasum(cublasHandle_t handle, int n, cuComplex* x, int incx, float* result) { - return hipblasScasum(handle, n, reinterpret_cast(x), incx, result); + return hipblasScasum(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasDzasum(cublasHandle_t handle, int n, cuDoubleComplex* x, int incx, double* result) { - return hipblasDzasum(handle, n, reinterpret_cast(x), incx, result); + return hipblasDzasum(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasSaxpy(cublasHandle_t handle, int n, float* alpha, float* x, int incx, float* y, int incy) { @@ -168,16 +184,16 @@ cublasStatus_t cublasDaxpy(cublasHandle_t handle, int n, double* alpha, double* cublasStatus_t cublasCaxpy(cublasHandle_t handle, int n, cuComplex* alpha, cuComplex* x, int incx, cuComplex* y, int incy) { return hipblasCaxpy(handle, n, - reinterpret_cast(alpha), - reinterpret_cast(x), incx, - reinterpret_cast(y), incy); + reinterpret_cast(alpha), + reinterpret_cast(x), incx, + reinterpret_cast(y), incy); } cublasStatus_t cublasZaxpy(cublasHandle_t handle, int n, cuDoubleComplex* alpha, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy) { return hipblasZaxpy(handle, n, - reinterpret_cast(alpha), - reinterpret_cast(x), incx, - reinterpret_cast(y), incy); + reinterpret_cast(alpha), + reinterpret_cast(x), incx, + reinterpret_cast(y), incy); } cublasStatus_t cublasSdot(cublasHandle_t handle, int n, float* x, int incx, float* y, int incy, float* result) { @@ -191,33 +207,33 @@ cublasStatus_t cublasDdot(cublasHandle_t handle, int n, double* x, int incx, dou cublasStatus_t cublasCdotu(cublasHandle_t handle, int n, cuComplex* x, int incx, cuComplex* y, int incy, cuComplex* result) { return hipblasCdotu(handle, n, - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); } cublasStatus_t cublasCdotc(cublasHandle_t handle, int n, cuComplex* x, int incx, cuComplex* y, int incy, cuComplex* result) { return hipblasCdotc(handle, n, - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); } cublasStatus_t cublasZdotu(cublasHandle_t handle, int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, cuDoubleComplex* result) { return hipblasZdotu(handle, n, - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); } cublasStatus_t cublasZdotc(cublasHandle_t handle, int n, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, cuDoubleComplex* result) { return hipblasZdotc(handle, n, - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(result)); + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(result)); } cublasStatus_t cublasSnrm2(cublasHandle_t handle, int n, float* x, int incx, float* result) { @@ -229,11 +245,11 @@ cublasStatus_t cublasDnrm2(cublasHandle_t handle, int n, double* x, int incx, do } cublasStatus_t cublasScnrm2(cublasHandle_t handle, int n, cuComplex* x, int incx, float* result) { - return hipblasScnrm2(handle, n, reinterpret_cast(x), incx, result); + return hipblasScnrm2(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasDznrm2(cublasHandle_t handle, int n, cuDoubleComplex* x, int incx, double* result) { - return hipblasDznrm2(handle, n, reinterpret_cast(x), incx, result); + return hipblasDznrm2(handle, n, reinterpret_cast(x), incx, result); } cublasStatus_t cublasSscal(cublasHandle_t handle, int n, float* alpha, float* x, int incx) { @@ -245,19 +261,19 @@ cublasStatus_t cublasDscal(cublasHandle_t handle, int n, double* alpha, double* } cublasStatus_t cublasCscal(cublasHandle_t handle, int n, cuComplex* alpha, cuComplex* x, int incx) { - return hipblasCscal(handle, n, reinterpret_cast(alpha), reinterpret_cast(x), incx); + return hipblasCscal(handle, n, reinterpret_cast(alpha), reinterpret_cast(x), incx); } cublasStatus_t cublasCsscal(cublasHandle_t handle, int n, float* alpha, cuComplex* x, int incx) { - return hipblasCsscal(handle, n, alpha, reinterpret_cast(x), incx); + return hipblasCsscal(handle, n, alpha, reinterpret_cast(x), incx); } cublasStatus_t cublasZscal(cublasHandle_t handle, int n, cuDoubleComplex* alpha, cuDoubleComplex* x, int incx) { - return hipblasZscal(handle, n, reinterpret_cast(alpha), reinterpret_cast(x), incx); + return hipblasZscal(handle, n, reinterpret_cast(alpha), reinterpret_cast(x), incx); } cublasStatus_t cublasZdscal(cublasHandle_t handle, int n, double* alpha, cuDoubleComplex* x, int incx) { - return hipblasZdscal(handle, n, alpha, reinterpret_cast(x), incx); + return hipblasZdscal(handle, n, alpha, reinterpret_cast(x), incx); } @@ -280,22 +296,22 @@ cublasStatus_t cublasCgemv(cublasHandle_t handle, cublasOperation_t trans, int m cuComplex* A, int lda, cuComplex* x, int incx, cuComplex* beta, cuComplex* y, int incy) { return hipblasCgemv(handle, convert_hipblasOperation_t(trans), m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(x), incx, - reinterpret_cast(beta), - reinterpret_cast(y), incy); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(x), incx, + reinterpret_cast(beta), + reinterpret_cast(y), incy); } cublasStatus_t cublasZgemv(cublasHandle_t handle, cublasOperation_t trans, int m, int n, cuDoubleComplex* alpha, cuDoubleComplex* A, int lda, cuDoubleComplex* x, int incx, cuDoubleComplex* beta, cuDoubleComplex* y, int incy) { return hipblasZgemv(handle, convert_hipblasOperation_t(trans), m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(x), incx, - reinterpret_cast(beta), - reinterpret_cast(y), incy); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(x), incx, + reinterpret_cast(beta), + reinterpret_cast(y), incy); } cublasStatus_t cublasSger(cublasHandle_t handle, int m, int n, float* alpha, float* x, int incx, @@ -311,39 +327,39 @@ cublasStatus_t cublasDger(cublasHandle_t handle, int m, int n, double* alpha, do cublasStatus_t cublasCgeru(cublasHandle_t handle, int m, int n, cuComplex* alpha, cuComplex* x, int incx, cuComplex* y, int incy, cuComplex* A, int lda) { return hipblasCgeru(handle, m, n, - reinterpret_cast(alpha), - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(A), lda); + reinterpret_cast(alpha), + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(A), lda); } cublasStatus_t cublasCgerc(cublasHandle_t handle, int m, int n, cuComplex* alpha, cuComplex* x, int incx, cuComplex* y, int incy, cuComplex* A, int lda) { return hipblasCgerc(handle, m, n, - reinterpret_cast(alpha), - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(A), lda); + reinterpret_cast(alpha), + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(A), lda); } cublasStatus_t cublasZgeru(cublasHandle_t handle, int m, int n, cuDoubleComplex* alpha, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, cuDoubleComplex* A, int lda) { return hipblasZgeru(handle, m, n, - reinterpret_cast(alpha), - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(A), lda); + reinterpret_cast(alpha), + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(A), lda); } cublasStatus_t cublasZgerc(cublasHandle_t handle, int m, int n, cuDoubleComplex* alpha, cuDoubleComplex* x, int incx, cuDoubleComplex* y, int incy, cuDoubleComplex* A, int lda) { return hipblasZgerc(handle, m, n, - reinterpret_cast(alpha), - reinterpret_cast(x), incx, - reinterpret_cast(y), incy, - reinterpret_cast(A), lda); + reinterpret_cast(alpha), + reinterpret_cast(x), incx, + reinterpret_cast(y), incy, + reinterpret_cast(A), lda); } cublasStatus_t cublasSsbmv(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, @@ -376,11 +392,11 @@ cublasStatus_t cublasGemmStridedBatchedEx(cublasHandle_t handle, cublasOperation } return hipblasGemmStridedBatchedEx(handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, alpha, - A, convert_hipblasDatatype_t(Atype), lda, strideA, - B, convert_hipblasDatatype_t(Btype), ldb, strideB, + A, convert_hipDatatype(Atype), lda, strideA, + B, convert_hipDatatype(Btype), ldb, strideB, beta, - C, convert_hipblasDatatype_t(Ctype), ldc, strideC, - batchCount, convert_hipblasDatatype_t(computeType), + C, convert_hipDatatype(Ctype), ldc, strideC, + batchCount, convert_hipblasComputeType_t(computeType), static_cast(160)); // HIPBLAS_GEMM_DEFAULT } @@ -412,11 +428,11 @@ cublasStatus_t cublasCgemm( return hipblasCgemm( handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); } cublasStatus_t cublasZgemm( @@ -429,11 +445,11 @@ cublasStatus_t cublasZgemm( return hipblasZgemm( handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); } cublasStatus_t cublasSgemmBatched( @@ -469,11 +485,11 @@ cublasStatus_t cublasCgemmBatched( return hipblasCgemmBatched( handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, batchCount); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(beta), + reinterpret_cast(C), ldc, batchCount); } cublasStatus_t cublasZgemmBatched( @@ -487,11 +503,11 @@ cublasStatus_t cublasZgemmBatched( return hipblasZgemmBatched( handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, batchCount); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, + reinterpret_cast(beta), + reinterpret_cast(C), ldc, batchCount); } cublasStatus_t cublasSgemmEx( @@ -525,11 +541,11 @@ cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, cub } return hipblasGemmEx(handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, alpha, - A, convert_hipblasDatatype_t(Atype), lda, - B, convert_hipblasDatatype_t(Btype), ldb, + A, convert_hipDatatype(Atype), lda, + B, convert_hipDatatype(Btype), ldb, beta, - C, convert_hipblasDatatype_t(Ctype), ldc, - convert_hipblasDatatype_t(computetype), + C, convert_hipDatatype(Ctype), ldc, + convert_hipblasComputeType_t(computetype), static_cast(160)); // HIPBLAS_GEMM_DEFAULT } @@ -571,9 +587,9 @@ cublasStatus_t cublasCtrsm(cublasHandle_t handle, cublasSideMode_t size, cublasF convert_hipblasOperation_t(trans), convert_hipblasDiagType_t(diag), m, n, - reinterpret_cast(alpha), - reinterpret_cast(const_cast(A)), lda, - reinterpret_cast(B), ldb); + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), lda, + reinterpret_cast(B), ldb); } cublasStatus_t cublasZtrsm(cublasHandle_t handle, cublasSideMode_t size, cublasFillMode_t uplo, cublasOperation_t trans, @@ -585,9 +601,9 @@ cublasStatus_t cublasZtrsm(cublasHandle_t handle, cublasSideMode_t size, cublasF convert_hipblasOperation_t(trans), convert_hipblasDiagType_t(diag), m, n, - reinterpret_cast(alpha), - reinterpret_cast(const_cast(A)), lda, - reinterpret_cast(B), ldb); + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), lda, + reinterpret_cast(B), ldb); } cublasStatus_t cublasStrsmBatched(cublasHandle_t handle, cublasSideMode_t size, cublasFillMode_t uplo, cublasOperation_t trans, @@ -621,9 +637,9 @@ cublasStatus_t cublasCtrsmBatched(cublasHandle_t handle, cublasSideMode_t size, convert_hipblasOperation_t(trans), convert_hipblasDiagType_t(diag), m, n, - reinterpret_cast(alpha), - reinterpret_cast(const_cast(A)), lda, - reinterpret_cast(const_cast(B)), ldb, batchCount); + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), lda, + reinterpret_cast(const_cast(B)), ldb, batchCount); } cublasStatus_t cublasZtrsmBatched(cublasHandle_t handle, cublasSideMode_t size, cublasFillMode_t uplo, cublasOperation_t trans, @@ -635,9 +651,9 @@ cublasStatus_t cublasZtrsmBatched(cublasHandle_t handle, cublasSideMode_t size, convert_hipblasOperation_t(trans), convert_hipblasDiagType_t(diag), m, n, - reinterpret_cast(alpha), - reinterpret_cast(const_cast(A)), lda, - reinterpret_cast(const_cast(B)), ldb, batchCount); + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), lda, + reinterpret_cast(const_cast(B)), ldb, batchCount); } cublasStatus_t cublasSsyrk(cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, int n, int k, @@ -659,10 +675,10 @@ cublasStatus_t cublasCsyrk(cublasHandle_t handle, cublasFillMode_t uplo, cublasO const cuComplex* beta, cuComplex* C, int ldc) { return hipblasCsyrk(handle, convert_hipblasFillMode_t(uplo), convert_hipblasOperation_t(trans), n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(beta), - reinterpret_cast(C), ldc); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); } cublasStatus_t cublasZsyrk(cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, int n, int k, @@ -670,10 +686,10 @@ cublasStatus_t cublasZsyrk(cublasHandle_t handle, cublasFillMode_t uplo, cublasO const cuDoubleComplex* beta, cuDoubleComplex* C, int ldc) { return hipblasZsyrk(handle, convert_hipblasFillMode_t(uplo), convert_hipblasOperation_t(trans), n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(beta), - reinterpret_cast(C), ldc); + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); } // BLAS extension @@ -705,13 +721,13 @@ cublasStatus_t cublasCgeam( return HIPBLAS_STATUS_NOT_SUPPORTED; #else return hipblasCgeam(handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), + reinterpret_cast(alpha), + reinterpret_cast(A), lda, - reinterpret_cast(beta), - reinterpret_cast(B), + reinterpret_cast(beta), + reinterpret_cast(B), ldb, - reinterpret_cast(C), + reinterpret_cast(C), ldc); #endif } @@ -726,13 +742,13 @@ cublasStatus_t cublasZgeam( return HIPBLAS_STATUS_NOT_SUPPORTED; #else return hipblasZgeam(handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), + reinterpret_cast(alpha), + reinterpret_cast(A), lda, - reinterpret_cast(beta), - reinterpret_cast(B), + reinterpret_cast(beta), + reinterpret_cast(B), ldb, - reinterpret_cast(C), + reinterpret_cast(C), ldc); #endif } @@ -767,9 +783,9 @@ cublasStatus_t cublasCdgmm(cublasHandle_t handle, cublasSideMode_t mode, int m, return HIPBLAS_STATUS_NOT_SUPPORTED; #else return hipblasCdgmm(handle, convert_hipblasSideMode_t(mode), m, n, - reinterpret_cast(A), lda, - reinterpret_cast(x), incx, - reinterpret_cast(C), ldc); + reinterpret_cast(A), lda, + reinterpret_cast(x), incx, + reinterpret_cast(C), ldc); #endif } @@ -781,9 +797,9 @@ cublasStatus_t cublasZdgmm(cublasHandle_t handle, cublasSideMode_t mode, int m, return HIPBLAS_STATUS_NOT_SUPPORTED; #else return hipblasZdgmm(handle, convert_hipblasSideMode_t(mode), m, n, - reinterpret_cast(A), lda, - reinterpret_cast(x), incx, - reinterpret_cast(C), ldc); + reinterpret_cast(A), lda, + reinterpret_cast(x), incx, + reinterpret_cast(C), ldc); #endif } @@ -832,9 +848,9 @@ cublasStatus_t cublasCgetriBatched(cublasHandle_t handle, return HIPBLAS_STATUS_NOT_SUPPORTED; #else return hipblasCgetriBatched(handle, n, - reinterpret_cast(const_cast(A)), + reinterpret_cast(const_cast(A)), lda, const_cast(P), - reinterpret_cast(C), + reinterpret_cast(C), ldc, info, batchSize); #endif } @@ -852,9 +868,9 @@ cublasStatus_t cublasZgetriBatched(cublasHandle_t handle, return HIPBLAS_STATUS_NOT_SUPPORTED; #else return hipblasZgetriBatched(handle, n, - reinterpret_cast(const_cast(A)), + reinterpret_cast(const_cast(A)), lda, const_cast(P), - reinterpret_cast(C), + reinterpret_cast(C), ldc, info, batchSize); #endif } @@ -899,11 +915,11 @@ cublasStatus_t cublasCgemmStridedBatched( handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, bsa, - reinterpret_cast(B), ldb, bsb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, bsc, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, bsa, + reinterpret_cast(B), ldb, bsb, + reinterpret_cast(beta), + reinterpret_cast(C), ldc, bsc, batchCount); } @@ -919,11 +935,11 @@ cublasStatus_t cublasZgemmStridedBatched( handle, convert_hipblasOperation_t(transa), convert_hipblasOperation_t(transb), m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, bsa, - reinterpret_cast(B), ldb, bsb, - reinterpret_cast(beta), - reinterpret_cast(C), ldc, bsc, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, bsa, + reinterpret_cast(B), ldb, bsb, + reinterpret_cast(beta), + reinterpret_cast(C), ldc, bsc, batchCount); } @@ -956,14 +972,14 @@ cublasStatus_t cublasDgetrfBatched(cublasHandle_t handle, int n, double **Aarray cublasStatus_t cublasCgetrfBatched(cublasHandle_t handle, int n, cuComplex **Aarray, int lda, int *PivotArray, int *infoArray, int batchSize) { return hipblasCgetrfBatched(handle, n, - reinterpret_cast(Aarray), lda, + reinterpret_cast(Aarray), lda, PivotArray, infoArray, batchSize); } cublasStatus_t cublasZgetrfBatched(cublasHandle_t handle, int n, cuDoubleComplex **Aarray, int lda, int *PivotArray, int *infoArray, int batchSize) { return hipblasZgetrfBatched(handle, n, - reinterpret_cast(Aarray), lda, + reinterpret_cast(Aarray), lda, PivotArray, infoArray, batchSize); } @@ -1021,9 +1037,9 @@ cublasStatus_t cublasCgetrsBatched(cublasHandle_t handle, return hipblasCgetrsBatched(handle, convert_hipblasOperation_t(trans), n, nrhs, - reinterpret_cast(const_cast(Aarray)), lda, + reinterpret_cast(const_cast(Aarray)), lda, devIpiv, - reinterpret_cast(Barray), ldb, + reinterpret_cast(Barray), ldb, info, batchSize); } @@ -1041,9 +1057,9 @@ cublasStatus_t cublasZgetrsBatched(cublasHandle_t handle, return hipblasZgetrsBatched(handle, convert_hipblasOperation_t(trans), n, nrhs, - reinterpret_cast(const_cast(Aarray)), lda, + reinterpret_cast(const_cast(Aarray)), lda, devIpiv, - reinterpret_cast(Barray), ldb, + reinterpret_cast(Barray), ldb, info, batchSize); }