diff --git a/include/xstl/association_map/cuda/detail/association_map.hpp b/include/xstl/association_map/cuda/detail/association_map.hpp index cf1a8ef..2be06b5 100644 --- a/include/xstl/association_map/cuda/detail/association_map.hpp +++ b/include/xstl/association_map/cuda/detail/association_map.hpp @@ -9,7 +9,6 @@ #include "xstl/association_map/cuda/detail/kernels.cuh" #include #include -#include #include namespace xstd::cuda { @@ -55,10 +54,10 @@ namespace xstd::cuda { auto temporary_keys = make_device_unique(m_extents.keys + 1, stream); cudaMemsetAsync(temporary_keys.data(), 0, sizeof(key_type) * (m_extents.keys + 1), stream); - thrust::async::inclusive_scan(thrust::device.on(stream), - accumulator.data(), - accumulator.data() + m_extents.keys, - temporary_keys.data() + 1); + thrust::inclusive_scan(thrust::cuda::par_nosync.on(stream), + accumulator.data(), + accumulator.data() + m_extents.keys, + temporary_keys.data() + 1); cudaMemcpyAsync(m_data.keys.data(), temporary_keys.data(), sizeof(key_type) * (m_extents.keys + 1), diff --git a/include/xstl/association_map/hip/detail/association_map.hpp b/include/xstl/association_map/hip/detail/association_map.hpp index 30ca0c4..56aeaf6 100644 --- a/include/xstl/association_map/hip/detail/association_map.hpp +++ b/include/xstl/association_map/hip/detail/association_map.hpp @@ -9,7 +9,6 @@ #include "xstl/association_map/hip/detail/kernels.cuh" #include #include -#include #include namespace xstd::hip { @@ -55,10 +54,10 @@ namespace xstd::hip { auto temporary_keys = make_device_unique(m_extents.keys + 1, stream); hipMemsetAsync(temporary_keys.data(), 0, sizeof(key_type) * (m_extents.keys + 1), stream); - thrust::async::inclusive_scan(thrust::device.on(stream), - accumulator.data(), - accumulator.data() + m_extents.keys, - temporary_keys.data() + 1); + thrust::inclusive_scan(thrust::cuda::par_nosync.on(stream), + accumulator.data(), + accumulator.data() + m_extents.keys, + temporary_keys.data() + 1); hipMemcpyAsync(m_data.keys.data(), temporary_keys.data(), sizeof(key_type) * (m_extents.keys + 1),