From 879a7d3f558d86a83ea7faa8c563ad06f565a9ad Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 21 Aug 2025 16:14:02 +0200 Subject: [PATCH] implement rocsparse spsv --- include/spblas/vendor/rocsparse/rocsparse.hpp | 1 + include/spblas/vendor/rocsparse/trisolve.hpp | 118 ++++++++++++++++++ test/gtest/CMakeLists.txt | 2 +- test/gtest/device/triangular_solve_test.cpp | 116 +++++++++++++++++ 4 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 include/spblas/vendor/rocsparse/trisolve.hpp create mode 100644 test/gtest/device/triangular_solve_test.cpp diff --git a/include/spblas/vendor/rocsparse/rocsparse.hpp b/include/spblas/vendor/rocsparse/rocsparse.hpp index 014b2ba..66a4cd9 100644 --- a/include/spblas/vendor/rocsparse/rocsparse.hpp +++ b/include/spblas/vendor/rocsparse/rocsparse.hpp @@ -2,3 +2,4 @@ #include "multiply.hpp" #include "multiply_spgemm.hpp" +#include "trisolve.hpp" diff --git a/include/spblas/vendor/rocsparse/trisolve.hpp b/include/spblas/vendor/rocsparse/trisolve.hpp new file mode 100644 index 0000000..3120acc --- /dev/null +++ b/include/spblas/vendor/rocsparse/trisolve.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include "exception.hpp" +#include "hip_allocator.hpp" +#include "types.hpp" + +namespace spblas { +class triangular_solve_state_t { +public: + triangular_solve_state_t() + : triangular_solve_state_t(rocsparse::hip_allocator{}) {} + + triangular_solve_state_t(rocsparse::hip_allocator alloc) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + rocsparse_handle handle; + __rocsparse::throw_if_error(rocsparse_create_handle(&handle)); + if (auto stream = alloc.stream()) { + rocsparse_set_stream(handle, stream); + } + handle_ = handle_manager(handle, [](rocsparse_handle handle) { + __rocsparse::throw_if_error(rocsparse_destroy_handle(handle)); + }); + } + + triangular_solve_state_t(rocsparse::hip_allocator alloc, + rocsparse_handle handle) + : alloc_(alloc), buffer_size_(0), workspace_(nullptr) { + handle_ = handle_manager(handle, [](rocsparse_handle handle) { + // it is provided by user, we do not delete it at all. + }); + } + + ~triangular_solve_state_t() { + alloc_.deallocate(workspace_); + } + + template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range + void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + C&& c) { + auto a_base = __detail::get_ultimate_base(a); + auto b_base = __detail::get_ultimate_base(b); + using matrix_type = decltype(a_base); + using value_type = typename matrix_type::scalar_type; + const auto diag_type = std::is_same_v + ? rocsparse_diag_type_non_unit + : rocsparse_diag_type_unit; + const auto fill_mode = std::is_same_v + ? rocsparse_fill_mode_upper + : rocsparse_fill_mode_lower; + + auto a_descr = __rocsparse::create_rocsparse_handle(a_base); + auto b_descr = __rocsparse::create_rocsparse_handle(b_base); + auto c_descr = __rocsparse::create_rocsparse_handle(c); + + __rocsparse::throw_if_error(rocsparse_spmat_set_attribute( + a_descr, rocsparse_spmat_fill_mode, &fill_mode, sizeof(fill_mode))); + __rocsparse::throw_if_error(rocsparse_spmat_set_attribute( + a_descr, rocsparse_spmat_diag_type, &diag_type, sizeof(diag_type))); + value_type alpha = 1.0; + size_t buffer_size = 0; + auto handle = this->handle_.get(); + __rocsparse::throw_if_error(rocsparse_spsv( + handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, + detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, + rocsparse_spsv_stage_buffer_size, &buffer_size, nullptr)); + if (buffer_size > this->buffer_size_) { + this->alloc_.deallocate(workspace_, this->buffer_size_); + this->buffer_size_ = buffer_size; + workspace_ = this->alloc_.allocate(buffer_size); + } + __rocsparse::throw_if_error(rocsparse_spsv( + handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, + detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, + rocsparse_spsv_stage_preprocess, &buffer_size, this->workspace_)); + __rocsparse::throw_if_error(rocsparse_spsv( + handle, rocsparse_operation_none, &alpha, a_descr, b_descr, c_descr, + detail::rocsparse_data_type_v, rocsparse_spsv_alg_default, + rocsparse_spsv_stage_compute, &buffer_size, this->workspace_)); + __rocsparse::throw_if_error(rocsparse_destroy_spmat_descr(a_descr)); + __rocsparse::throw_if_error(rocsparse_destroy_dnvec_descr(b_descr)); + __rocsparse::throw_if_error(rocsparse_destroy_dnvec_descr(c_descr)); + } + +private: + using handle_manager = + std::unique_ptr::element_type, + std::function>; + handle_manager handle_; + rocsparse::hip_allocator alloc_; + std::uint64_t buffer_size_; + char* workspace_; +}; + +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(triangular_solve_state_t& trisolve_handle, A&& a, + Triangle uplo, DiagonalStorage diag, B&& b, C&& c) { + trisolve_handle.triangular_solve(a, uplo, diag, b, c); +} + +} // namespace spblas diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index cd96de1..f1d1504 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -17,7 +17,7 @@ endif() # GPU tests if (SPBLAS_GPU_BACKEND) if (ENABLE_ROCSPARSE) - set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp) + set(GPUTEST_SOURCES device/spmv_test.cpp device/spgemm_test.cpp device/spgemm_reuse_test.cpp device/rocsparse/spgemm_4args_test.cpp device/triangular_solve_test.cpp) set_source_files_properties(${GPUTEST_SOURCES} PROPERTIES LANGUAGE HIP) else () set(GPUTEST_SOURCES device/spmv_test.cpp) diff --git a/test/gtest/device/triangular_solve_test.cpp b/test/gtest/device/triangular_solve_test.cpp new file mode 100644 index 0000000..2afa362 --- /dev/null +++ b/test/gtest/device/triangular_solve_test.cpp @@ -0,0 +1,116 @@ +#include + +#include "../util.hpp" +#include + +#include + +template +void reference_triangular_solve(spblas::csr_view a, Triangle t, + DiagonalStorage d, B&& b, X&& x) { + auto&& values = a.values(); + auto&& colind = a.colind(); + auto&& rowptr = a.rowptr(); + auto shape = a.shape(); + + if constexpr (std::is_same_v) { + // backward solve + for (I row = shape[0]; row-- > 0;) { + T tmp = b[row]; + T diag_val = 0.0; + for (I j = rowptr[row]; j < rowptr[row + 1]; j++) { + I col = colind[j]; + if (col > row) { + T a_val = values[j]; + T x_val = x[col]; + tmp -= a_val * x_val; // b - U*x + } else if (col == row) { + diag_val = values[j]; + } + } + if constexpr (std::is_same_v) { + x[row] = tmp / diag_val; // ( b - U*x) / d + } else { + x[row] = tmp; // ( b- U*x) / 1 + } + } + } else if constexpr (std::is_same_v) { + // Forward Solve + for (I row = 0; row < shape[0]; row++) { + T tmp = b[row]; + T diag_val = 0.0; + for (I j = rowptr[row]; j < rowptr[row + 1]; ++j) { + I col = colind[j]; + if (col < row) { + T a_val = values[j]; + T x_val = x[col]; + tmp -= a_val * x_val; // b - L*x + } else if (col == row) { + diag_val = values[j]; + } + } + if constexpr (std::is_same_v) { + x[row] = tmp / diag_val; // ( b - L*x) / d + } else { + x[row] = tmp; // ( b- L*x) / 1 + } + } + } +} + +template +void triangular_solve_test(Triangle t, DiagonalStorage d) { + for (auto&& [m, n, nnz] : util::square_dims) { + // generate problem on host + auto [values, rowptr, colind, shape, _] = + spblas::generate_csr(m, n, nnz); + spblas::csr_view a(values, rowptr, colind, shape, nnz); + std::vector x(n, 1); + std::vector b(m, 1); + T scale_factor = 1e-3f; + std::transform(values.begin(), values.end(), values.begin(), + [scale_factor](T val) { return scale_factor * val; }); + // setup the problem on device + thrust::device_vector d_b(b); + thrust::device_vector d_x(x); + thrust::device_vector d_values(values); + thrust::device_vector d_rowptr(rowptr); + thrust::device_vector d_colind(colind); + spblas::csr_view d_a(d_values.data().get(), d_rowptr.data().get(), + d_colind.data().get(), shape, nnz); + std::span b_span(d_b.data().get(), m); + std::span x_span(d_x.data().get(), n); + + spblas::triangular_solve_state_t state; + spblas::triangular_solve(state, d_a, Triangle{}, DiagonalStorage{}, b_span, + x_span); + thrust::copy(d_x.begin(), d_x.end(), x.begin()); + + std::vector x_ref(m, 0); + reference_triangular_solve(a, Triangle{}, DiagonalStorage{}, b, x_ref); + + for (std::size_t i = 0; i < x.size(); i++) { + EXPECT_EQ_(x[i], x_ref[i]); + } + } +} + +TEST(CsrView, TriangularSolveLowerImplicit) { + using T = float; + using I = spblas::index_t; + + triangular_solve_test(spblas::lower_triangle_t{}, + spblas::implicit_unit_diagonal_t{}); +} + +TEST(CsrView, TriangularSolveUpperImplicit) { + using T = float; + using I = spblas::index_t; + + triangular_solve_test(spblas::upper_triangle_t{}, + spblas::implicit_unit_diagonal_t{}); +}