Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 122 additions & 61 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Large diffs are not rendered by default.

53 changes: 31 additions & 22 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include <cuco/operator.hpp>

#include <thrust/tuple.h>

#include <cuda/atomic>

#include <cooperative_groups.h>
Expand Down Expand Up @@ -171,7 +173,8 @@ class operator_impl<
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using value_type = typename base_type::value_type;
using mapped_type = T;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;
Expand All @@ -180,7 +183,7 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
Expand All @@ -196,7 +199,7 @@ class operator_impl<
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
Expand Down Expand Up @@ -225,7 +228,8 @@ class operator_impl<
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using value_type = typename base_type::value_type;
using mapped_type = T;

static constexpr auto cg_size = base_type::cg_size;
static constexpr auto window_size = base_type::window_size;
Expand All @@ -238,7 +242,7 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*/
Expand All @@ -247,8 +251,10 @@ class operator_impl<
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

ref_type& ref_ = static_cast<ref_type&>(*this);
auto const key = value.first;
ref_type& ref_ = static_cast<ref_type&>(*this);

auto const val = ref_.impl_.heterogeneous_value(value);
auto const key = ref_.impl_.extract_key(val);
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(key, storage_ref.window_extent());
Expand All @@ -264,14 +270,14 @@ class operator_impl<
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
val.second);
return;
}
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(slot_content.first, ref_.impl_.erased_key_sentinel())) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
if (attempt_insert_or_assign(
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)) {
(storage_ref.data() + *probing_iter)->data() + intra_window_index, val)) {
return;
}
}
Expand All @@ -286,7 +292,7 @@ class operator_impl<
* @brief Inserts a key-value pair `{k, v}` if it's not present in the map. Otherwise, assigns `v`
* to the mapped_type corresponding to the key `k`.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
Expand All @@ -297,7 +303,8 @@ class operator_impl<
{
ref_type& ref_ = static_cast<ref_type&>(*this);

auto const key = value.first;
auto const val = ref_.impl_.heterogeneous_value(value);
auto const key = ref_.impl_.extract_key(val);
auto& probing_scheme = ref_.impl_.probing_scheme();
auto storage_ref = ref_.impl_.storage_ref();
auto probing_iter = probing_scheme(group, key, storage_ref.window_extent());
Expand Down Expand Up @@ -332,7 +339,7 @@ class operator_impl<
if (group.thread_rank() == src_lane) {
ref_.impl_.atomic_store(
&((storage_ref.data() + *probing_iter)->data() + intra_window_index)->second,
value.second);
val.second);
}
group.sync();
return;
Expand All @@ -345,7 +352,7 @@ class operator_impl<
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert_or_assign(
(storage_ref.data() + *probing_iter)->data() + intra_window_index, value)
(storage_ref.data() + *probing_iter)->data() + intra_window_index, val)
: false;

// Exit if inserted or assigned
Expand Down Expand Up @@ -377,7 +384,8 @@ class operator_impl<
ref_type& ref_ = static_cast<ref_type&>(*this);
auto const expected_key = ref_.impl_.empty_slot_sentinel().first;

auto old_key = ref_.impl_.compare_and_swap(&slot->first, expected_key, value.first);
auto old_key =
ref_.impl_.compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success or key was already present in the map
Expand Down Expand Up @@ -406,6 +414,7 @@ class operator_impl<
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
using key_type = typename base_type::key_type;
using value_type = typename base_type::value_type;
using mapped_type = T;
using iterator = typename base_type::iterator;
using const_iterator = typename base_type::const_iterator;

Expand Down Expand Up @@ -446,7 +455,7 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
Expand All @@ -467,7 +476,7 @@ class operator_impl<
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
Expand Down Expand Up @@ -506,7 +515,7 @@ class operator_impl<
/**
* @brief Erases an element.
*
* @tparam ProbeKey Input type which is implicitly convertible to 'key_type'
* @tparam ProbeKey Input key type which is convertible to 'key_type'
*
* @param key The element to erase
*
Expand All @@ -522,7 +531,7 @@ class operator_impl<
/**
* @brief Erases an element.
*
* @tparam ProbeKey Input type which is implicitly convertible to 'key_type'
* @tparam ProbeKey Input key type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform group insert
* @param key The element to erase
Expand Down Expand Up @@ -563,7 +572,7 @@ class operator_impl<
* @note If the probe key `key` was inserted into the container, returns
* true. Otherwise, returns false.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input key type which is convertible to 'key_type'
*
* @param key The key to search for
*
Expand All @@ -583,7 +592,7 @@ class operator_impl<
* @note If the probe key `key` was inserted into the container, returns
* true. Otherwise, returns false.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input key type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform group contains
* @param key The key to search for
Expand Down Expand Up @@ -652,7 +661,7 @@ class operator_impl<
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input key type which is convertible to 'key_type'
*
* @param key The key to search for
*
Expand All @@ -672,7 +681,7 @@ class operator_impl<
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input key type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
Expand Down
20 changes: 10 additions & 10 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class operator_impl<op::insert_tag,
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
Expand All @@ -164,7 +164,7 @@ class operator_impl<op::insert_tag,
/**
* @brief Inserts an element.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert
* @param value The element to insert
Expand Down Expand Up @@ -232,7 +232,7 @@ class operator_impl<op::insert_and_find_tag,
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param value The element to insert
*
Expand All @@ -253,7 +253,7 @@ class operator_impl<op::insert_and_find_tag,
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Value Input type which is convertible to 'value_type'
*
* @param group The Cooperative Group used to perform group insert_and_find
* @param value The element to insert
Expand Down Expand Up @@ -290,7 +290,7 @@ class operator_impl<op::erase_tag,
/**
* @brief Erases an element.
*
* @tparam ProbeKey Input type which is implicitly convertible to 'key_type'
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param key The element to erase
*
Expand All @@ -306,7 +306,7 @@ class operator_impl<op::erase_tag,
/**
* @brief Erases an element.
*
* @tparam ProbeKey Input type which is implicitly convertible to 'key_type'
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform group erase
* @param value The element to erase
Expand Down Expand Up @@ -345,7 +345,7 @@ class operator_impl<op::contains_tag,
* @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns
* false.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param key The key to search for
*
Expand All @@ -364,7 +364,7 @@ class operator_impl<op::contains_tag,
* @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns
* false.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform group contains
* @param key The key to search for
Expand Down Expand Up @@ -431,7 +431,7 @@ class operator_impl<op::find_tag,
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param key The key to search for
*
Expand All @@ -451,7 +451,7 @@ class operator_impl<op::find_tag,
* @note Returns a un-incrementable input iterator to the element whose key is equivalent to
* `key`. If no such element exists, returns `end()`.
*
* @tparam ProbeKey Probe key type
* @tparam ProbeKey Input type which is convertible to 'key_type'
*
* @param group The Cooperative Group used to perform this operation
* @param key The key to search for
Expand Down
16 changes: 15 additions & 1 deletion include/cuco/detail/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include <thrust/device_reference.h>
#include <thrust/tuple.h>

#include <cuda/std/tuple>
#include <cuda/std/type_traits>

#include <tuple>

namespace cuco::detail {
Expand All @@ -36,6 +36,20 @@ struct is_std_pair_like<T,
conditional_t<std::tuple_size<T>::value == 2, cuda::std::true_type, cuda::std::false_type> {
};

template <typename T, typename = void>
struct is_cuda_std_pair_like : cuda::std::false_type {
};

template <typename T>
struct is_cuda_std_pair_like<
T,
cuda::std::void_t<decltype(cuda::std::get<0>(cuda::std::declval<T>())),
decltype(cuda::std::get<1>(cuda::std::declval<T>()))>>
: cuda::std::conditional_t<cuda::std::tuple_size<T>::value == 2,
cuda::std::true_type,
cuda::std::false_type> {
};

template <typename T, typename = void>
struct is_thrust_pair_like_impl : cuda::std::false_type {
};
Expand Down
14 changes: 14 additions & 0 deletions include/cuco/pair.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <thrust/tuple.h>

#include <cuda/std/tuple>
#include <tuple>
#include <type_traits>

namespace cuco {
Expand Down Expand Up @@ -86,6 +87,19 @@ struct alignas(detail::pair_alignment<First, Second>()) pair {
* @param p The input pair to copy from
*/
template <typename T, std::enable_if_t<detail::is_std_pair_like<T>::value>* = nullptr>
__host__ __device__ constexpr pair(T const& p)
: pair{std::get<0>(thrust::raw_reference_cast(p)), std::get<1>(thrust::raw_reference_cast(p))}
{
}

/**
* @brief Constructs a pair from the given cuda::std::pair-like `p`.
*
* @tparam T Type of the pair to copy from
*
* @param p The input pair to copy from
*/
template <typename T, std::enable_if_t<detail::is_cuda_std_pair_like<T>::value>* = nullptr>
__host__ __device__ constexpr pair(T const& p)
: pair{cuda::std::get<0>(thrust::raw_reference_cast(p)),
cuda::std::get<1>(thrust::raw_reference_cast(p))}
Expand Down
Loading