Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
836e77a
First draft
sleeepyjack Jan 24, 2024
d3a1e2f
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Jan 24, 2024
6718560
Code style
sleeepyjack Jan 24, 2024
c59744e
Resolve merge conflicts
sleeepyjack Jan 24, 2024
b7533a0
Initialize shmem atomics through placement new
sleeepyjack Jan 24, 2024
f4bdac2
Improve naming
sleeepyjack Jan 24, 2024
cea2afb
Move some functionality to storage class
sleeepyjack Jan 24, 2024
0f0bd3f
Add inline docs for public APIs
sleeepyjack Jan 25, 2024
b21dcd1
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Jan 25, 2024
1c780c2
Add benchmark
sleeepyjack Jan 25, 2024
a83a3f3
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Jan 30, 2024
b478e01
Remove scope ctor parameter for now
sleeepyjack Jan 30, 2024
e3d401a
Update benchmark
sleeepyjack Jan 30, 2024
56520a6
Select cg reduce impl based on nvcc version
sleeepyjack Jan 31, 2024
3673772
Re-format tuning header
sleeepyjack Jan 31, 2024
799284e
Implement HLL++ bias correction step
sleeepyjack Feb 1, 2024
758977c
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Feb 1, 2024
abbeffa
Extend examples and fix some bugs along the way
sleeepyjack Feb 1, 2024
86d4618
Refactor thresholds
sleeepyjack Feb 1, 2024
d6a9a4e
Initialize shmem storage using placement new
sleeepyjack Feb 1, 2024
891d606
Add unit test
sleeepyjack Feb 1, 2024
3544195
Remove experimental cg async reduce since it is buggy
sleeepyjack Feb 1, 2024
3506ecb
Fix bit-shifting bug that lead to high error rates
sleeepyjack Feb 2, 2024
919d0ab
Storage cleanups
sleeepyjack Feb 2, 2024
12301df
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Feb 2, 2024
52f6e09
Update readme
sleeepyjack Feb 2, 2024
0a0119d
Fix typo
sleeepyjack Feb 2, 2024
ab50bed
Fix typo
sleeepyjack Feb 2, 2024
68d2df0
Apply suggestions from code review
sleeepyjack Feb 3, 2024
4567169
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Feb 6, 2024
93f68a2
Use CUDART_VERSION instead of (__CUDACC_VER_MAJOR__
sleeepyjack Feb 6, 2024
03a8572
Apply suggestions from code review
sleeepyjack Feb 6, 2024
33f7baf
Enable Precision>18; fix some bugs, extend tests.
sleeepyjack Feb 16, 2024
6e3683f
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Feb 16, 2024
b1253bf
Remove storage class and move host implementations to ref class
sleeepyjack Mar 13, 2024
64a5b70
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Mar 13, 2024
22c083d
Remove storage class
sleeepyjack Mar 14, 2024
56cdc6b
Add vectorized add kernel
sleeepyjack Mar 15, 2024
b8dc849
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Mar 15, 2024
9b4b612
Add missing kernel config
sleeepyjack Mar 15, 2024
30bd79d
Make tuning arrs accessible in non-constexpr context
sleeepyjack Mar 15, 2024
e93c248
Allow wider vector sizes
sleeepyjack Mar 15, 2024
204b8e2
Fix processing of remaining items
sleeepyjack Mar 15, 2024
fe1cf5a
Guard invoke_one with macro
sleeepyjack Mar 15, 2024
ae9e77c
Specify sketch size/precision at runtime
sleeepyjack Mar 18, 2024
65ff70a
Pre-compute register mask
sleeepyjack Mar 18, 2024
8068799
Fix unit test
sleeepyjack Mar 19, 2024
04c303d
Add sketch_size_kb strong type and fix stupid bug where I called a st…
sleeepyjack Mar 20, 2024
a7036ae
Fix benchmark
sleeepyjack Mar 20, 2024
3e25da7
More robust error estimation in benchmark
sleeepyjack Mar 20, 2024
e5d5112
Benchmark gmem fallback kernel
sleeepyjack Mar 20, 2024
99c0dee
Rename max_sketch_size_kb -> sketch_size_kb
sleeepyjack Mar 20, 2024
aeaecf4
Improve error handling and docs
sleeepyjack Mar 20, 2024
55fa312
Cleanup finalizer
sleeepyjack Mar 20, 2024
2229c68
Use double reduction
sleeepyjack Mar 20, 2024
156a843
Use .estimate() in device ref example
sleeepyjack Mar 20, 2024
80dde95
Add device ref test
sleeepyjack Mar 20, 2024
730bf73
Restructure to reduce fp error
sleeepyjack Mar 20, 2024
c50e795
Rename parameter for other estimator ref
sleeepyjack Mar 20, 2024
d5595da
Update benchmark
sleeepyjack Mar 20, 2024
16ad77a
Rebind allocator to register_type to ensure proper alignment
sleeepyjack Mar 20, 2024
b501a32
Use cudaMemcpyDefault
sleeepyjack Mar 20, 2024
0bf0a88
Mention alignment requirements in device_ref_example
sleeepyjack Mar 20, 2024
d03120c
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Mar 20, 2024
a361360
Pass T instead of Estimator to benchmark
sleeepyjack Mar 20, 2024
66870a7
Fix typo in benchmark script
sleeepyjack Mar 20, 2024
b990dca
Rename hash function
sleeepyjack Mar 21, 2024
c87309e
Use placement new to initialize sketch
sleeepyjack Mar 21, 2024
03d4b41
Remove custom_deleter member
sleeepyjack Mar 21, 2024
7de06fb
Rename sketch_size.hpp -> sktech_size.cuh
sleeepyjack Mar 21, 2024
185d3c4
Use std::abs
sleeepyjack Mar 21, 2024
023d080
Use std::vector instead of thrust::host_vector>
sleeepyjack Mar 21, 2024
53cdf37
Add note about shmem alignment
sleeepyjack Mar 21, 2024
2a81714
Remove comment
sleeepyjack Mar 21, 2024
d859b39
Remove device-sided error handling since it hurts performance
sleeepyjack Mar 21, 2024
43be0f0
Constexpr all the things!
sleeepyjack Mar 21, 2024
fbd6dab
Add constructor overload which takes the desired standard deviation
sleeepyjack Mar 22, 2024
dfe1a07
Remove stray include
sleeepyjack Mar 22, 2024
2629adc
Bugfixes
sleeepyjack Mar 23, 2024
1ad97e2
Fix merge
sleeepyjack Mar 27, 2024
3b0da20
Add Spark parity tests
sleeepyjack Mar 27, 2024
f80509f
Fix error calculation
sleeepyjack Mar 27, 2024
bbb7258
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Apr 3, 2024
b61a2db
Move include/cuco/sentinel.cuh -> include/cuco/types.cuh
sleeepyjack Apr 3, 2024
9b0ee68
Move HLL-related strong types to types.cuh
sleeepyjack Apr 3, 2024
75cd967
Apparently Doxygen has become even pickier...
sleeepyjack Apr 3, 2024
9436931
Merge remote-tracking branch 'upstream/dev' into hll
sleeepyjack Apr 3, 2024
6929f65
Clean up device ref example
sleeepyjack Apr 3, 2024
a496f9e
Update godbolt links
sleeepyjack Apr 3, 2024
7fecd7b
Clean up unique sequence unit test
sleeepyjack Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,12 @@ We plan to add many GPU-accelerated, concurrent data structures to `cuCollection
#### Examples:
- [Host-bulk APIs (TODO)]()

### `distinct_count_estimator`

`cuco::distinct_count_estimator` implements the well-established [HyperLogLog++ algorithm](https://static.googleusercontent.com/media/research.google.com/de//pubs/archive/40671.pdf) for approximating the count of distinct items in a multiset/stream.

#### Examples:
- [Host-bulk APIs](https://github.com/NVIDIA/cuCollections/blob/dev/examples/distinct_count_estimator/host_bulk_example.cu) (see [live example in godbolt](https://godbolt.org/z/ahjEoWM1E))
- [Device-ref APIs](https://github.com/NVIDIA/cuCollections/blob/dev/examples/distinct_count_estimator/device_ref_example.cu) (see [live example in godbolt](https://godbolt.org/z/qebYY8Goj))


5 changes: 5 additions & 0 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,8 @@ ConfigureBench(DYNAMIC_MAP_BENCH
# - hash function benchmarks ----------------------------------------------------------------------
ConfigureBench(HASH_BENCH
hash_bench.cu)

###################################################################################################
# - distinct_count_estimator benchmarks -----------------------------------------------------------
ConfigureBench(DISTINCT_COUNT_ESTIMATOR_BENCH
distinct_count_estimator_bench.cu)
164 changes: 164 additions & 0 deletions benchmarks/distinct_count_estimator_bench.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <defaults.hpp>
#include <utils.hpp>

#include <cuco/distinct_count_estimator.cuh>
#include <cuco/static_set.cuh>
#include <cuco/utility/key_generator.cuh>

#include <nvbench/nvbench.cuh>

#include <thrust/device_vector.h>
#include <thrust/iterator/transform_iterator.h>

#include <cuda/functional>

#include <cmath>
#include <cstddef>

using namespace cuco::benchmark;
using namespace cuco::utility;

template <typename InputIt>
[[nodiscard]] std::size_t exact_distinct_count(InputIt first, std::size_t n)
{
// TODO static_set currently only supports types up-to 8-bytes in size.
// Casting is valid since the keys generated are representable in int64_t.
using T = std::int64_t;

auto cast_iter = thrust::make_transform_iterator(
first, cuda::proclaim_return_type<T>([] __device__(auto i) { return static_cast<T>(i); }));

auto set = cuco::static_set{n, 0.8, cuco::empty_key<T>{-1}};
set.insert(cast_iter, cast_iter + n);
return set.size();
}

template <class Estimator, class Dist>
[[nodiscard]] double relative_error(nvbench::state& state, std::size_t num_samples)
{
using T = typename Estimator::value_type;

auto const num_items = state.get_int64("NumInputs");
auto const sketch_size_kb = state.get_int64("SketchSizeKB");

thrust::device_vector<T> items(num_items);

key_generator gen;
Estimator estimator{cuco::sketch_size_kb(sketch_size_kb)};
double error_sum = 0;
for (std::size_t i = 0; i < num_samples; ++i) {
gen.generate(dist_from_state<Dist>(state), items.begin(), items.end());
estimator.add(items.begin(), items.end());
double estimated_cardinality = estimator.estimate();
double true_cardinality = exact_distinct_count(items.begin(), num_items);
error_sum += std::abs(estimated_cardinality / true_cardinality - 1.0);
estimator.clear();
}

return error_sum / num_samples;
}

/**
* @brief A benchmark evaluating `cuco::distinct_count_estimator` end-to-end performance
*/
template <typename T, typename Dist>
void distinct_count_estimator_e2e(nvbench::state& state, nvbench::type_list<T, Dist>)
{
using estimator_type = cuco::distinct_count_estimator<T>;

auto const num_items = state.get_int64("NumInputs");
auto const sketch_size_kb = state.get_int64("SketchSizeKB");

state.add_element_count(num_items);
state.add_global_memory_reads<T>(num_items, "InputSize");

auto const err_samples = (cuda::std::is_same_v<Dist, distribution::unique>) ? 1 : 5;
auto const err = relative_error<estimator_type, Dist>(state, err_samples);
auto& summ = state.add_summary("MeanRelativeError");
summ.set_string("hint", "MRelErr");
summ.set_string("short_name", "MeanRelativeError");
summ.set_string("description", "Mean relatve approximation error.");
summ.set_float64("value", err);

thrust::device_vector<T> items(num_items);

key_generator gen;
gen.generate(dist_from_state<Dist>(state), items.begin(), items.end());

estimator_type estimator{cuco::sketch_size_kb(sketch_size_kb)};
std::size_t estimated_cardinality = 0;
state.exec(nvbench::exec_tag::sync | nvbench::exec_tag::timer,
[&](nvbench::launch& launch, auto& timer) {
timer.start();
estimator.add_async(items.begin(), items.end(), {launch.get_stream()});
estimated_cardinality = estimator.estimate({launch.get_stream()});
timer.stop();

estimator.clear_async({launch.get_stream()});
});
}

/**
* @brief A benchmark evaluating `cuco::distinct_count_estimator::add` performance
*/
template <typename T, typename Dist>
void distinct_count_estimator_add(nvbench::state& state, nvbench::type_list<T, Dist>)
{
using estimator_type = cuco::distinct_count_estimator<T>;

auto const num_items = state.get_int64("NumInputs");
auto const sketch_size_kb = state.get_int64("SketchSizeKB");

thrust::device_vector<T> items(num_items);

key_generator gen;
gen.generate(dist_from_state<Dist>(state), items.begin(), items.end());

state.add_element_count(num_items);
state.add_global_memory_reads<T>(num_items, "InputSize");

estimator_type estimator{cuco::sketch_size_kb(sketch_size_kb)};
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
timer.start();
estimator.add_async(items.begin(), items.end(), {launch.get_stream()});
timer.stop();

estimator.clear_async({launch.get_stream()});
});
}

using TYPE_RANGE = nvbench::type_list<nvbench::int32_t, nvbench::int64_t, __int128_t>;

NVBENCH_BENCH_TYPES(distinct_count_estimator_e2e,
NVBENCH_TYPE_AXES(TYPE_RANGE, nvbench::type_list<distribution::uniform>))
.set_name("distinct_count_estimator_e2e")
.set_type_axes_names({"T", "Distribution"})
.add_int64_power_of_two_axis("NumInputs", {28, 29, 30})
.add_int64_axis("SketchSizeKB", {8, 16, 32, 64, 128, 256}) // 256KB uses gmem fallback kernel
.add_int64_axis("Multiplicity", {1})
.set_max_noise(defaults::MAX_NOISE);

NVBENCH_BENCH_TYPES(distinct_count_estimator_add,
NVBENCH_TYPE_AXES(TYPE_RANGE, nvbench::type_list<distribution::uniform>))
.set_name("distinct_count_estimator::add_async")
.set_type_axes_names({"T", "Distribution"})
.add_int64_power_of_two_axis("NumInputs", {28, 29, 30})
.add_int64_axis("SketchSizeKB", {8, 16, 32, 64, 128, 256})
.add_int64_axis("Multiplicity", {1})
.set_max_noise(defaults::MAX_NOISE);
2 changes: 2 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ ConfigureExample(STATIC_MAP_DEVICE_SIDE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/sta
ConfigureExample(STATIC_MAP_CUSTOM_TYPE_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/custom_type_example.cu")
ConfigureExample(STATIC_MAP_COUNT_BY_KEY_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_map/count_by_key_example.cu")
ConfigureExample(STATIC_MULTIMAP_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/static_multimap/host_bulk_example.cu")
ConfigureExample(DISTINCT_COUNT_ESTIMATOR_HOST_BULK_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/distinct_count_estimator/host_bulk_example.cu")
ConfigureExample(DISTINCT_COUNT_ESTIMATOR_DEVICE_REF_EXAMPLE "${CMAKE_CURRENT_SOURCE_DIR}/distinct_count_estimator/device_ref_example.cu")
164 changes: 164 additions & 0 deletions examples/distinct_count_estimator/device_ref_example.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuco/distinct_count_estimator.cuh>

#include <thrust/device_vector.h>
#include <thrust/sequence.h>

#include <cstddef>
#include <iostream>

/**
* @file device_ref_example.cu
* @brief Demonstrates usage of `cuco::distinct_count_estimator` device-side APIs.
*
* This example demonstrates how the non-owning reference type `cuco::distinct_count_estimator_ref`
* can be used to implement a custom kernel that fuses the cardinality estimation step with any
* other workload that traverses the input data.
*/

template <class RefType, class InputIt>
__global__ void fused_kernel(RefType ref, InputIt first, std::size_t n)
{
// Transform the reference type (with device scope) to a reference type with block scope
using local_ref_type = typename RefType::with_scope<cuda::thread_scope_block>;

// Shared memory storage for the block-local estimator
extern __shared__ std::byte local_sketch[];

// The following check is optional since the base address of dynamic shared memory is guaranteed
// to meet the alignment requirements
/*
auto const alignment =
1ull << cuda::std::countr_zero(reinterpret_cast<std::uintptr_t>(local_sketch));
assert(alignment >= local_ref_type::sketch_alignment());
*/

auto const loop_stride = gridDim.x * blockDim.x;
auto idx = blockDim.x * blockIdx.x + threadIdx.x;
auto const block = cooperative_groups::this_thread_block();

// Create the local estimator with the shared memory storage
local_ref_type local_ref(cuda::std::span{local_sketch, ref.sketch_bytes()});

// Initialize the local estimator
local_ref.clear(block);
block.sync();

while (idx < n) {
auto const& item = *(first + idx);

// Add each item to the local estimator
local_ref.add(item);

/*
Here we can add some custom workload that takes the input `item`.

The idea is that cardinality estimation can be fused with any other workload that
traverses the data. Since `local_ref.add` can run close to the SOL of the DRAM bandwidth, we get
the estimate "for free" while performing other computations over the data.
*/

idx += loop_stride;
}
block.sync();

// We can also compute the local estimate on the device
// auto const local_estimate = local_ref.estimate(block);
if (block.thread_rank() == 0) {
// The local estimate should approximately be `num_items`/`gridDim.x`
// printf("Estimate for block %d = %llu\n", blockIdx.x, local_estimate);
}

// In the end, we merge the shared memory estimator into the global estimator which gives us the
// final result
ref.merge(block, local_ref);
}

template <typename Ref, typename InputIt, typename OutputIt>
__global__ void device_estimate_kernel(cuco::sketch_size_kb sketch_size_kb,
InputIt in,
size_t n,
OutputIt out)
{
extern __shared__ std::byte local_sketch[];

auto const block = cooperative_groups::this_thread_block();

// only a single block computes the estimate
if (block.group_index().x == 0) {
Ref estimator(cuda::std::span(local_sketch, Ref::sketch_bytes(sketch_size_kb)));

estimator.clear(block);
block.sync();

for (int i = block.thread_rank(); i < n; i += block.num_threads()) {
estimator.add(*(in + i));
}
block.sync();
// we can compute the final estimate on the device and return the result to the host
auto const estimate = estimator.estimate(block);

if (block.thread_rank() == 0) { *out = estimate; }
}
}

int main(void)
{
using T = int;
using estimator_type = cuco::distinct_count_estimator<T>;
constexpr std::size_t num_items = 1ull << 28; // 1GB
auto const sketch_size_kb = 32_KB;

thrust::device_vector<T> items(num_items);

// Generate `num_items` distinct items
thrust::sequence(items.begin(), items.end(), 0);

// Initialize the estimator
estimator_type estimator(sketch_size_kb);

// Add all items to the estimator
estimator.add(items.begin(), items.end());

// Calculate the cardinality estimate from the bulk operation
std::size_t const estimated_cardinality_bulk = estimator.estimate();

// Clear the estimator so it can be reused
estimator.clear();

// Number of dynamic shared memory bytes required to store a CTA-local sketch
auto const sketch_bytes = estimator.sketch_bytes();

// Call the custom kernel and pass a non-owning reference to the estimator to the GPU
fused_kernel<<<10, 512, sketch_bytes>>>(estimator.ref(), items.begin(), num_items);

// Calculate the cardinality estimate from the custom kernel
std::size_t const estimated_cardinality_custom = estimator.estimate();

thrust::device_vector<std::size_t> device_estimate(1);
device_estimate_kernel<typename estimator_type::ref_type<cuda::thread_scope_block>>
<<<1, 512, sketch_bytes>>>(sketch_size_kb, items.begin(), num_items, device_estimate.begin());

std::size_t const estimated_cardinality_device = device_estimate[0];

if (estimated_cardinality_custom == estimated_cardinality_bulk and
estimated_cardinality_device == estimated_cardinality_bulk) {
std::cout << "Success! Cardinality estimates are identical" << std::endl;
}

return 0;
}
Loading