diff --git a/include/spblas/vendor/cusparse/cusparse.hpp b/include/spblas/vendor/cusparse/cusparse.hpp index 7caa698..26c1d64 100644 --- a/include/spblas/vendor/cusparse/cusparse.hpp +++ b/include/spblas/vendor/cusparse/cusparse.hpp @@ -1,3 +1,4 @@ #pragma once #include "multiply.hpp" +#include "trisolve.hpp" diff --git a/include/spblas/vendor/cusparse/trisolve.hpp b/include/spblas/vendor/cusparse/trisolve.hpp new file mode 100644 index 0000000..7afa9d0 --- /dev/null +++ b/include/spblas/vendor/cusparse/trisolve.hpp @@ -0,0 +1,123 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "cuda_allocator.hpp" +#include "detail/cusparse_tensors.hpp" +#include "exception.hpp" +#include "types.hpp" + +namespace spblas { +class triangular_solve_state_t { +public: + triangular_solve_state_t() + : triangular_solve_state_t(cusparse::cuda_allocator{}) {} + + triangular_solve_state_t(cusparse::cuda_allocator alloc) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + cusparseHandle_t handle; + __cusparse::throw_if_error(cusparseCreate(&handle)); + if (auto stream = alloc.stream()) { + cusparseSetStream(handle, stream); + } + handle_ = handle_manager(handle, [](cusparseHandle_t handle) { + __cusparse::throw_if_error(cusparseDestroy(handle)); + }); + } + + triangular_solve_state_t(cusparse::cuda_allocator alloc, + cusparseHandle_t handle) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + handle_ = handle_manager(handle, [](cusparseHandle_t handle) { + // it is provided by user, we do not delete it at all. + }); + } + + ~triangular_solve_state_t() { + alloc_.deallocate(workspace_); + } + + template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range + void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + C&& c) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + using matrix_type = decltype(a_base); + using value_type = typename matrix_type::scalar_type; + // the following needs to be non-const because cusparseSpMatSetAttribute + // only accept void* + auto diag_type = std::is_same_v + ? CUSPARSE_DIAG_TYPE_NON_UNIT + : CUSPARSE_DIAG_TYPE_UNIT; + auto fill_mode = std::is_same_v + ? CUSPARSE_FILL_MODE_UPPER + : CUSPARSE_FILL_MODE_LOWER; + + auto a_descr = __cusparse::create_cusparse_handle(a_base); + auto b_descr = __cusparse::create_cusparse_handle(b_base); + auto c_descr = __cusparse::create_cusparse_handle(c); + + __cusparse::throw_if_error(cusparseSpMatSetAttribute( + a_descr, CUSPARSE_SPMAT_FILL_MODE, &fill_mode, sizeof(fill_mode))); + __cusparse::throw_if_error(cusparseSpMatSetAttribute( + a_descr, CUSPARSE_SPMAT_DIAG_TYPE, &diag_type, sizeof(diag_type))); + value_type alpha = 1.0; + size_t buffer_size = 0; + auto handle = this->handle_.get(); + cusparseSpSVDescr_t descr; + cusparseSpSV_createDescr(&descr); + __cusparse::throw_if_error(cusparseSpSV_bufferSize( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_descr, b_descr, + c_descr, detail::cuda_data_type_v, + CUSPARSE_SPSV_ALG_DEFAULT, descr, &buffer_size)); + if (buffer_size > this->buffer_size_) { + this->alloc_.deallocate(workspace_, this->buffer_size_); + this->buffer_size_ = buffer_size; + workspace_ = this->alloc_.allocate(buffer_size); + } + __cusparse::throw_if_error(cusparseSpSV_analysis( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_descr, b_descr, + c_descr, detail::cuda_data_type_v, + CUSPARSE_SPSV_ALG_DEFAULT, descr, this->workspace_)); + __cusparse::throw_if_error(cusparseSpSV_solve( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_descr, b_descr, + c_descr, detail::cuda_data_type_v, + CUSPARSE_SPSV_ALG_DEFAULT, descr)); + __cusparse::throw_if_error(cusparseDestroySpMat(a_descr)); + __cusparse::throw_if_error(cusparseDestroyDnVec(b_descr)); + __cusparse::throw_if_error(cusparseDestroyDnVec(c_descr)); + } + +private: + using handle_manager = + std::unique_ptr::element_type, + std::function>; + handle_manager handle_; + cusparse::cuda_allocator alloc_; + std::uint64_t buffer_size_; + char* workspace_; +}; + +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(triangular_solve_state_t& trisolve_handle, A&& a, + Triangle uplo, DiagonalStorage diag, B&& b, C&& c) { + trisolve_handle.triangular_solve(a, uplo, diag, b, c); +} + +} // namespace spblas diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index f1d1504..2bb7f55 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -20,7 +20,7 @@ if (SPBLAS_GPU_BACKEND) set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp device/triangular_solve_test.cpp) set_source_files_properties(${GPUTEST_SOURCES} PROPERTIES LANGUAGE HIP) else () - set(GPUTEST_SOURCES device/spmv_test.cpp) + set(GPUTEST_SOURCES device/spmv_test.cpp device/triangular_solve_test.cpp) endif () list(APPEND TEST_SOURCES ${GPUTEST_SOURCES}) endif()