diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 86d072a8a..1750d5949 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -23,7 +23,6 @@ #include #include -#include #include #include @@ -262,7 +261,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * @@ -273,7 +272,8 @@ class open_addressing_ref_impl { { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -290,7 +290,7 @@ class open_addressing_ref_impl { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, slot_content, - value)) { + val)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; case insert_result::DUPLICATE: return false; @@ -304,7 +304,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert @@ -315,7 +315,8 @@ class open_addressing_ref_impl { __device__ bool insert(cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -353,7 +354,7 @@ class open_addressing_ref_impl { (group.thread_rank() == src_lane) ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, window_slots[intra_window_index], - value) + val) : insert_result::CONTINUE; switch (group.shfl(status, src_lane)) { @@ -374,7 +375,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * @@ -393,7 +394,8 @@ class open_addressing_ref_impl { "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); #endif - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -414,14 +416,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY or cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), this->erased_key_sentinel())) { - auto const res = [&]() { - if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, window_slots[i], value); - } else { - return cas_dependent_write(window_ptr + i, window_slots[i], value); - } - }(); - switch (res) { + switch (this->attempt_insert_stable(window_ptr + i, window_slots[i], val)) { case insert_result::SUCCESS: { if constexpr (has_payload) { // wait to ensure that the write to the value part also took place @@ -451,7 +446,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert @@ -471,7 +466,8 @@ class open_addressing_ref_impl { "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); #endif - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -522,11 +518,7 @@ class open_addressing_ref_impl { auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); auto const status = [&, target_idx = intra_window_index]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } - if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, window_slots[target_idx], value); - } else { - return cas_dependent_write(slot_ptr, window_slots[target_idx], value); - } + return this->attempt_insert_stable(slot_ptr, window_slots[target_idx], val); }(); switch (group.shfl(status, src_lane)) { @@ -561,7 +553,7 @@ class open_addressing_ref_impl { /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param value The element to erase * @@ -601,7 +593,7 @@ class open_addressing_ref_impl { /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group erase * @param value The element to erase @@ -661,7 +653,7 @@ class open_addressing_ref_impl { * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns * false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param key The key to search for * @@ -694,7 +686,7 @@ class open_addressing_ref_impl { * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns * false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group contains * @param key The key to search for @@ -734,7 +726,7 @@ class open_addressing_ref_impl { * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param key The key to search for * @@ -771,7 +763,7 @@ class open_addressing_ref_impl { * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform this operation * @param key The key to search for @@ -902,22 +894,20 @@ class open_addressing_ref_impl { } } - private: /** * @brief Extracts the key from a given value type. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The input value * * @return The key */ template - [[nodiscard]] __host__ __device__ constexpr auto const& extract_key( - Value const& value) const noexcept + [[nodiscard]] __device__ constexpr auto const& extract_key(Value const& value) const noexcept { if constexpr (this->has_payload) { - return thrust::get<0>(thrust::raw_reference_cast(value)); + return thrust::raw_reference_cast(value).first; } else { return thrust::raw_reference_cast(value); } @@ -928,17 +918,65 @@ class open_addressing_ref_impl { * * @note This function is only available if `this->has_payload == true` * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The input value * * @return The payload */ template > - [[nodiscard]] __host__ __device__ constexpr auto const& extract_payload( - Value const& value) const noexcept + [[nodiscard]] __device__ constexpr auto const& extract_payload(Value const& value) const noexcept + { + return thrust::raw_reference_cast(value).second; + } + + /** + * @brief Converts the given type to the container's native `value_type`. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __device__ constexpr value_type native_value(T const& value) const noexcept { - return thrust::get<1>(thrust::raw_reference_cast(value)); + if constexpr (this->has_payload) { + return {static_cast(this->extract_key(value)), this->extract_payload(value)}; + } else { + return static_cast(value); + } + } + + /** + * @brief Converts the given type to the container's native `value_type` while maintaining the + * heterogeneous key type. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __device__ constexpr auto heterogeneous_value(T const& value) const noexcept + { + if constexpr (this->has_payload and not cuda::std::is_same_v) { + using mapped_type = decltype(this->empty_slot_sentinel_.second); + if constexpr (cuco::detail::is_cuda_std_pair_like::value) { + return cuco::pair{cuda::std::get<0>(value), + static_cast(cuda::std::get<1>(value))}; + } else if constexpr (cuco::detail::is_thrust_pair_like::value) { + return cuco::pair{thrust::get<0>(value), static_cast(thrust::get<1>(value))}; + } else { + // hail mary (convert using .first/.second members) + return cuco::pair{thrust::raw_reference_cast(value.first), + static_cast(value.second)}; + } + } else { + return thrust::raw_reference_cast(value); + } } /** @@ -958,7 +996,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -971,7 +1009,7 @@ class open_addressing_ref_impl { value_type const& expected, Value const& desired) noexcept { - auto old = compare_and_swap(address, expected, static_cast(desired)); + auto old = compare_and_swap(address, expected, this->native_value(desired)); auto* old_ptr = reinterpret_cast(&old); if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) { return insert_result::SUCCESS; @@ -986,7 +1024,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with two back-to-back CAS operations. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -1004,10 +1042,9 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; auto const expected_payload = expected.second; - auto old_key = compare_and_swap( - &address->first, expected_key, static_cast(thrust::get<0>(desired))); - auto old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(thrust::get<1>(desired))); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_payload = compare_and_swap(&address->second, expected_payload, desired.second); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -1015,8 +1052,7 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(thrust::get<1>(desired))); + old_payload = compare_and_swap(&address->second, expected_payload, desired.second); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -1025,9 +1061,7 @@ class open_addressing_ref_impl { // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, - thrust::get<0>(thrust::raw_reference_cast(desired))) == - detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -1037,7 +1071,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with CAS-dependent write operations. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -1053,23 +1087,20 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; - auto old_key = compare_and_swap( - &address->first, expected_key, static_cast(thrust::get<0>(desired))); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast(desired.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&address->second, static_cast(thrust::get<1>(desired))); - + atomic_store(&address->second, desired.second); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, - thrust::get<0>(thrust::raw_reference_cast(desired))) == - detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -1082,7 +1113,7 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -1106,10 +1137,40 @@ class open_addressing_ref_impl { } } + /** + * @brief Attempts to insert an element into a slot. + * + * @note Dispatches the correct implementation depending on the container + * type and presence of other operator mixins. + * + * @note `stable` indicates that the payload will only be updated once from the sentinel value to + * the desired value, meaning there can be no ABA situations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ insert_result attempt_insert_stable(value_type* address, + value_type const& expected, + Value const& desired) noexcept + { + if constexpr (sizeof(value_type) <= 8) { + return packed_cas(address, expected, desired); + } else { + return cas_dependent_write(address, expected, desired); + } + } + /** * @brief Waits until the slot payload has been updated * - * @note The function will return once the slot payload is no longer equal to the sentinel value. + * @note The function will return once the slot payload is no longer equal to the sentinel + * value. * * @tparam T Map slot type * diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index f27f21e76..c6a24bf7e 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -18,6 +18,8 @@ #include +#include + #include #include @@ -171,7 +173,8 @@ class operator_impl< using base_type = static_map_ref; using ref_type = static_map_ref; using key_type = typename base_type::key_type; - using value_type = typename base_type::value_type; + using value_type = typename base_type::value_type; + using mapped_type = T; static constexpr auto cg_size = base_type::cg_size; static constexpr auto window_size = base_type::window_size; @@ -180,7 +183,7 @@ class operator_impl< /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * @@ -196,7 +199,7 @@ class operator_impl< /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert @@ -225,7 +228,8 @@ class operator_impl< using base_type = static_map_ref; using ref_type = static_map_ref; using key_type = typename base_type::key_type; - using value_type = typename base_type::value_type; + using value_type = typename base_type::value_type; + using mapped_type = T; static constexpr auto cg_size = base_type::cg_size; static constexpr auto window_size = base_type::window_size; @@ -238,7 +242,7 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert */ @@ -247,8 +251,10 @@ class operator_impl< { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - ref_type& ref_ = static_cast(*this); - auto const key = value.first; + ref_type& ref_ = static_cast(*this); + + auto const val = ref_.impl_.heterogeneous_value(value); + auto const key = ref_.impl_.extract_key(val); auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.window_extent()); @@ -264,14 +270,14 @@ class operator_impl< auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + val.second); return; } if (eq_res == detail::equal_result::EMPTY or cuco::detail::bitwise_compare(slot_content.first, ref_.impl_.erased_key_sentinel())) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); if (attempt_insert_or_assign( - (storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) { + (storage_ref.data() + *probing_iter)->data() + intra_window_index, val)) { return; } } @@ -286,7 +292,7 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert @@ -297,7 +303,8 @@ class operator_impl< { ref_type& ref_ = static_cast(*this); - auto const key = value.first; + auto const val = ref_.impl_.heterogeneous_value(value); + auto const key = ref_.impl_.extract_key(val); auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.window_extent()); @@ -332,7 +339,7 @@ class operator_impl< if (group.thread_rank() == src_lane) { ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + val.second); } group.sync(); return; @@ -345,7 +352,7 @@ class operator_impl< auto const status = (group.thread_rank() == src_lane) ? attempt_insert_or_assign( - (storage_ref.data() + *probing_iter)->data() + intra_window_index, value) + (storage_ref.data() + *probing_iter)->data() + intra_window_index, val) : false; // Exit if inserted or assigned @@ -377,7 +384,8 @@ class operator_impl< ref_type& ref_ = static_cast(*this); auto const expected_key = ref_.impl_.empty_slot_sentinel().first; - auto old_key = ref_.impl_.compare_and_swap(&slot->first, expected_key, value.first); + auto old_key = + ref_.impl_.compare_and_swap(&slot->first, expected_key, static_cast(value.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success or key was already present in the map @@ -406,6 +414,7 @@ class operator_impl< using ref_type = static_map_ref; using key_type = typename base_type::key_type; using value_type = typename base_type::value_type; + using mapped_type = T; using iterator = typename base_type::iterator; using const_iterator = typename base_type::const_iterator; @@ -446,7 +455,7 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * @@ -467,7 +476,7 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert @@ -506,7 +515,7 @@ class operator_impl< /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param key The element to erase * @@ -522,7 +531,7 @@ class operator_impl< /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group insert * @param key The element to erase @@ -563,7 +572,7 @@ class operator_impl< * @note If the probe key `key` was inserted into the container, returns * true. Otherwise, returns false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param key The key to search for * @@ -583,7 +592,7 @@ class operator_impl< * @note If the probe key `key` was inserted into the container, returns * true. Otherwise, returns false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group contains * @param key The key to search for @@ -652,7 +661,7 @@ class operator_impl< * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param key The key to search for * @@ -672,7 +681,7 @@ class operator_impl< * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform this operation * @param key The key to search for diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index c21e0dddb..f06a5f201 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -148,7 +148,7 @@ class operator_impl #include +#include #include - #include namespace cuco::detail { @@ -36,6 +36,20 @@ struct is_std_pair_like::value == 2, cuda::std::true_type, cuda::std::false_type> { }; +template +struct is_cuda_std_pair_like : cuda::std::false_type { +}; + +template +struct is_cuda_std_pair_like< + T, + cuda::std::void_t(cuda::std::declval())), + decltype(cuda::std::get<1>(cuda::std::declval()))>> + : cuda::std::conditional_t::value == 2, + cuda::std::true_type, + cuda::std::false_type> { +}; + template struct is_thrust_pair_like_impl : cuda::std::false_type { }; diff --git a/include/cuco/pair.cuh b/include/cuco/pair.cuh index d28cae5da..824a07f07 100644 --- a/include/cuco/pair.cuh +++ b/include/cuco/pair.cuh @@ -23,6 +23,7 @@ #include #include +#include #include namespace cuco { @@ -86,6 +87,19 @@ struct alignas(detail::pair_alignment()) pair { * @param p The input pair to copy from */ template ::value>* = nullptr> + __host__ __device__ constexpr pair(T const& p) + : pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))} + { + } + + /** + * @brief Constructs a pair from the given cuda::std::pair-like `p`. + * + * @tparam T Type of the pair to copy from + * + * @param p The input pair to copy from + */ + template ::value>* = nullptr> __host__ __device__ constexpr pair(T const& p) : pair{cuda::std::get<0>(thrust::raw_reference_cast(p)), cuda::std::get<1>(thrust::raw_reference_cast(p))} diff --git a/tests/static_map/erase_test.cu b/tests/static_map/erase_test.cu index aab0df6d8..dc59ac2d9 100644 --- a/tests/static_map/erase_test.cu +++ b/tests/static_map/erase_test.cu @@ -19,31 +19,31 @@ #include #include -#include #include -#include -#include -#include +#include +#include #include +#include + using size_type = int32_t; template void test_erase(Map& map, size_type num_keys) { - using Key = typename Map::key_type; - using Value = typename Map::mapped_type; + using key_type = typename Map::key_type; + using mapped_type = typename Map::mapped_type; - thrust::device_vector d_keys(num_keys); - thrust::device_vector d_values(num_keys); thrust::device_vector d_keys_exist(num_keys); - thrust::sequence(thrust::device, d_keys.begin(), d_keys.end(), 1); - thrust::sequence(thrust::device, d_values.begin(), d_values.end(), 1); + auto keys_begin = thrust::counting_iterator(1); - auto pairs_begin = - thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin())); + auto pairs_begin = thrust::make_transform_iterator( + keys_begin, + cuda::proclaim_return_type>([] __device__(key_type const& x) { + return cuco::pair(x, static_cast(x)); + })); SECTION("Check basic insert/erase") { @@ -51,11 +51,11 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(map.size() == num_keys); - map.erase(d_keys.begin(), d_keys.end()); + map.erase(keys_begin, keys_begin + num_keys); REQUIRE(map.size() == 0); - map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); REQUIRE(cuco::test::none_of(d_keys_exist.begin(), d_keys_exist.end(), thrust::identity{})); @@ -63,12 +63,12 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(map.size() == num_keys); - map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); REQUIRE(cuco::test::all_of(d_keys_exist.begin(), d_keys_exist.end(), thrust::identity{})); - map.erase(d_keys.begin(), d_keys.begin() + num_keys / 2); - map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); + map.erase(keys_begin, keys_begin + num_keys / 2); + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); REQUIRE(cuco::test::none_of( d_keys_exist.begin(), d_keys_exist.begin() + num_keys / 2, thrust::identity{})); @@ -76,7 +76,7 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of( d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{})); - map.erase(d_keys.begin() + num_keys / 2, d_keys.end()); + map.erase(keys_begin + num_keys / 2, keys_begin + num_keys); REQUIRE(map.size() == 0); } } diff --git a/tests/static_map/rehash_test.cu b/tests/static_map/rehash_test.cu index 69a73c6b3..bbfc4278f 100644 --- a/tests/static_map/rehash_test.cu +++ b/tests/static_map/rehash_test.cu @@ -16,13 +16,13 @@ #include -#include -#include -#include -#include +#include +#include #include +#include + TEST_CASE("Rehash", "") { using key_type = int; @@ -36,14 +36,13 @@ TEST_CASE("Rehash", "") cuco::empty_value{-1}, cuco::erased_key{-2}}; - thrust::device_vector d_keys(num_keys); - thrust::device_vector d_values(num_keys); - - thrust::sequence(d_keys.begin(), d_keys.end()); - thrust::sequence(d_values.begin(), d_values.end()); + auto keys_begin = thrust::counting_iterator(1); - auto pairs_begin = - thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin())); + auto pairs_begin = thrust::make_transform_iterator( + keys_begin, + cuda::proclaim_return_type>([] __device__(key_type const& x) { + return cuco::pair(x, static_cast(x)); + })); map.insert(pairs_begin, pairs_begin + num_keys); @@ -53,7 +52,7 @@ TEST_CASE("Rehash", "") map.rehash(num_keys * 2); REQUIRE(map.size() == num_keys); - map.erase(d_keys.begin(), d_keys.begin() + num_erased_keys); + map.erase(keys_begin, keys_begin + num_erased_keys); map.rehash(); REQUIRE(map.size() == num_keys - num_erased_keys); }