From 17112a82c74cb9feacd6427b1fa87196ad4399f3 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Tue, 17 Oct 2023 00:05:15 +0000 Subject: [PATCH 01/13] Add stable alternative to attemt_insert --- .../open_addressing_ref_impl.cuh | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 683bf94b1..1756bc560 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -401,13 +401,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY or cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), this->erased_key_sentinel())) { - switch ([&]() { - if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, window_slots[i], value); - } else { - return cas_dependent_write(window_ptr + i, window_slots[i], value); - } - }()) { + switch (this->attempt_insert_stable(window_ptr + i, window_slots[i], value)) { case insert_result::SUCCESS: { return {iterator{&window_ptr[i]}, true}; } @@ -485,11 +479,7 @@ class open_addressing_ref_impl { auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); auto const status = [&, target_idx = intra_window_index]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } - if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, window_slots[target_idx], value); - } else { - return cas_dependent_write(slot_ptr, window_slots[target_idx], value); - } + return this->attempt_insert_stable(slot_ptr, window_slots[target_idx], value); }(); switch (group.shfl(status, src_lane)) { @@ -1054,6 +1044,35 @@ class open_addressing_ref_impl { } } + /** + * @brief Attempts to insert an element into a slot. + * + * @note Dispatches the correct implementation depending on the container + * type and presence of other operator mixins. + * + * @note `stable` here means that the payload will only be updated once from the sentinel value to + * the payload value + * + * @tparam Value Input type which is implicitly convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ insert_result attempt_insert_stable(value_type* address, + value_type const& expected, + Value const& desired) noexcept + { + if constexpr (sizeof(value_type) <= 8) { + return packed_cas(address, expected, desired); + } else { + return cas_dependent_write(address, expected, desired); + } + } + // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot detail::equal_wrapper predicate_; ///< Key equality binary callable From 8bf6ed99c5cd99f1507ef06c91e173d611a7c180 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Tue, 17 Oct 2023 00:16:35 +0000 Subject: [PATCH 02/13] Switch to thrust::get where needed --- .../cuco/detail/static_map/static_map_ref.inl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index f27f21e76..4cc3297b4 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -18,6 +18,8 @@ #include +#include + #include #include @@ -248,7 +250,7 @@ class operator_impl< static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); ref_type& ref_ = static_cast(*this); - auto const key = value.first; + auto const key = thrust::get<0>(thrust::raw_reference_cast(value)); auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.window_extent()); @@ -264,7 +266,7 @@ class operator_impl< auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + static_cast(thrust::get<1>(value))); return; } if (eq_res == detail::equal_result::EMPTY or @@ -297,7 +299,7 @@ class operator_impl< { ref_type& ref_ = static_cast(*this); - auto const key = value.first; + auto const key = thrust::get<0>(thrust::raw_reference_cast(value)); auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.window_extent()); @@ -332,7 +334,7 @@ class operator_impl< if (group.thread_rank() == src_lane) { ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + static_cast(thrust::get<1>(value))); } group.sync(); return; @@ -377,15 +379,17 @@ class operator_impl< ref_type& ref_ = static_cast(*this); auto const expected_key = ref_.impl_.empty_slot_sentinel().first; - auto old_key = ref_.impl_.compare_and_swap(&slot->first, expected_key, value.first); + auto old_key = ref_.impl_.compare_and_swap( + &slot->first, expected_key, static_cast(thrust::get<0>(value))); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success or key was already present in the map if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key) or - (ref_.impl_.predicate().equal_to(*old_key_ptr, value.first) == + (ref_.impl_.predicate().equal_to(*old_key_ptr, + thrust::get<0>(thrust::raw_reference_cast(value))) == detail::equal_result::EQUAL)) { // Update payload - ref_.impl_.atomic_store(&slot->second, value.second); + ref_.impl_.atomic_store(&slot->second, static_cast(thrust::get<1>(value))); return true; } return false; From cf3a3cc1ed6f2a4631a2c8a670e38c38778bb111 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Thu, 2 Nov 2023 11:04:27 -0700 Subject: [PATCH 03/13] Move common kernels and functors to the open addressing folder (#388) There are no actual code changes but moving files around. --- .../{common_functors.cuh => open_addressing/functors.cuh} | 0 .../{common_kernels.cuh => open_addressing/kernels.cuh} | 0 include/cuco/detail/open_addressing/open_addressing_impl.cuh | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) rename include/cuco/detail/{common_functors.cuh => open_addressing/functors.cuh} (100%) rename include/cuco/detail/{common_kernels.cuh => open_addressing/kernels.cuh} (100%) diff --git a/include/cuco/detail/common_functors.cuh b/include/cuco/detail/open_addressing/functors.cuh similarity index 100% rename from include/cuco/detail/common_functors.cuh rename to include/cuco/detail/open_addressing/functors.cuh diff --git a/include/cuco/detail/common_kernels.cuh b/include/cuco/detail/open_addressing/kernels.cuh similarity index 100% rename from include/cuco/detail/common_kernels.cuh rename to include/cuco/detail/open_addressing/kernels.cuh diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index 712ff85ee..79fbfb874 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include -#include +#include +#include #include #include #include From 3a9b747804de8cbe37a6060c7336fec31615e154 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Fri, 3 Nov 2023 14:32:06 +0000 Subject: [PATCH 04/13] Update docs on stable insert --- .../cuco/detail/open_addressing/open_addressing_ref_impl.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 1756bc560..f8a2e7642 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -1050,8 +1050,8 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * - * @note `stable` here means that the payload will only be updated once from the sentinel value to - * the payload value + * @note `stable` indicates that the payload will only be updated once from the sentinel value to + * the desired value, meaning there can be no ABA situations. * * @tparam Value Input type which is implicitly convertible to 'value_type' * From 1dd1648db8ddc9db980fbd26a3e794488ef7e2f8 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Fri, 3 Nov 2023 23:34:33 +0000 Subject: [PATCH 05/13] Update pair traits to use cuda::std --- include/cuco/detail/traits.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/cuco/detail/traits.hpp b/include/cuco/detail/traits.hpp index 313f95430..602a93251 100644 --- a/include/cuco/detail/traits.hpp +++ b/include/cuco/detail/traits.hpp @@ -18,10 +18,9 @@ #include #include +#include #include -#include - namespace cuco::detail { template @@ -30,10 +29,11 @@ struct is_std_pair_like : cuda::std::false_type { template struct is_std_pair_like(cuda::std::declval())), - decltype(std::get<1>(cuda::std::declval()))>> - : cuda::std:: - conditional_t::value == 2, cuda::std::true_type, cuda::std::false_type> { + cuda::std::void_t(cuda::std::declval())), + decltype(cuda::std::get<1>(cuda::std::declval()))>> + : cuda::std::conditional_t::value == 2, + cuda::std::true_type, + cuda::std::false_type> { }; template From b7c114ed107aa7a435c336e3e5ddfce572e1fc83 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Sat, 4 Nov 2023 01:00:45 +0000 Subject: [PATCH 06/13] Fix heterogeneous insert for static_map --- .../open_addressing_ref_impl.cuh | 84 +++++++++++-------- .../cuco/detail/static_map/static_map_ref.inl | 74 ++++++++-------- .../cuco/detail/static_set/static_set_ref.inl | 36 ++++---- tests/static_map/erase_test.cu | 32 ++++--- tests/static_map/rehash_test.cu | 19 ++--- 5 files changed, 126 insertions(+), 119 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index f8a2e7642..d1f4396e4 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -23,7 +23,6 @@ #include #include -#include #include #include @@ -262,7 +261,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * @@ -304,7 +303,7 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert @@ -374,7 +373,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * @@ -423,7 +422,7 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert @@ -500,7 +499,7 @@ class open_addressing_ref_impl { /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param value The element to erase * @@ -540,7 +539,7 @@ class open_addressing_ref_impl { /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group erase * @param value The element to erase @@ -600,7 +599,7 @@ class open_addressing_ref_impl { * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns * false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param key The key to search for * @@ -633,7 +632,7 @@ class open_addressing_ref_impl { * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns * false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group contains * @param key The key to search for @@ -673,7 +672,7 @@ class open_addressing_ref_impl { * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param key The key to search for * @@ -710,7 +709,7 @@ class open_addressing_ref_impl { * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform this operation * @param key The key to search for @@ -845,7 +844,7 @@ class open_addressing_ref_impl { /** * @brief Extracts the key from a given value type. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The input value * @@ -856,7 +855,7 @@ class open_addressing_ref_impl { Value const& value) const noexcept { if constexpr (this->has_payload) { - return thrust::get<0>(thrust::raw_reference_cast(value)); + return thrust::raw_reference_cast(value).first; } else { return thrust::raw_reference_cast(value); } @@ -867,7 +866,7 @@ class open_addressing_ref_impl { * * @note This function is only available if `this->has_payload == true` * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The input value * @@ -877,7 +876,26 @@ class open_addressing_ref_impl { [[nodiscard]] __host__ __device__ constexpr auto const& extract_payload( Value const& value) const noexcept { - return thrust::get<1>(thrust::raw_reference_cast(value)); + return thrust::raw_reference_cast(value).second; + } + + /** + * @brief Converts the given type to the container's native `value_type`. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __host__ __device__ constexpr value_type native_value(T const& value) const noexcept + { + if constexpr (this->has_payload) { + return {static_cast(this->extract_key(value)), this->extract_payload(value)}; + } else { + return static_cast(value); + } } /** @@ -897,7 +915,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -910,7 +928,7 @@ class open_addressing_ref_impl { value_type const& expected, Value const& desired) noexcept { - auto old = compare_and_swap(address, expected, static_cast(desired)); + auto old = compare_and_swap(address, expected, this->native_value(desired)); auto* old_ptr = reinterpret_cast(&old); if (cuco::detail::bitwise_compare(this->extract_key(*old_ptr), this->extract_key(expected))) { return insert_result::SUCCESS; @@ -925,7 +943,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with two back-to-back CAS operations. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -943,10 +961,9 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; auto const expected_payload = expected.second; - auto old_key = compare_and_swap( - &address->first, expected_key, static_cast(thrust::get<0>(desired))); - auto old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(thrust::get<1>(desired))); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast(desired.first)); + auto old_payload = compare_and_swap(&address->second, expected_payload, desired.second); auto* old_key_ptr = reinterpret_cast(&old_key); auto* old_payload_ptr = reinterpret_cast(&old_payload); @@ -954,8 +971,7 @@ class open_addressing_ref_impl { // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { - old_payload = compare_and_swap( - &address->second, expected_payload, static_cast(thrust::get<1>(desired))); + old_payload = compare_and_swap(&address->second, expected_payload, desired.second); } return insert_result::SUCCESS; } else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) { @@ -964,9 +980,7 @@ class open_addressing_ref_impl { // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, - thrust::get<0>(thrust::raw_reference_cast(desired))) == - detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -976,7 +990,7 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with CAS-dependent write operations. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -992,22 +1006,20 @@ class open_addressing_ref_impl { auto const expected_key = expected.first; - auto old_key = compare_and_swap( - &address->first, expected_key, static_cast(thrust::get<0>(desired))); + auto old_key = + compare_and_swap(&address->first, expected_key, static_cast(desired.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) { - atomic_store(&address->second, static_cast(thrust::get<1>(desired))); + atomic_store(&address->second, desired.second); return insert_result::SUCCESS; } // Our key was already present in the slot, so our key is a duplicate // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare - if (this->predicate_.equal_to(*old_key_ptr, - thrust::get<0>(thrust::raw_reference_cast(desired))) == - detail::equal_result::EQUAL) { + if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) { return insert_result::DUPLICATE; } @@ -1020,7 +1032,7 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against @@ -1053,7 +1065,7 @@ class open_addressing_ref_impl { * @note `stable` indicates that the payload will only be updated once from the sentinel value to * the desired value, meaning there can be no ABA situations. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param address Pointer to the slot in memory * @param expected Element to compare against diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 4cc3297b4..c24d68842 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -173,7 +173,8 @@ class operator_impl< using base_type = static_map_ref; using ref_type = static_map_ref; using key_type = typename base_type::key_type; - using value_type = typename base_type::value_type; + using value_type = typename base_type::value_type; + using mapped_type = T; static constexpr auto cg_size = base_type::cg_size; static constexpr auto window_size = base_type::window_size; @@ -182,14 +183,14 @@ class operator_impl< /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param value The element to insert * * @return True if the given element is successfully inserted */ - template - __device__ bool insert(Value const& value) noexcept + template + __device__ bool insert(cuco::pair const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert(value); @@ -198,16 +199,16 @@ class operator_impl< /** * @brief Inserts an element. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - Value const& value) noexcept + cuco::pair const& value) noexcept { auto& ref_ = static_cast(*this); return ref_.impl_.insert(group, value); @@ -227,7 +228,8 @@ class operator_impl< using base_type = static_map_ref; using ref_type = static_map_ref; using key_type = typename base_type::key_type; - using value_type = typename base_type::value_type; + using value_type = typename base_type::value_type; + using mapped_type = T; static constexpr auto cg_size = base_type::cg_size; static constexpr auto window_size = base_type::window_size; @@ -240,17 +242,17 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param value The element to insert */ - template - __device__ void insert_or_assign(Value const& value) noexcept + template + __device__ void insert_or_assign(cuco::pair const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); ref_type& ref_ = static_cast(*this); - auto const key = thrust::get<0>(thrust::raw_reference_cast(value)); + auto const key = value.first; auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.window_extent()); @@ -266,7 +268,7 @@ class operator_impl< auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - static_cast(thrust::get<1>(value))); + value.second); return; } if (eq_res == detail::equal_result::EMPTY or @@ -288,18 +290,18 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert */ - template + template __device__ void insert_or_assign(cooperative_groups::thread_block_tile const& group, - Value const& value) noexcept + cuco::pair const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto const key = thrust::get<0>(thrust::raw_reference_cast(value)); + auto const key = value.first; auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.window_extent()); @@ -334,7 +336,7 @@ class operator_impl< if (group.thread_rank() == src_lane) { ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - static_cast(thrust::get<1>(value))); + value.second); } group.sync(); return; @@ -379,17 +381,16 @@ class operator_impl< ref_type& ref_ = static_cast(*this); auto const expected_key = ref_.impl_.empty_slot_sentinel().first; - auto old_key = ref_.impl_.compare_and_swap( - &slot->first, expected_key, static_cast(thrust::get<0>(value))); + auto old_key = + ref_.impl_.compare_and_swap(&slot->first, expected_key, static_cast(value.first)); auto* old_key_ptr = reinterpret_cast(&old_key); // if key success or key was already present in the map if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key) or - (ref_.impl_.predicate().equal_to(*old_key_ptr, - thrust::get<0>(thrust::raw_reference_cast(value))) == + (ref_.impl_.predicate().equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL)) { // Update payload - ref_.impl_.atomic_store(&slot->second, static_cast(thrust::get<1>(value))); + ref_.impl_.atomic_store(&slot->second, value.second); return true; } return false; @@ -410,6 +411,7 @@ class operator_impl< using ref_type = static_map_ref; using key_type = typename base_type::key_type; using value_type = typename base_type::value_type; + using mapped_type = T; using iterator = typename base_type::iterator; using const_iterator = typename base_type::const_iterator; @@ -450,15 +452,16 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param value The element to insert * * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template - __device__ thrust::pair insert_and_find(Value const& value) noexcept + template + __device__ thrust::pair insert_and_find( + cuco::pair const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(value); @@ -471,7 +474,7 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam Value Input type which is implicitly convertible to 'value_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert @@ -479,9 +482,10 @@ class operator_impl< * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, Value const& value) noexcept + cooperative_groups::thread_block_tile const& group, + cuco::pair const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(group, value); @@ -510,7 +514,7 @@ class operator_impl< /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param key The element to erase * @@ -526,7 +530,7 @@ class operator_impl< /** * @brief Erases an element. * - * @tparam ProbeKey Input type which is implicitly convertible to 'key_type' + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group insert * @param key The element to erase @@ -567,7 +571,7 @@ class operator_impl< * @note If the probe key `key` was inserted into the container, returns * true. Otherwise, returns false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param key The key to search for * @@ -587,7 +591,7 @@ class operator_impl< * @note If the probe key `key` was inserted into the container, returns * true. Otherwise, returns false. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform group contains * @param key The key to search for @@ -656,7 +660,7 @@ class operator_impl< * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param key The key to search for * @@ -676,7 +680,7 @@ class operator_impl< * @note Returns a un-incrementable input iterator to the element whose key is equivalent to * `key`. If no such element exists, returns `end()`. * - * @tparam ProbeKey Probe key type + * @tparam ProbeKey Input key type which is convertible to 'key_type' * * @param group The Cooperative Group used to perform this operation * @param key The key to search for diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index c21e0dddb..f9da929f4 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -148,14 +148,14 @@ class operator_impl - __device__ bool insert(Value const& value) noexcept + template + __device__ bool insert(ProbeKey const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert(value); @@ -164,16 +164,16 @@ class operator_impl + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - Value const& value) noexcept + ProbeKey const& value) noexcept { auto& ref_ = static_cast(*this); return ref_.impl_.insert(group, value); @@ -232,15 +232,15 @@ class operator_impl - __device__ thrust::pair insert_and_find(Value const& value) noexcept + template + __device__ thrust::pair insert_and_find(ProbeKey const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(value); @@ -253,7 +253,7 @@ class operator_impl + template __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, Value const& value) noexcept + cooperative_groups::thread_block_tile const& group, ProbeKey const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(group, value); @@ -290,7 +290,7 @@ class operator_impl #include -#include #include -#include -#include -#include +#include +#include #include @@ -32,18 +30,16 @@ using size_type = int32_t; template void test_erase(Map& map, size_type num_keys) { - using Key = typename Map::key_type; - using Value = typename Map::mapped_type; + using key_type = typename Map::key_type; + using mapped_type = typename Map::mapped_type; - thrust::device_vector d_keys(num_keys); - thrust::device_vector d_values(num_keys); thrust::device_vector d_keys_exist(num_keys); - thrust::sequence(thrust::device, d_keys.begin(), d_keys.end(), 1); - thrust::sequence(thrust::device, d_values.begin(), d_values.end(), 1); + auto keys_begin = thrust::counting_iterator(1); - auto pairs_begin = - thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin())); + auto pairs_begin = thrust::make_transform_iterator(keys_begin, [] __device__(key_type const& x) { + return cuco::pair(x, static_cast(x)); + }); SECTION("Check basic insert/erase") { @@ -51,11 +47,11 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(map.size() == num_keys); - map.erase(d_keys.begin(), d_keys.end()); + map.erase(keys_begin, keys_begin + num_keys); REQUIRE(map.size() == 0); - map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); REQUIRE(cuco::test::none_of(d_keys_exist.begin(), d_keys_exist.end(), thrust::identity{})); @@ -63,12 +59,12 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(map.size() == num_keys); - map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); REQUIRE(cuco::test::all_of(d_keys_exist.begin(), d_keys_exist.end(), thrust::identity{})); - map.erase(d_keys.begin(), d_keys.begin() + num_keys / 2); - map.contains(d_keys.begin(), d_keys.end(), d_keys_exist.begin()); + map.erase(keys_begin, keys_begin + num_keys / 2); + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); REQUIRE(cuco::test::none_of( d_keys_exist.begin(), d_keys_exist.begin() + num_keys / 2, thrust::identity{})); @@ -76,7 +72,7 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of( d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{})); - map.erase(d_keys.begin() + num_keys / 2, d_keys.end()); + map.erase(keys_begin + num_keys / 2, keys_begin + num_keys); REQUIRE(map.size() == 0); } } diff --git a/tests/static_map/rehash_test.cu b/tests/static_map/rehash_test.cu index 69a73c6b3..55693f31e 100644 --- a/tests/static_map/rehash_test.cu +++ b/tests/static_map/rehash_test.cu @@ -16,10 +16,8 @@ #include -#include -#include -#include -#include +#include +#include #include @@ -36,14 +34,11 @@ TEST_CASE("Rehash", "") cuco::empty_value{-1}, cuco::erased_key{-2}}; - thrust::device_vector d_keys(num_keys); - thrust::device_vector d_values(num_keys); + auto keys_begin = thrust::counting_iterator(1); - thrust::sequence(d_keys.begin(), d_keys.end()); - thrust::sequence(d_values.begin(), d_values.end()); - - auto pairs_begin = - thrust::make_zip_iterator(thrust::make_tuple(d_keys.begin(), d_values.begin())); + auto pairs_begin = thrust::make_transform_iterator(keys_begin, [] __device__(key_type const& x) { + return cuco::pair(x, static_cast(x)); + }); map.insert(pairs_begin, pairs_begin + num_keys); @@ -53,7 +48,7 @@ TEST_CASE("Rehash", "") map.rehash(num_keys * 2); REQUIRE(map.size() == num_keys); - map.erase(d_keys.begin(), d_keys.begin() + num_erased_keys); + map.erase(keys_begin, keys_begin + num_erased_keys); map.rehash(); REQUIRE(map.size() == num_keys - num_erased_keys); } From 614f7bf1e6138245234697c11faaa988105f6aa3 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Mon, 20 Nov 2023 22:28:09 +0000 Subject: [PATCH 07/13] Fix is_(cuda_)std_pair_like --- include/cuco/detail/traits.hpp | 22 ++++++++++++++++------ include/cuco/pair.cuh | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/include/cuco/detail/traits.hpp b/include/cuco/detail/traits.hpp index 602a93251..2b24cddc7 100644 --- a/include/cuco/detail/traits.hpp +++ b/include/cuco/detail/traits.hpp @@ -20,6 +20,7 @@ #include #include +#include namespace cuco::detail { @@ -29,15 +30,24 @@ struct is_std_pair_like : cuda::std::false_type { template struct is_std_pair_like(cuda::std::declval())), - decltype(cuda::std::get<1>(cuda::std::declval()))>> - : cuda::std::conditional_t::value == 2, - cuda::std::true_type, - cuda::std::false_type> { + cuda::std::void_t(cuda::std::declval())), + decltype(std::get<1>(cuda::std::declval()))>> + : cuda::std:: + conditional_t::value == 2, cuda::std::true_type, cuda::std::false_type> { }; template -struct is_thrust_pair_like_impl : cuda::std::false_type { +struct is_cuda_std_pair_like : cuda::std::false_type { +}; + +template +struct is_cuda_std_pair_like< + T, + cuda::std::void_t(cuda::std::declval())), + decltype(cuda::std::get<1>(cuda::std::declval()))>> + : cuda::std::conditional_t::value == 2, + cuda::std::true_type, + cuda::std::false_type> { }; template diff --git a/include/cuco/pair.cuh b/include/cuco/pair.cuh index d28cae5da..824a07f07 100644 --- a/include/cuco/pair.cuh +++ b/include/cuco/pair.cuh @@ -23,6 +23,7 @@ #include #include +#include #include namespace cuco { @@ -86,6 +87,19 @@ struct alignas(detail::pair_alignment()) pair { * @param p The input pair to copy from */ template ::value>* = nullptr> + __host__ __device__ constexpr pair(T const& p) + : pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))} + { + } + + /** + * @brief Constructs a pair from the given cuda::std::pair-like `p`. + * + * @tparam T Type of the pair to copy from + * + * @param p The input pair to copy from + */ + template ::value>* = nullptr> __host__ __device__ constexpr pair(T const& p) : pair{cuda::std::get<0>(thrust::raw_reference_cast(p)), cuda::std::get<1>(thrust::raw_reference_cast(p))} From 9c935339bb1a7ac773ebd7a5223f80bc8cf16ca4 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Mon, 20 Nov 2023 22:40:51 +0000 Subject: [PATCH 08/13] Fix merge error --- .../open_addressing_ref_impl.cuh | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 3640add3a..3f013210f 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -1133,35 +1133,36 @@ class open_addressing_ref_impl { } else { return cas_dependent_write(address, expected, desired); } + } - /** - * @brief Waits until the slot payload has been updated - * - * @note The function will return once the slot payload is no longer equal to the sentinel - * value. - * - * @tparam T Map slot type - * - * @param slot The target slot to check payload with - * @param sentinel The slot sentinel value - */ - template - __device__ void wait_for_payload(T & slot, T const& sentinel) const noexcept - { - auto ref = cuda::atomic_ref{slot}; - T current; - // TODO exponential backoff strategy - do { - current = ref.load(cuda::std::memory_order_relaxed); - } while (cuco::detail::bitwise_compare(current, sentinel)); - } + /** + * @brief Waits until the slot payload has been updated + * + * @note The function will return once the slot payload is no longer equal to the sentinel + * value. + * + * @tparam T Map slot type + * + * @param slot The target slot to check payload with + * @param sentinel The slot sentinel value + */ + template + __device__ void wait_for_payload(T& slot, T const& sentinel) const noexcept + { + auto ref = cuda::atomic_ref{slot}; + T current; + // TODO exponential backoff strategy + do { + current = ref.load(cuda::std::memory_order_relaxed); + } while (cuco::detail::bitwise_compare(current, sentinel)); + } - // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper - value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot - detail::equal_wrapper predicate_; ///< Key equality binary callable - probing_scheme_type probing_scheme_; ///< Probing scheme - storage_ref_type storage_ref_; ///< Slot storage ref - }; + // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper + value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot + detail::equal_wrapper predicate_; ///< Key equality binary callable + probing_scheme_type probing_scheme_; ///< Probing scheme + storage_ref_type storage_ref_; ///< Slot storage ref +}; } // namespace detail } // namespace experimental From e86c65732aac926bb564fffaab5f472faa6731af Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Mon, 20 Nov 2023 22:45:44 +0000 Subject: [PATCH 09/13] Somehow lost the is_thrust_pair_like primary template --- include/cuco/detail/traits.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/cuco/detail/traits.hpp b/include/cuco/detail/traits.hpp index 2b24cddc7..9154c736a 100644 --- a/include/cuco/detail/traits.hpp +++ b/include/cuco/detail/traits.hpp @@ -50,6 +50,10 @@ struct is_cuda_std_pair_like< cuda::std::false_type> { }; +template +struct is_thrust_pair_like_impl : cuda::std::false_type { +}; + template struct is_thrust_pair_like_impl< T, From f5bf6d839a313bfbec3be8820a0fc7daee4e8883 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Tue, 21 Nov 2023 22:36:23 +0000 Subject: [PATCH 10/13] Re-enable arbitrary input pair types --- .../open_addressing_ref_impl.cuh | 52 +++++++++++++++--- .../cuco/detail/static_map/static_map_ref.inl | 55 ++++++++++--------- 2 files changed, 71 insertions(+), 36 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 3f013210f..314dad7f0 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -272,7 +272,8 @@ class open_addressing_ref_impl { { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -289,7 +290,7 @@ class open_addressing_ref_impl { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, slot_content, - value)) { + val)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; case insert_result::DUPLICATE: return false; @@ -314,7 +315,8 @@ class open_addressing_ref_impl { __device__ bool insert(cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -352,7 +354,7 @@ class open_addressing_ref_impl { (group.thread_rank() == src_lane) ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, window_slots[intra_window_index], - value) + val) : insert_result::CONTINUE; switch (group.shfl(status, src_lane)) { @@ -392,7 +394,8 @@ class open_addressing_ref_impl { "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); #endif - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(key, storage_ref_.window_extent()); while (true) { @@ -413,7 +416,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY or cuco::detail::bitwise_compare(this->extract_key(window_slots[i]), this->erased_key_sentinel())) { - switch (this->attempt_insert_stable(window_ptr + i, window_slots[i], value)) { + switch (this->attempt_insert_stable(window_ptr + i, window_slots[i], val)) { case insert_result::SUCCESS: { if constexpr (has_payload) { // wait to ensure that the write to the value part also took place @@ -463,7 +466,8 @@ class open_addressing_ref_impl { "insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs."); #endif - auto const key = this->extract_key(value); + auto const val = this->heterogeneous_value(value); + auto const key = this->extract_key(val); auto probing_iter = probing_scheme_(group, key, storage_ref_.window_extent()); while (true) { @@ -514,7 +518,7 @@ class open_addressing_ref_impl { auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); auto const status = [&, target_idx = intra_window_index]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } - return this->attempt_insert_stable(slot_ptr, window_slots[target_idx], value); + return this->attempt_insert_stable(slot_ptr, window_slots[target_idx], val); }(); switch (group.shfl(status, src_lane)) { @@ -890,7 +894,6 @@ class open_addressing_ref_impl { } } - private: /** * @brief Extracts the key from a given value type. * @@ -948,6 +951,37 @@ class open_addressing_ref_impl { } } + /** + * @brief Converts the given type to the container's native `value_type` while maintaining the + * heterogeneous key type. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __host__ __device__ constexpr auto heterogeneous_value( + T const& value) const noexcept + { + if constexpr (this->has_payload and not cuda::std::is_same_v) { + using mapped_type = decltype(this->empty_slot_sentinel_.second); + if constexpr (cuco::detail::is_cuda_std_pair_like::value) { + return cuco::pair{cuda::std::get<0>(value), + static_cast(cuda::std::get<1>(value))}; + } else if constexpr (cuco::detail::is_thrust_pair_like::value) { + return cuco::pair{thrust::get<0>(value), static_cast(thrust::get<1>(value))}; + } else { + // hail mary (convert using .first/.second members) + return cuco::pair{thrust::raw_reference_cast(value.first), + static_cast(value.second)}; + } + } else { + return thrust::raw_reference_cast(value); + } + } + /** * @brief Gets the sentinel used to represent an erased slot. * diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index c24d68842..c6a24bf7e 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -183,14 +183,14 @@ class operator_impl< /** * @brief Inserts an element. * - * @tparam ProbeKey Input key type which is convertible to 'key_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * * @return True if the given element is successfully inserted */ - template - __device__ bool insert(cuco::pair const& value) noexcept + template + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert(value); @@ -199,16 +199,16 @@ class operator_impl< /** * @brief Inserts an element. * - * @tparam ProbeKey Input key type which is convertible to 'key_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - cuco::pair const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); return ref_.impl_.insert(group, value); @@ -242,17 +242,19 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * - * @tparam ProbeKey Input key type which is convertible to 'key_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert */ - template - __device__ void insert_or_assign(cuco::pair const& value) noexcept + template + __device__ void insert_or_assign(Value const& value) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); - ref_type& ref_ = static_cast(*this); - auto const key = value.first; + ref_type& ref_ = static_cast(*this); + + auto const val = ref_.impl_.heterogeneous_value(value); + auto const key = ref_.impl_.extract_key(val); auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(key, storage_ref.window_extent()); @@ -268,14 +270,14 @@ class operator_impl< auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + val.second); return; } if (eq_res == detail::equal_result::EMPTY or cuco::detail::bitwise_compare(slot_content.first, ref_.impl_.erased_key_sentinel())) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); if (attempt_insert_or_assign( - (storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) { + (storage_ref.data() + *probing_iter)->data() + intra_window_index, val)) { return; } } @@ -290,18 +292,19 @@ class operator_impl< * @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v` * to the mapped_type corresponding to the key `k`. * - * @tparam ProbeKey Input key type which is convertible to 'key_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert * @param value The element to insert */ - template + template __device__ void insert_or_assign(cooperative_groups::thread_block_tile const& group, - cuco::pair const& value) noexcept + Value const& value) noexcept { ref_type& ref_ = static_cast(*this); - auto const key = value.first; + auto const val = ref_.impl_.heterogeneous_value(value); + auto const key = ref_.impl_.extract_key(val); auto& probing_scheme = ref_.impl_.probing_scheme(); auto storage_ref = ref_.impl_.storage_ref(); auto probing_iter = probing_scheme(group, key, storage_ref.window_extent()); @@ -336,7 +339,7 @@ class operator_impl< if (group.thread_rank() == src_lane) { ref_.impl_.atomic_store( &((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second, - value.second); + val.second); } group.sync(); return; @@ -349,7 +352,7 @@ class operator_impl< auto const status = (group.thread_rank() == src_lane) ? attempt_insert_or_assign( - (storage_ref.data() + *probing_iter)->data() + intra_window_index, value) + (storage_ref.data() + *probing_iter)->data() + intra_window_index, val) : false; // Exit if inserted or assigned @@ -452,16 +455,15 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam ProbeKey Input key type which is convertible to 'key_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param value The element to insert * * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template - __device__ thrust::pair insert_and_find( - cuco::pair const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(value); @@ -474,7 +476,7 @@ class operator_impl< * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam ProbeKey Input key type which is convertible to 'key_type' + * @tparam Value Input type which is convertible to 'value_type' * * @param group The Cooperative Group used to perform group insert_and_find * @param value The element to insert @@ -482,10 +484,9 @@ class operator_impl< * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, - cuco::pair const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(group, value); From 6cc9040acc17e1873d5b3976ac07438a7da3b644 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Tue, 21 Nov 2023 23:45:16 +0000 Subject: [PATCH 11/13] Sprinkle in some cuda::proclaim_return_type --- tests/static_map/erase_test.cu | 10 +++++++--- tests/static_map/rehash_test.cu | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/static_map/erase_test.cu b/tests/static_map/erase_test.cu index 803c200b3..dc59ac2d9 100644 --- a/tests/static_map/erase_test.cu +++ b/tests/static_map/erase_test.cu @@ -25,6 +25,8 @@ #include +#include + using size_type = int32_t; template @@ -37,9 +39,11 @@ void test_erase(Map& map, size_type num_keys) auto keys_begin = thrust::counting_iterator(1); - auto pairs_begin = thrust::make_transform_iterator(keys_begin, [] __device__(key_type const& x) { - return cuco::pair(x, static_cast(x)); - }); + auto pairs_begin = thrust::make_transform_iterator( + keys_begin, + cuda::proclaim_return_type>([] __device__(key_type const& x) { + return cuco::pair(x, static_cast(x)); + })); SECTION("Check basic insert/erase") { diff --git a/tests/static_map/rehash_test.cu b/tests/static_map/rehash_test.cu index 55693f31e..bbfc4278f 100644 --- a/tests/static_map/rehash_test.cu +++ b/tests/static_map/rehash_test.cu @@ -21,6 +21,8 @@ #include +#include + TEST_CASE("Rehash", "") { using key_type = int; @@ -36,9 +38,11 @@ TEST_CASE("Rehash", "") auto keys_begin = thrust::counting_iterator(1); - auto pairs_begin = thrust::make_transform_iterator(keys_begin, [] __device__(key_type const& x) { - return cuco::pair(x, static_cast(x)); - }); + auto pairs_begin = thrust::make_transform_iterator( + keys_begin, + cuda::proclaim_return_type>([] __device__(key_type const& x) { + return cuco::pair(x, static_cast(x)); + })); map.insert(pairs_begin, pairs_begin + num_keys); From b7ccda0711b951d0b9ab6f57e853810a2810f3b1 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Tue, 21 Nov 2023 23:46:55 +0000 Subject: [PATCH 12/13] Remove __host__ qualifier --- .../open_addressing/open_addressing_ref_impl.cuh | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 314dad7f0..1750d5949 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -904,8 +904,7 @@ class open_addressing_ref_impl { * @return The key */ template - [[nodiscard]] __host__ __device__ constexpr auto const& extract_key( - Value const& value) const noexcept + [[nodiscard]] __device__ constexpr auto const& extract_key(Value const& value) const noexcept { if constexpr (this->has_payload) { return thrust::raw_reference_cast(value).first; @@ -926,8 +925,7 @@ class open_addressing_ref_impl { * @return The payload */ template > - [[nodiscard]] __host__ __device__ constexpr auto const& extract_payload( - Value const& value) const noexcept + [[nodiscard]] __device__ constexpr auto const& extract_payload(Value const& value) const noexcept { return thrust::raw_reference_cast(value).second; } @@ -942,7 +940,7 @@ class open_addressing_ref_impl { * @return The converted object */ template - [[nodiscard]] __host__ __device__ constexpr value_type native_value(T const& value) const noexcept + [[nodiscard]] __device__ constexpr value_type native_value(T const& value) const noexcept { if constexpr (this->has_payload) { return {static_cast(this->extract_key(value)), this->extract_payload(value)}; @@ -962,8 +960,7 @@ class open_addressing_ref_impl { * @return The converted object */ template - [[nodiscard]] __host__ __device__ constexpr auto heterogeneous_value( - T const& value) const noexcept + [[nodiscard]] __device__ constexpr auto heterogeneous_value(T const& value) const noexcept { if constexpr (this->has_payload and not cuda::std::is_same_v) { using mapped_type = decltype(this->empty_slot_sentinel_.second); From f933f728d8477461233f5def165912cd0c6f8233 Mon Sep 17 00:00:00 2001 From: Daniel Juenger <2955913+sleeepyjack@users.noreply.github.com> Date: Wed, 22 Nov 2023 00:28:23 +0000 Subject: [PATCH 13/13] Use Value instead of InsertKey --- .../cuco/detail/static_set/static_set_ref.inl | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index f9da929f4..f06a5f201 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -148,14 +148,14 @@ class operator_impl - __device__ bool insert(ProbeKey const& value) noexcept + template + __device__ bool insert(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert(value); @@ -164,16 +164,16 @@ class operator_impl + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, - ProbeKey const& value) noexcept + Value const& value) noexcept { auto& ref_ = static_cast(*this); return ref_.impl_.insert(group, value); @@ -232,15 +232,15 @@ class operator_impl - __device__ thrust::pair insert_and_find(ProbeKey const& value) noexcept + template + __device__ thrust::pair insert_and_find(Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(value); @@ -253,7 +253,7 @@ class operator_impl + template __device__ thrust::pair insert_and_find( - cooperative_groups::thread_block_tile const& group, ProbeKey const& value) noexcept + cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { ref_type& ref_ = static_cast(*this); return ref_.impl_.insert_and_find(group, value);