-
Notifications
You must be signed in to change notification settings - Fork 25
[WIP] libcuml and libcudf groupby integration
#37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shaneding
wants to merge
7
commits into
rapidsai:main
Choose a base branch
from
shaneding:libcuml-and-libcudf-groupby
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
cc0e579
finished grouping
shaneding f8408f1
updated stable_name
shaneding 9ab3665
added basic data transformation example
shaneding bd469a8
cuml debug
shaneding 37e1853
updated to work on multiple groups
shaneding 11394ef
v1 done
shaneding 363fddd
updated files
shaneding File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| cmake_minimum_required(VERSION 3.18) | ||
|
|
||
| project(basic_example VERSION 0.0.1 LANGUAGES CXX CUDA) | ||
|
|
||
| set(CPM_DOWNLOAD_VERSION v0.32.2) | ||
| file(DOWNLOAD https://github.com/cpm-cmake/CPM.cmake/releases/download/${CPM_DOWNLOAD_VERSION}/get_cpm.cmake ${CMAKE_BINARY_DIR}/cmake/get_cpm.cmake) | ||
| include(${CMAKE_BINARY_DIR}/cmake/get_cpm.cmake) | ||
|
|
||
| set(CUDF_TAG branch-21.10) | ||
| CPMFindPackage(NAME cudf | ||
| GIT_REPOSITORY https://github.com/rapidsai/cudf | ||
| GIT_TAG ${CUDF_TAG} | ||
| GIT_SHALLOW TRUE | ||
| SOURCE_SUBDIR cpp | ||
| ) | ||
|
|
||
| set(CUDF_TAG branch-21.10) | ||
| CPMFindPackage(NAME cuml | ||
| GIT_REPOSITORY https://github.com/rapidsai/cuml | ||
| GIT_TAG ${CUDF_TAG} | ||
| GIT_SHALLOW TRUE | ||
| SOURCE_SUBDIR cpp | ||
| ) | ||
|
|
||
| # Configure your project here | ||
| add_executable(basic_example src/process_csv.cpp) | ||
| target_link_libraries(basic_example PRIVATE | ||
| cudf::cudf | ||
| cuml::cuml++ | ||
| ) | ||
| target_compile_features(basic_example PRIVATE cxx_std_17) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # libcuml and libcudf groupby | ||
|
|
||
| This C++ example demonstrates running a `libcuml` regression model on groups from a `libcudf` `groupby` operation | ||
|
|
||
| ## Compile and execute | ||
|
|
||
| ```bash | ||
| # Configure project | ||
| cmake -S . -B build/ | ||
| # Build | ||
| cmake --build build/ --parallel $PARALLEL_LEVEL | ||
| # Execute | ||
| build/basic_example | ||
| ``` | ||
|
|
||
| If your machine does not come with a pre-built libcudf binary, expect the | ||
| first build to take some time, as it would build libcudf on the host machine. | ||
| It may be sped up by configuring the proper `PARALLEL_LEVEL` number. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| #pragma once | ||
|
|
||
| #include <cudf/column/column_view.hpp> | ||
| #include <cudf/utilities/error.hpp> | ||
|
|
||
| #include <rmm/thrust_rmm_allocator.h> | ||
| #include <rmm/device_buffer.hpp> | ||
| #include <rmm/device_uvector.hpp> | ||
|
|
||
| #include <thrust/copy.h> | ||
| #include <thrust/device_ptr.h> | ||
| #include <thrust/execution_policy.h> | ||
| #include <thrust/fill.h> | ||
|
|
||
| #include <iostream> | ||
| #include <string> | ||
|
|
||
| namespace cuspatial { | ||
| namespace debug { | ||
|
|
||
| template <typename T> | ||
| void print(std::vector<T> const& vec, | ||
| std::ostream& os = std::cout, | ||
| std::string const& delimiter = ",") | ||
| { | ||
| std::vector<double> f64s(vec.size()); | ||
| std::copy(vec.begin(), vec.end(), f64s.begin()); | ||
| os << "size: " << vec.size() << " [" << std::endl << " "; | ||
| std::copy(f64s.begin(), f64s.end(), std::ostream_iterator<double>(os, delimiter.data())); | ||
| os << std::endl << "]" << std::endl; | ||
| } | ||
|
|
||
| template <typename T> | ||
| void print(rmm::device_vector<T> const& vec, | ||
| std::ostream& os = std::cout, | ||
| std::string const& delimiter = ",", | ||
| cudaStream_t stream = 0) | ||
| { | ||
| CUDA_TRY(cudaStreamSynchronize(stream)); | ||
| std::vector<T> hvec(vec.size()); | ||
| std::fill(hvec.begin(), hvec.end(), T{0}); | ||
| thrust::copy(vec.begin(), vec.end(), hvec.begin()); | ||
| print<T>(hvec, os, delimiter); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void print(rmm::device_uvector<T> const& uvec, | ||
| std::ostream& os = std::cout, | ||
| std::string const& delimiter = ",", | ||
| cudaStream_t stream = 0) | ||
| { | ||
| rmm::device_vector<T> dvec(uvec.size()); | ||
| std::fill(dvec.begin(), dvec.end(), T{0}); | ||
| thrust::copy(rmm::exec_policy(stream)->on(stream), uvec.begin(), uvec.end(), dvec.begin()); | ||
| print<T>(dvec, os, delimiter, stream); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void print(rmm::device_buffer const& buf, | ||
| std::ostream& os = std::cout, | ||
| std::string const& delimiter = ",", | ||
| cudaStream_t stream = 0) | ||
| { | ||
| auto ptr = thrust::device_pointer_cast<const T>(buf.data()); | ||
| rmm::device_vector<T> dvec(buf.size() / sizeof(T)); | ||
| thrust::fill(dvec.begin(), dvec.end(), T{0}); | ||
| thrust::copy(rmm::exec_policy(stream)->on(stream), ptr, ptr + dvec.size(), dvec.begin()); | ||
| print<T>(dvec, os, delimiter, stream); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void print(cudf::column_view const& col, | ||
| std::ostream& os = std::cout, | ||
| std::string const& delimiter = ",", | ||
| cudaStream_t stream = 0) | ||
| { | ||
| rmm::device_vector<T> dvec(col.size()); | ||
| std::fill(dvec.begin(), dvec.end(), T{0}); | ||
| thrust::copy(rmm::exec_policy(stream)->on(stream), col.begin<T>(), col.end<T>(), dvec.begin()); | ||
| print<T>(dvec, os, delimiter, stream); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void print(thrust::device_ptr<T> const& ptr, | ||
| cudf::size_type size, | ||
| std::ostream& os = std::cout, | ||
| std::string const& delimiter = ",", | ||
| cudaStream_t stream = 0) | ||
| { | ||
| rmm::device_vector<T> dvec(size); | ||
| std::fill(dvec.begin(), dvec.end(), T{0}); | ||
| thrust::copy(rmm::exec_policy(stream)->on(stream), ptr, ptr + size, dvec.begin()); | ||
| print<T>(dvec, os, delimiter, stream); | ||
| } | ||
| } // namespace debug | ||
| } // namespace cuspatial |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| #include <cudf/aggregation.hpp> | ||
| #include <cudf/groupby.hpp> | ||
| #include <cudf/io/csv.hpp> | ||
| #include <cudf/copying.hpp> | ||
| #include <cudf/reshape.hpp> | ||
| #include <cudf/transpose.hpp> | ||
| #include <cudf/table/table.hpp> | ||
| #include <cudf/column/column_factories.hpp> | ||
|
|
||
|
|
||
| #include <cuml/linear_model/glm.hpp> | ||
| #include <raft/handle.hpp> | ||
| #include <raft/cudart_utils.h> | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
| #include <iostream> | ||
|
|
||
| #include <rmm/exec_policy.hpp> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <thrust/uninitialized_fill.h> | ||
| #include <thrust/device_malloc.h> | ||
| #include <thrust/device_vector.h> | ||
| #include <thrust/copy.h> | ||
|
|
||
|
|
||
|
|
||
|
|
||
| #ifndef CUDA_RT_CALL | ||
| #define CUDA_RT_CALL(call) \ | ||
| { \ | ||
| cudaError_t cudaStatus = call; \ | ||
| if (cudaSuccess != cudaStatus) \ | ||
| fprintf(stderr, \ | ||
| "ERROR: CUDA RT call \"%s\" in line %d of file %s failed with " \ | ||
| "%s (%d).\n", \ | ||
| #call, \ | ||
| __LINE__, \ | ||
| __FILE__, \ | ||
| cudaGetErrorString(cudaStatus), \ | ||
| cudaStatus); \ | ||
| } | ||
| #endif // CUDA_RT_CALL | ||
|
|
||
|
|
||
| cudf::io::table_with_metadata read_csv(std::string const& file_path) | ||
| { | ||
| auto source_info = cudf::io::source_info(file_path); | ||
| auto builder = cudf::io::csv_reader_options::builder(source_info); | ||
| auto options = builder.build(); | ||
| return cudf::io::read_csv(options); | ||
| } | ||
|
|
||
| void write_csv(cudf::table_view const& tbl_view, std::string const& file_path) | ||
| { | ||
| auto sink_info = cudf::io::sink_info(file_path); | ||
| auto builder = cudf::io::csv_writer_options::builder(sink_info, tbl_view); | ||
| auto options = builder.build(); | ||
| cudf::io::write_csv(options); | ||
| } | ||
|
|
||
|
|
||
| std::unique_ptr<cudf::column> generate_grouped_arr(cudf::table_view values, cudf::size_type start, cudf::size_type end) | ||
| { | ||
| auto sliced_table = cudf::slice(values, {start, end}).front(); | ||
| auto [_, transposed_table] = cudf::transpose(sliced_table); | ||
|
|
||
| return cudf::interleave_columns(transposed_table); | ||
| } | ||
|
|
||
| std::unique_ptr<cudf::table> cuml_regression_on_groupby(cudf::table_view input_table) | ||
| { | ||
| // Schema: Name | X | Y | target | ||
| auto keys = cudf::table_view{{input_table.column(0)}}; // name | ||
|
|
||
| cudf::groupby::groupby grpby_obj(keys); | ||
| cudf::groupby::groupby::groups gb_groups = grpby_obj.get_groups(input_table.select({1,2,3})); | ||
| auto values_view = (gb_groups.values)->view(); | ||
|
|
||
| auto interleaved = generate_grouped_arr(values_view, 0, 3); | ||
|
|
||
| // cuml setup | ||
| int n_cols = 2; | ||
| raft::handle_t handle; | ||
| cudaStream_t stream = rmm::cuda_stream_default.value(); | ||
| CUDA_RT_CALL(cudaStreamCreate(&stream)); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we recreating the stream if already using the default rmm stream? |
||
| handle.set_stream(stream); | ||
|
|
||
| //thrust::device_vector<double> coef1(gb_groups.offsets.size() - 1); | ||
| //thrust::device_vector<double> coef2(gb_groups.offsets.size() - 1); | ||
| // looping through each group | ||
|
|
||
| auto coef1 = cudf::make_numeric_column(cudf::data_type{cudf::type_id::FLOAT64}, (gb_groups.offsets.size() - 1) * 2); | ||
| auto coef2 = cudf::make_numeric_column(cudf::data_type{cudf::type_id::FLOAT64}, gb_groups.offsets.size() - 1); | ||
|
|
||
| for (int i = 1; i < gb_groups.offsets.size(); i++) { | ||
|
|
||
| cudf::size_type offset1 = gb_groups.offsets[i-1], offset2 = gb_groups.offsets[i]; | ||
| int n_rows = offset2 - offset1; | ||
|
|
||
| auto interleaved = generate_grouped_arr(values_view, offset1, offset2); | ||
| double *matrix_pointer = interleaved->mutable_view().data<double>(); | ||
|
|
||
| // original values | ||
| raft::print_device_vector<double>("values", matrix_pointer, n_rows * (n_cols + 1), std::cout); | ||
| thrust::device_ptr<double> coef = thrust::device_malloc<double>(n_cols); | ||
| double intercept; | ||
|
|
||
| // label is stored in matrix_pointer + n_rows * n_cols | ||
| ML::GLM::olsFit(handle, matrix_pointer, n_rows, n_cols, matrix_pointer + n_rows * n_cols, coef.get(), &intercept, false, false); | ||
| // raft::print_device_vector<double>("values", matrix_pointer, n_rows * (n_cols + 1), std::cout); | ||
|
|
||
| // loops through n_cols | ||
| thrust::copy(thrust::device, coef.get(), coef.get()+2, coef1->mutable_view().data<double>() + i - 1); | ||
| //thrust::copy(coef.get()+1, coef.get()+2, coef2->mutable_view().data<double>() + i - 1); | ||
|
|
||
| raft::print_device_vector<double>("coef", coef.get(), n_cols, std::cout); | ||
| } | ||
| //raft::print_device_vector<double>("coef", coef1.data().get(), gb_groups.offsets.size() - 1, std::cout); | ||
| return std::make_unique<cudf::table>(cudf::table_view({coef1->view()}).select({0})); | ||
| } | ||
|
|
||
| int main(int argc, char** argv) | ||
| { | ||
| // Read data | ||
| auto sample_table = read_csv("test2.csv"); | ||
|
|
||
| // Process | ||
| auto result = cuml_regression_on_groupby(*sample_table.tbl); | ||
|
|
||
| // Write out result | ||
| write_csv(*result, "test_out.csv"); | ||
|
|
||
| return 0; | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be CUML_TAG