@@ -22,9 +22,10 @@ class spgemm_state_t {
2222 spgemm_state_t () : spgemm_state_t (cusparse::cuda_allocator<char >{}) {}
2323
2424 spgemm_state_t (cusparse::cuda_allocator<char > alloc)
25- : alloc_(alloc), buffer_size_1_(0 ), buffer_size_2_(0 ),
26- workspace_1_ (nullptr ), workspace_2_(nullptr ), result_nnz_(0 ),
27- result_shape_(0 , 0 ) {
25+ : alloc_(alloc), buffer_size_1_(0 ), buffer_size_2_(0 ), buffer_size_3_(0 ),
26+ buffer_size_4_ (0 ), buffer_size_5_(0 ), workspace_1_(nullptr ),
27+ workspace_2_(nullptr ), workspace_3_(nullptr ), workspace_4_(nullptr ),
28+ workspace_5_(nullptr ), result_nnz_(0 ), result_shape_(0 , 0 ) {
2829 cusparseHandle_t handle;
2930 __cusparse::throw_if_error (cusparseCreate (&handle));
3031 if (auto stream = alloc.stream ()) {
@@ -157,6 +158,156 @@ class spgemm_state_t {
157158 to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this ->descr_ ));
158159 }
159160
161+ template <matrix A, matrix B, matrix C>
162+ requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
163+ __detail::is_csr_view_v<C>
164+ void multiply_symbolic_compute (A&& a, B&& b, C&& c) {
165+ auto a_base = __detail::get_ultimate_base (a);
166+ auto b_base = __detail::get_ultimate_base (b);
167+ using matrix_type = decltype (a_base);
168+ using input_type = decltype (b_base);
169+ using output_type = std::remove_reference_t <decltype (c)>;
170+ using value_type = typename matrix_type::scalar_type;
171+ size_t buffer_size = 0 ;
172+
173+ auto alpha_optional = __detail::get_scaling_factor (a, b);
174+ value_type alpha = alpha_optional.value_or (1 );
175+ value_type beta = 1 ;
176+ auto handle = this ->handle_ .get ();
177+ __cusparse::throw_if_error (cusparseDestroySpMat (mat_a_));
178+ __cusparse::throw_if_error (cusparseDestroySpMat (mat_b_));
179+ __cusparse::throw_if_error (cusparseDestroySpMat (mat_c_));
180+ mat_a_ = __cusparse::create_matrix_descr (a_base);
181+ mat_b_ = __cusparse::create_matrix_descr (b_base);
182+ mat_c_ = __cusparse::create_matrix_descr (c);
183+
184+ // ask bufferSize1 bytes for external memory
185+ size_t buffer_size_1 = 0 ;
186+ __cusparse::throw_if_error (cusparseSpGEMMreuse_workEstimation (
187+ handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
188+ CUSPARSE_OPERATION_NON_TRANSPOSE, mat_a_, mat_b_, mat_c_,
189+ CUSPARSE_SPGEMM_DEFAULT, this ->descr_ , &buffer_size_1, NULL ));
190+ if (buffer_size_1 > this ->buffer_size_1_ ) {
191+ this ->alloc_ .deallocate (this ->workspace_1_ , buffer_size_1_);
192+ this ->buffer_size_1_ = buffer_size_1;
193+ this ->workspace_1_ = this ->alloc_ .allocate (buffer_size_1);
194+ }
195+ // inspect the matrices A and B to understand the memory requirement for
196+ // the next step
197+ __cusparse::throw_if_error (cusparseSpGEMMreuse_workEstimation (
198+ handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
199+ CUSPARSE_OPERATION_NON_TRANSPOSE, mat_a_, mat_b_, mat_c_,
200+ CUSPARSE_SPGEMM_DEFAULT, this ->descr_ , &buffer_size_1,
201+ this ->workspace_1_ ));
202+
203+ // ask buffer_size_2/3/4 bytes for external memory
204+ size_t buffer_size_2 = 0 ;
205+ size_t buffer_size_3 = 0 ;
206+ size_t buffer_size_4 = 0 ;
207+ cusparseSpGEMMreuse_nnz (handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
208+ CUSPARSE_OPERATION_NON_TRANSPOSE, mat_a_, mat_b_,
209+ mat_c_, CUSPARSE_SPGEMM_DEFAULT, this ->descr_ ,
210+ &buffer_size_2, NULL , &buffer_size_3, NULL ,
211+ &buffer_size_4, NULL );
212+ if (buffer_size_2 > this ->buffer_size_2_ ) {
213+ this ->alloc_ .deallocate (this ->workspace_2_ , buffer_size_2_);
214+ this ->buffer_size_2_ = buffer_size_2;
215+ this ->workspace_2_ = this ->alloc_ .allocate (buffer_size_2);
216+ }
217+ if (buffer_size_3 > this ->buffer_size_3_ ) {
218+ this ->alloc_ .deallocate (this ->workspace_3_ , buffer_size_3_);
219+ this ->buffer_size_3_ = buffer_size_3;
220+ this ->workspace_3_ = this ->alloc_ .allocate (buffer_size_3);
221+ }
222+ if (buffer_size_4 > this ->buffer_size_4_ ) {
223+ this ->alloc_ .deallocate (this ->workspace_4_ , buffer_size_4_);
224+ this ->buffer_size_4_ = buffer_size_4;
225+ this ->workspace_4_ = this ->alloc_ .allocate (buffer_size_4);
226+ }
227+
228+ // compute nnz
229+ cusparseSpGEMMreuse_nnz (handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
230+ CUSPARSE_OPERATION_NON_TRANSPOSE, mat_a_, mat_b_,
231+ mat_c_, CUSPARSE_SPGEMM_DEFAULT, this ->descr_ ,
232+ &buffer_size_2, this ->workspace_2_ , &buffer_size_3,
233+ this ->workspace_3_ , &buffer_size_4,
234+ this ->workspace_4_ );
235+ // get matrix C non-zero entries c_nnz
236+ int64_t c_num_rows, c_num_cols, c_nnz;
237+ cusparseSpMatGetSize (mat_c_, &c_num_rows, &c_num_cols, &c_nnz);
238+ this ->result_nnz_ = c_nnz;
239+ this ->result_shape_ = index<index_t >(c_num_rows, c_num_cols);
240+ }
241+
242+ template <matrix A, matrix B, matrix C>
243+ requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
244+ __detail::is_csr_view_v<C>
245+ void multiply_symbolic_fill (A&& a, B&& b, C&& c) {
246+ auto a_base = __detail::get_ultimate_base (a);
247+ auto b_base = __detail::get_ultimate_base (b);
248+ using matrix_type = decltype (a_base);
249+ using input_type = decltype (b_base);
250+ using output_type = std::remove_reference_t <decltype (c)>;
251+ using value_type = typename matrix_type::scalar_type;
252+
253+ auto alpha_optional = __detail::get_scaling_factor (a, b);
254+ value_type alpha = alpha_optional.value_or (1 );
255+ value_type beta = 0 ;
256+
257+ __cusparse::throw_if_error (cusparseCsrSetPointers (
258+ this ->mat_c_ , c.rowptr ().data (), c.colind ().data (), c.values ().data ()));
259+
260+ auto handle = this ->handle_ .get ();
261+ size_t buffer_size_5 = 0 ;
262+ cusparseSpGEMMreuse_copy (handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
263+ CUSPARSE_OPERATION_NON_TRANSPOSE, mat_a_, mat_b_,
264+ mat_c_, CUSPARSE_SPGEMM_DEFAULT, this ->descr_ ,
265+ &buffer_size_5, NULL );
266+ if (buffer_size_5 > this ->buffer_size_5_ ) {
267+ this ->alloc_ .deallocate (this ->workspace_5_ , buffer_size_5_);
268+ this ->buffer_size_5_ = buffer_size_5;
269+ this ->workspace_5_ = this ->alloc_ .allocate (buffer_size_5);
270+ }
271+ cusparseSpGEMMreuse_copy (handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
272+ CUSPARSE_OPERATION_NON_TRANSPOSE, mat_a_, mat_b_,
273+ mat_c_, CUSPARSE_SPGEMM_DEFAULT, this ->descr_ ,
274+ &buffer_size_5, this ->workspace_5_ );
275+ }
276+
277+ template <matrix A, matrix B, matrix C>
278+ requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
279+ __detail::is_csr_view_v<C>
280+ void multiply_numeric (A&& a, B&& b, C&& c) {
281+ auto a_base = __detail::get_ultimate_base (a);
282+ auto b_base = __detail::get_ultimate_base (b);
283+ using matrix_type = decltype (a_base);
284+ using input_type = decltype (b_base);
285+ using output_type = std::remove_reference_t <decltype (c)>;
286+ using value_type = typename matrix_type::scalar_type;
287+
288+ auto alpha_optional = __detail::get_scaling_factor (a, b);
289+ tensor_scalar_t <A> alpha = alpha_optional.value_or (1 );
290+ value_type alpha_val = alpha;
291+ value_type beta = 0 ;
292+
293+ auto handle = this ->handle_ .get ();
294+
295+ // Update the pointer from the matrix but they must contains the same
296+ // sparsity as the previous call.
297+ __cusparse::throw_if_error (
298+ cusparseCsrSetPointers (this ->mat_a_ , a_base.rowptr ().data (),
299+ a_base.colind ().data (), a_base.values ().data ()));
300+ __cusparse::throw_if_error (
301+ cusparseCsrSetPointers (this ->mat_b_ , b_base.rowptr ().data (),
302+ b_base.colind ().data (), b_base.values ().data ()));
303+ __cusparse::throw_if_error (cusparseCsrSetPointers (
304+ this ->mat_c_ , c.rowptr ().data (), c.colind ().data (), c.values ().data ()));
305+ cusparseSpGEMMreuse_compute (
306+ handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
307+ CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a_, mat_b_, &beta, mat_c_,
308+ to_cuda_datatype<value_type>(), CUSPARSE_SPGEMM_DEFAULT, this ->descr_ );
309+ }
310+
160311private:
161312 using handle_manager =
162313 std::unique_ptr<std::pointer_traits<cusparseHandle_t>::element_type,
@@ -165,8 +316,14 @@ class spgemm_state_t {
165316 cusparse::cuda_allocator<char > alloc_;
166317 size_t buffer_size_1_;
167318 size_t buffer_size_2_;
319+ size_t buffer_size_3_;
320+ size_t buffer_size_4_;
321+ size_t buffer_size_5_;
168322 char * workspace_1_;
169323 char * workspace_2_;
324+ char * workspace_3_;
325+ char * workspace_4_;
326+ char * workspace_5_;
170327 index<index_t > result_shape_;
171328 index_t result_nnz_;
172329 cusparseSpMatDescr_t mat_a_ = nullptr ;
@@ -194,4 +351,27 @@ void multiply_fill(spgemm_state_t& spgemm_handle, A&& a, B&& b, C&& c) {
194351 spgemm_handle.multiply_fill (a, b, c);
195352}
196353
354+ template <matrix A, matrix B, matrix C>
355+ requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
356+ __detail::is_csr_view_v<C>
357+ void multiply_symbolic_compute (spgemm_state_t & spgemm_handle, A&& a, B&& b,
358+ C&& c) {
359+ spgemm_handle.multiply_symbolic_compute (a, b, c);
360+ }
361+
362+ template <matrix A, matrix B, matrix C>
363+ requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
364+ __detail::is_csr_view_v<C>
365+ void multiply_symbolic_fill (spgemm_state_t & spgemm_handle, A&& a, B&& b,
366+ C&& c) {
367+ spgemm_handle.multiply_symbolic_fill (a, b, c);
368+ }
369+
370+ template <matrix A, matrix B, matrix C>
371+ requires __detail::has_csr_base<A> && __detail::has_csr_base<B> &&
372+ __detail::is_csr_view_v<C>
373+ void multiply_numeric (spgemm_state_t & spgemm_handle, A&& a, B&& b, C&& c) {
374+ spgemm_handle.multiply_numeric (a, b, c);
375+ }
376+
197377} // namespace spblas
0 commit comments