From 947c28a3c8e7bb626447c97b2a19f884c3e12c41 Mon Sep 17 00:00:00 2001 From: rakri Date: Wed, 21 Aug 2024 03:25:54 -0700 Subject: [PATCH 01/18] first cut changes to add diversity feature --- include/index.h | 5 +- include/neighbor.h | 32 ++++++++++ src/index.cpp | 144 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 179 insertions(+), 2 deletions(-) diff --git a/include/index.h b/include/index.h index b9bf4f384..2700d4ada 100644 --- a/include/index.h +++ b/include/index.h @@ -258,7 +258,7 @@ template clas // The query to use is placed in scratch->aligned_query std::pair iterate_to_fixed_point(InMemQueryScratch *scratch, const uint32_t Lindex, const std::vector &init_ids, bool use_filter, - const std::vector &filters, bool search_invocation); + const std::vector &filters, bool search_invocation, uint32_t maxLperSeller = 0); void search_for_point_and_prune(int location, uint32_t Lindex, std::vector &pruned_list, InMemQueryScratch *scratch, bool use_filter = false, @@ -384,6 +384,9 @@ template clas std::unordered_map _label_to_start_id; std::unordered_map _medoid_counts; + bool _diverse_index = false; + std::vector _location_to_seller; + bool _use_universal_label = false; LabelT _universal_label = 0; uint32_t _filterIndexingQueueSize; diff --git a/include/neighbor.h b/include/neighbor.h index d7c0c25ed..1117a7f5e 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -93,6 +93,38 @@ class NeighborPriorityQueue } } + + // Deletes the item if found. + void delete_id(const uint32_t &id) + { + size_t lo = 0, hi = _size; + size_t loc = std::numeric_limits::max(); + while (lo < hi) + { + size_t mid = (lo + hi) >> 1; + if (id < _data[mid].id) + { + hi = mid; + } + else if (_data[mid].id == id) + { + loc = mid; + break; + } + else + { + lo = mid + 1; + } + } + + if (loc != std::numeric_limits::max()) + { + std::memmove(&_data[loc], &_data[loc+1], (_size - loc - 1) * sizeof(Neighbor)); + _size--; + } + } + + Neighbor closest_unexpanded() { _data[_cur].expanded = true; diff --git a/src/index.cpp b/src/index.cpp index bf93344fa..970a1eb0f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -791,11 +791,19 @@ bool Index::detect_common_filters(uint32_t point_id, bool searc template std::pair Index::iterate_to_fixed_point( InMemQueryScratch *scratch, const uint32_t Lsize, const std::vector &init_ids, bool use_filter, - const std::vector &filter_labels, bool search_invocation) + const std::vector &filter_labels, bool search_invocation, uint32_t maxLperSeller) { + bool diverse_search = false; + if (maxLperSeller == 0) + maxLperSeller = Lsize; + else + diverse_search = true; std::vector &expanded_nodes = scratch->pool(); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); best_L_nodes.reserve(Lsize); + + tsl::robin_map color_to_nodes; + tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); std::vector &id_scratch = scratch->id_scratch(); @@ -874,6 +882,12 @@ std::pair Index::iterate_to_fixed_point( Neighbor nn = Neighbor(id, distance); best_L_nodes.insert(nn); + if (diverse_search) { + auto &col = _location_to_seller[id]; + if (color_to_nodes.find(col) == color_to_nodes.end()) + color_to_nodes[col] = NeighborPriorityQueue(maxLperSeller); + color_to_nodes[col].insert(nn); + } } } @@ -969,7 +983,39 @@ std::pair Index::iterate_to_fixed_point( // Insert pairs into the pool of candidates for (size_t m = 0; m < id_scratch.size(); ++m) { + if (diverse_search) { + auto cur_id = id_scratch[n]; + auto cur_dist = dist_scratch[m]; + if (color_to_nodes.find(_location_to_seller[cur_id]) == color_to_nodes.end()) { + color_to_nodes[_location_to_seller[cur_id]] = NeighborPriorityQueue(maxLperSeller); + } + auto &cur_list = color_to_nodes[_location_to_seller[cur_id]]; + if (cur_list.size() < maxLperSeller && best_L_nodes.size() < Lsize) { + cur_list.insert(Neighbor(cur_id, cur_dist)); + best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + } else if (cur_list.size() == maxLperSeller && best_L_nodes.size() < Lsize) { + if (cur_dist < cur_list[maxLperSeller-1].distance) { + best_L_nodes.delete_id(cur_list[maxLperSeller-1].id); + cur_list.insert(Neighbor(cur_id, cur_dist)); + best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + } + } else if (cur_list.size() < maxLperSeller && best_L_nodes.size() == Lsize) { + if (cur_dist < best_L_nodes[Lsize-1].distance) { + color_to_nodes[_location_to_seller[best_L_nodes[Lsize-1].id]].delete_id(best_L_nodes[Lsize-1].id); + cur_list.insert(Neighbor(cur_id, cur_dist)); + best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + } + } else { + if (cur_dist < cur_list[maxLperSeller-1].distance) { + best_L_nodes.delete_id(cur_list[maxLperSeller-1].id); + cur_list.insert(Neighbor(cur_id, cur_dist)); + best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + } + } + } + else { best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); + } } } return std::make_pair(hops, cmps); @@ -1053,6 +1099,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L assert(_graph_store->get_total_points() == _max_points + _num_frozen_pts); } +/* template void Index::occlude_list(const uint32_t location, std::vector &pool, const float alpha, const uint32_t degree, const uint32_t maxc, std::vector &result, @@ -1148,6 +1195,101 @@ void Index::occlude_list(const uint32_t location, std::vector +void Index::occlude_list(const uint32_t location, std::vector &pool, const float alpha, + const uint32_t degree, const uint32_t maxc, std::vector &result, + InMemQueryScratch *scratch, + const tsl::robin_set *const delete_set_ptr) +{ + if (pool.size() == 0) + return; + + // Truncate pool at maxc and initialize scratch spaces + assert(std::is_sorted(pool.begin(), pool.end())); + assert(result.size() == 0); + if (pool.size() > maxc) + pool.resize(maxc); + std::vector &occlude_factor = scratch->occlude_factor(); + // occlude_list can be called with the same scratch more than once by + // search_for_point_and_add_link through inter_insert. + occlude_factor.clear(); + // Initialize occlude_factor to pool.size() many 0.0f values for correctness + occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); + float cur_alpha = alpha; + { + // used for MIPS, where we store a value of eps in cur_alpha to + // denote pruned out entries which we can skip in later rounds. + float eps = cur_alpha + 0.01f; + for (auto iter = pool.begin(); result.size() < degree && iter != pool.end(); ++iter) + { + if (occlude_factor[iter - pool.begin()] > cur_alpha) + { + continue; + } + // Set the entry to float::max so that is not considered again + occlude_factor[iter - pool.begin()] = std::numeric_limits::max(); + // Add the entry to the result if its not been deleted, and doesn't + // add a self loop + if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end()) + { + if (iter->id != location) + { + result.push_back(iter->id); + } + } + + // Update occlude factor for points from iter+1 to pool.end() + for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++) + { + auto t = iter2 - pool.begin(); + if (occlude_factor[t] > alpha) + continue; + + bool prune_allowed = true; + if (_filtered_index) + { + uint32_t a = iter->id; + uint32_t b = iter2->id; + if (_location_to_labels.size() < b || _location_to_labels.size() < a) + continue; + for (auto &x : _location_to_labels[b]) + { + if (std::find(_location_to_labels[a].begin(), _location_to_labels[a].end(), x) == + _location_to_labels[a].end()) + { + prune_allowed = false; + } + if (!prune_allowed) + break; + } + } + if (!prune_allowed) + continue; + + float djk = _data_store->get_distance(iter2->id, iter->id); + if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE) + { + occlude_factor[t] = (djk == 0) ? std::numeric_limits::max() + : std::max(occlude_factor[t], iter2->distance / djk); + } + else if (_dist_metric == diskann::Metric::INNER_PRODUCT) + { + // Improvization for flipping max and min dist for MIPS + float x = -iter2->distance; + float y = -djk; + if (y > cur_alpha * x) + { + occlude_factor[t] = std::max(occlude_factor[t], eps); + } + } + } + } + } +} + template void Index::prune_neighbors(const uint32_t location, std::vector &pool, From af855d43c4d324d63df754c54ac481c83a5b1127 Mon Sep 17 00:00:00 2001 From: rakri Date: Fri, 23 Aug 2024 00:01:16 -0700 Subject: [PATCH 02/18] more changes towards a diversity index --- apps/build_memory_index.cpp | 17 +++++- include/defaults.h | 4 ++ include/index.h | 5 ++ include/parameters.h | 32 ++++++++++- include/program_options_utils.hpp | 2 + src/index.cpp | 93 ++++++++++++++++++++++++++++--- 6 files changed, 141 insertions(+), 12 deletions(-) diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 544e42dee..a9e3b61bd 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -24,10 +24,11 @@ namespace po = boost::program_options; int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - uint32_t num_threads, R, L, Lf, build_PQ_bytes; + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type, seller_file; + uint32_t num_threads, R, L, Lf, build_PQ_bytes, num_diverse_build; float alpha; bool use_pq_build, use_opq; + bool diverse_index=false; po::options_description desc{ program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")}; @@ -70,6 +71,12 @@ int main(int argc, char **argv) program_options_utils::FILTERED_LBUILD); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), + program_options_utils::DIVERSITY_FILE); + optional_configs.add_options()("NumDiverse", po::value(&Lf)->default_value(1), + program_options_utils::NUM_DIVERSE); + + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -112,6 +119,9 @@ int main(int argc, char **argv) return -1; } + if(seller_file != "") + diverse_index = true; + try { diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha @@ -125,6 +135,9 @@ int main(int argc, char **argv) .with_alpha(alpha) .with_saturate_graph(false) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(seller_file) + .with_num_diverse_build(num_diverse_build) .build(); auto filter_params = diskann::IndexFilterParamsBuilder() diff --git a/include/defaults.h b/include/defaults.h index ef1750fcf..fc10c4fb2 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -30,5 +30,9 @@ const uint32_t MAX_DEGREE = 64; const uint32_t BUILD_LIST_SIZE = 100; const uint32_t SATURATE_GRAPH = false; const uint32_t SEARCH_LIST_SIZE = 100; + +const bool DIVERSE_INDEX = false; +const std::string EMPTY_STRING = ""; +const bool NUM_DIVERSE_BUILD = 1; } // namespace defaults } // namespace diskann diff --git a/include/index.h b/include/index.h index 2700d4ada..b0e7ce812 100644 --- a/include/index.h +++ b/include/index.h @@ -248,6 +248,7 @@ template clas uint32_t calculate_entry_point(); void parse_label_file(const std::string &label_file, size_t &num_pts_labels); + void parse_seller_file(const std::string &label_file, size_t &num_pts_labels); std::unordered_map load_label_map(const std::string &map_file); @@ -385,7 +386,11 @@ template clas std::unordered_map _medoid_counts; bool _diverse_index = false; + uint32_t _num_diverse_build =1; + uint32_t _max_L_per_seller = 0; std::vector _location_to_seller; + std::string _seller_file; + bool _use_universal_label = false; LabelT _universal_label = 0; diff --git a/include/parameters.h b/include/parameters.h index 0206814bd..d0afa756a 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -23,13 +23,16 @@ class IndexWriteParameters const float alpha; const uint32_t num_threads; const uint32_t filter_list_size; // Lf + const bool diversity_index; + const std::string base_seller_labels; + const uint32_t num_diverse_sellers; IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph, const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads, - const uint32_t filter_list_size) + const uint32_t filter_list_size, const bool diversity_index, const std::string base_sellers, const uint32_t num_diverse_sellers) : search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph), max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads), - filter_list_size(filter_list_size) + filter_list_size(filter_list_size), diversity_index(diversity_index), base_seller_labels(base_sellers), num_diverse_sellers(num_diverse_sellers) { } @@ -73,6 +76,26 @@ class IndexWriteParametersBuilder return *this; } + + IndexWriteParametersBuilder &with_diverse_index(const bool diverse_index) + { + _diverse_index = diverse_index; + return *this; + } + + IndexWriteParametersBuilder &with_seller_file(const std::string seller_file) + { + _base_sellers = seller_file; + return *this; + } + + IndexWriteParametersBuilder &with_num_diverse_build(const uint32_t num_diverse_build) + { + _num_diverse_build = num_diverse_build; + return *this; + } + + IndexWriteParametersBuilder &with_alpha(const float alpha) { _alpha = alpha; @@ -94,7 +117,7 @@ class IndexWriteParametersBuilder IndexWriteParameters build() const { return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha, - _num_threads, _filter_list_size); + _num_threads, _filter_list_size, _diverse_index, _base_sellers, _num_diverse_build); } IndexWriteParametersBuilder(const IndexWriteParameters &wp) @@ -114,6 +137,9 @@ class IndexWriteParametersBuilder float _alpha{defaults::ALPHA}; uint32_t _num_threads{defaults::NUM_THREADS}; uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; + bool _diverse_index{defaults::DIVERSE_INDEX}; + std::string _base_sellers{defaults::EMPTY_STRING}; + uint32_t _num_diverse_build{defaults::NUM_DIVERSE_BUILD}; }; } // namespace diskann diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 2be60595b..3686e3885 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -77,5 +77,7 @@ const char *UNIVERSAL_LABEL = "in the labels file instead of listing all labels for a node. DiskANN will not automatically assign a " "universal label to a node."; const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs"; +const char *DIVERSITY_FILE = "Seller diversity file for diverse index"; +const char *NUM_DIVERSE = "Number of diverse edges needed per node in each local region"; } // namespace program_options_utils diff --git a/src/index.cpp b/src/index.cpp index 970a1eb0f..acecd7519 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -101,6 +101,8 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrfilter_list_size; _indexingThreads = index_config.index_write_params->num_threads; _saturate_graph = index_config.index_write_params->saturate_graph; + _diverse_index = index_config.index_write_params->diversity_index; + _seller_file = index_config.index_write_params->base_seller_labels; if (index_config.index_search_params != nullptr) { @@ -110,7 +112,7 @@ Index::Index(const IndexConfig &index_config, std::shared_ptr Index::Index(Metric m, const size_t dim, const size_t max_points, const std::shared_ptr index_parameters, @@ -1213,24 +1215,40 @@ void Index::occlude_list(const uint32_t location, std::vector maxc) pool.resize(maxc); std::vector &occlude_factor = scratch->occlude_factor(); + std::vector> blockers(pool.size()); // occlude_list can be called with the same scratch more than once by // search_for_point_and_add_link through inter_insert. occlude_factor.clear(); // Initialize occlude_factor to pool.size() many 0.0f values for correctness occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); - float cur_alpha = alpha; + float cur_alpha = 1; + while (cur_alpha <= alpha && result.size() < degree) { + std::vector> blockers(pool.size()); // used for MIPS, where we store a value of eps in cur_alpha to // denote pruned out entries which we can skip in later rounds. float eps = cur_alpha + 0.01f; for (auto iter = pool.begin(); result.size() < degree && iter != pool.end(); ++iter) { + bool need_to_add_edge= true; + bool edge_added = false; + if (occlude_factor[iter - pool.begin()] == std::numeric_limits::min()) { + need_to_add_edge = false; // added as an edge in earlier round + edge_added =true; + } if (occlude_factor[iter - pool.begin()] > cur_alpha) { - continue; + if (blockers[iter - pool.begin()].size() >= _num_diverse_build) + need_to_add_edge = false; + else if (blockers[iter - pool.begin()].find(_location_to_seller[iter->id]) != blockers[iter - pool.begin()].end()) + need_to_add_edge = false; } - // Set the entry to float::max so that is not considered again - occlude_factor[iter - pool.begin()] = std::numeric_limits::max(); + + // Set the entry to float::max so that is not considered again, similarly add its own color as a blocking color +// blockers[iter - pool.begin()].insert(_location_to_seller[iter->id]); + + if (need_to_add_edge) { + occlude_factor[iter - pool.begin()] = std::numeric_limits::min(); // Add the entry to the result if its not been deleted, and doesn't // add a self loop if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end()) @@ -1240,13 +1258,15 @@ void Index::occlude_list(const uint32_t location, std::vectorid); } } + } + if (need_to_add_edge || edge_added) { // Update occlude factor for points from iter+1 to pool.end() for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++) { auto t = iter2 - pool.begin(); - if (occlude_factor[t] > alpha) - continue; +// if (occlude_factor[t] > alpha) +// continue; bool prune_allowed = true; if (_filtered_index) @@ -1274,6 +1294,9 @@ void Index::occlude_list(const uint32_t location, std::vector::max() : std::max(occlude_factor[t], iter2->distance / djk); + if (iter2->distance / djk > cur_alpha) { + blockers[t].insert(_location_to_seller[iter->id]); + } } else if (_dist_metric == diskann::Metric::INNER_PRODUCT) { @@ -1283,10 +1306,13 @@ void Index::occlude_list(const uint32_t location, std::vector cur_alpha * x) { occlude_factor[t] = std::max(occlude_factor[t], eps); + blockers[t].insert(_location_to_seller[iter->id]); } } } + } } + cur_alpha *= 1.2f; } } @@ -1672,6 +1698,12 @@ void Index::build_with_data_populated(const std::vector & } } + if (_diverse_index) { + uint64_t nrows; + parse_seller_file(_seller_file, nrows); + std::cout<<"Parsed seller file with " << nrows <<" rows" << std::endl; + } + uint32_t index_R = _indexingRange; uint32_t num_threads_index = _indexingThreads; uint32_t index_L = _indexingQueueSize; @@ -1984,6 +2016,53 @@ void Index::parse_label_file(const std::string &label_file, siz diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; } + +template +void Index::parse_seller_file(const std::string &label_file, size_t &num_points) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + _location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + _location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s)" << std::endl; +} + template void Index::_set_universal_label(const LabelType universal_label) { From 347c9cb4072b1c326c22d9c239f794a2a41ded32 Mon Sep 17 00:00:00 2001 From: rakri Date: Mon, 26 Aug 2024 07:54:07 -0700 Subject: [PATCH 03/18] code compiles but crashes --- apps/search_memory_index.cpp | 30 ++++++--- include/abstract_index.h | 6 ++ include/index.h | 8 ++- src/abstract_index.cpp | 25 ++++++++ src/index.cpp | 114 ++++++++++++++++++++++++++--------- 5 files changed, 144 insertions(+), 39 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..64ecc284d 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -30,7 +30,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, const float fail_if_recall_below) + const std::vector &query_filters, const float fail_if_recall_below, const uint32_t num_diverse_sellers = 0) { using TagT = uint32_t; // Load the query file @@ -157,6 +157,8 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, query_result_dists[test_id].resize(recall_at * query_num); std::vector res = std::vector(); + uint32_t maxLperSeller = (num_diverse_sellers > 0) ? (1.0*L)/(1.0*num_diverse_sellers) : L; + maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; auto s = std::chrono::high_resolution_clock::now(); omp_set_num_threads(num_threads); #pragma omp parallel for schedule(dynamic, 1) @@ -199,10 +201,17 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } else { + if (maxLperSeller != L) cmp_stats[i] = index - ->search(query + i * query_aligned_dim, recall_at, L, + ->diverse_search(query + i * query_aligned_dim, recall_at, L, maxLperSeller, + query_result_ids[test_id].data() + i * recall_at) + .second; + else + cmp_stats[i] = index + ->search(query + i * query_aligned_dim, recall_at, L, query_result_ids[test_id].data() + i * recall_at) .second; + } auto qe = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = qe - qs; @@ -279,7 +288,7 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K; + uint32_t num_threads, K, num_diverse_sellers; std::vector Lvec; bool print_all_recalls, dynamic, tags, show_qps_per_thread; float fail_if_recall_below = 0.0f; @@ -323,6 +332,9 @@ int main(int argc, char **argv) optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("num_diverse_sellers", + po::value(&num_diverse_sellers)->default_value(0), + "How many diverse sellers we want search results to contain"); optional_configs.add_options()( "dynamic", po::value(&dynamic)->default_value(false), "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); @@ -421,19 +433,19 @@ int main(int argc, char **argv) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); } else if (data_type == std::string("uint8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); } else { @@ -447,19 +459,19 @@ int main(int argc, char **argv) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); } else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); } else { diff --git a/include/abstract_index.h b/include/abstract_index.h index 059866f7c..c8b01105c 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -72,6 +72,10 @@ class AbstractIndex std::pair search(const data_type *query, const size_t K, const uint32_t L, IDType *indices, float *distances = nullptr); + template + std::pair diverse_search(const data_type *query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType *indices, + float *distances = nullptr); + // Filter support search // IndexType is either uint32_t or uint64_t template @@ -110,6 +114,8 @@ class AbstractIndex virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0; virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, std::any &indices, float *distances = nullptr) = 0; + virtual std::pair _diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + std::any &indices, float *distances = nullptr) = 0; virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, const size_t K, const uint32_t L, std::any &indices, float *distances) = 0; diff --git a/include/index.h b/include/index.h index b0e7ce812..71a1c37f2 100644 --- a/include/index.h +++ b/include/index.h @@ -132,7 +132,11 @@ template clas // can customize L on a per-query basis without tampering with "Parameters" template DISKANN_DLLEXPORT std::pair search(const T *query, const size_t K, const uint32_t L, - IDType *indices, float *distances = nullptr); + IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0); + + template + std::pair diverse_search(const T *query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType *indices, + float *distances = nullptr); // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, @@ -210,6 +214,8 @@ template clas const std::string &filter_label_raw, const size_t K, const uint32_t L, std::any &indices, float *distances) override; + virtual std::pair _diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + std::any &indices, float *distances = nullptr) override; virtual int _insert_point(const DataType &data_point, const TagType tag) override; virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 92665825f..fd2aafa20 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -22,6 +22,15 @@ std::pair AbstractIndex::search(const data_type *query, cons return _search(any_query, K, L, any_indices, distances); } +template +std::pair AbstractIndex::diverse_search(const data_type *query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + IDType *indices, float *distances) +{ + auto any_indices = std::any(indices); + auto any_query = std::any(query); + return _diverse_search(any_query, K, L, maxLperSeller, any_indices, distances); +} + template size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, bool use_filters, @@ -155,6 +164,22 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search AbstractIndex::search( const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const float *query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const uint8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const int8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const float *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const uint8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const int8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); + + + template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); diff --git a/src/index.cpp b/src/index.cpp index acecd7519..a8f428c67 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4,7 +4,7 @@ #include #include - +#include #include "boost/dynamic_bitset.hpp" #include "index_factory.h" #include "memory_mapper.h" @@ -103,6 +103,7 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrsaturate_graph; _diverse_index = index_config.index_write_params->diversity_index; _seller_file = index_config.index_write_params->base_seller_labels; + _num_diverse_build = index_config.index_write_params->num_diverse_sellers; if (index_config.index_search_params != nullptr) { @@ -292,6 +293,12 @@ void Index::save(const char *filename, bool compact_before_save if (!_save_as_one_file) { + if(_diverse_index) { + std::string index_seller_file = std::string(filename) + "_sellers.txt"; + std::filesystem::copy(_seller_file, index_seller_file); + std::cout<<"Saved seller file to " << index_seller_file <<"." << std::endl; + } + if (_filtered_index) { if (_label_to_start_id.size() > 0) @@ -590,6 +597,13 @@ void Index::load(const char *filename, uint32_t num_threads, ui throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } + std::string index_seller_file = std::string(filename) + "_sellers.txt"; + if(file_exists(index_seller_file)) { + uint64_t nrows_seller_file; + parse_seller_file(index_seller_file, nrows_seller_file); + _diverse_index = true; + } + if (file_exists(labels_file)) { _label_map = load_label_map(labels_map_file); @@ -1032,10 +1046,15 @@ void Index::search_for_point_and_prune(int location, uint32_t L const std::vector init_ids = get_init_ids(); const std::vector unused_filter_label; + uint32_t maxLperSeller = 0; + if (_diverse_index) { + maxLperSeller = (Lindex/_num_diverse_build > 0)? Lindex/_num_diverse_build : 1; + } + if (!use_filter) { _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false, maxLperSeller); } else { @@ -1051,7 +1070,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L _data_store->get_vector(location, scratch->aligned_query()); iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true, - _location_to_labels[location], false); + _location_to_labels[location], false, maxLperSeller); // combine candidate pools obtained with filter and unfiltered criteria. std::set best_candidate_pool; @@ -2176,10 +2195,46 @@ std::pair Index::_search(const DataType &qu } } + +template +std::pair Index::_diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + std::any &indices, float *distances) +{ + try + { + auto typed_query = std::any_cast(query); + if (typeid(uint32_t *) == indices.type()) + { + auto u32_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u32_ptr, distances, maxLperSeller); + } + else if (typeid(uint64_t *) == indices.type()) + { + auto u64_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u64_ptr, distances, maxLperSeller); + } + else + { + throw ANNException("Error: indices type can only be uint64_t or uint32_t.", -1); + } + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while searching. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + + + + template template std::pair Index::search(const T *query, const size_t K, const uint32_t L, - IdType *indices, float *distances) + IdType *indices, float *distances, const uint32_t maxLperSeller) { if (K > (uint64_t)L) { @@ -2204,7 +2259,7 @@ std::pair Index::search(const T *query, con _data_store->preprocess_query(query, scratch); - auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true, maxLperSeller); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); @@ -3576,30 +3631,31 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3640,30 +3696,30 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, From edd1bb1841f7a7fca8b648e7a214498cd02c429e Mon Sep 17 00:00:00 2001 From: rakri Date: Mon, 26 Aug 2024 10:26:27 -0700 Subject: [PATCH 04/18] fixed minor bug --- apps/build_memory_index.cpp | 4 +++- apps/search_memory_index.cpp | 2 ++ src/index.cpp | 13 +++++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index a9e3b61bd..435c7f3a7 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -73,7 +73,7 @@ int main(int argc, char **argv) program_options_utils::LABEL_TYPE_DESCRIPTION); optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), program_options_utils::DIVERSITY_FILE); - optional_configs.add_options()("NumDiverse", po::value(&Lf)->default_value(1), + optional_configs.add_options()("NumDiverse", po::value(&num_diverse_build)->default_value(1), program_options_utils::NUM_DIVERSE); @@ -130,6 +130,8 @@ int main(int argc, char **argv) size_t data_num, data_dim; diskann::get_bin_metadata(data_path, data_num, data_dim); + std::cout<<"Num diverse build: " << num_diverse_build << std::endl; + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) .with_filter_list_size(Lf) .with_alpha(alpha) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 64ecc284d..f49cfb889 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -68,6 +68,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } + query_num = 1; const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); auto config = diskann::IndexConfigBuilder() @@ -159,6 +160,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, uint32_t maxLperSeller = (num_diverse_sellers > 0) ? (1.0*L)/(1.0*num_diverse_sellers) : L; maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; + std::cout<<"MaxLperSeller = " << maxLperSeller << std::endl; auto s = std::chrono::high_resolution_clock::now(); omp_set_num_threads(num_threads); #pragma omp parallel for schedule(dynamic, 1) diff --git a/src/index.cpp b/src/index.cpp index a8f428c67..08219eea1 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -104,6 +104,7 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrdiversity_index; _seller_file = index_config.index_write_params->base_seller_labels; _num_diverse_build = index_config.index_write_params->num_diverse_sellers; + std::cout<<"Set _num_diverse_build to " << _num_diverse_build << std::endl; if (index_config.index_search_params != nullptr) { @@ -912,8 +913,15 @@ std::pair Index::iterate_to_fixed_point( while (best_L_nodes.has_unexpanded_node()) { + for (auto &x : color_to_nodes) { + if (x.second.size() > 0) { + std::cout< Index::iterate_to_fixed_point( for (size_t m = 0; m < id_scratch.size(); ++m) { if (diverse_search) { - auto cur_id = id_scratch[n]; + auto cur_id = id_scratch[m]; auto cur_dist = dist_scratch[m]; + //std::cout<::parse_seller_file(const std::string &label_file, si line_cnt++; } num_points = (size_t)line_cnt; - diskann::cout << "Identified " << sellers.size() << " distinct seller(s)" << std::endl; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; } template From 827a6703c9ca3b7e209e4f7ceafaa83f1f6290f8 Mon Sep 17 00:00:00 2001 From: rakri Date: Mon, 26 Aug 2024 11:53:32 -0700 Subject: [PATCH 05/18] fixed minor bug --- include/neighbor.h | 16 +++++++++++++--- src/index.cpp | 47 ++++++++++++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/include/neighbor.h b/include/neighbor.h index 1117a7f5e..ba98fdb96 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -95,18 +95,18 @@ class NeighborPriorityQueue // Deletes the item if found. - void delete_id(const uint32_t &id) + void delete_id(const Neighbor &nbr) { size_t lo = 0, hi = _size; size_t loc = std::numeric_limits::max(); while (lo < hi) { size_t mid = (lo + hi) >> 1; - if (id < _data[mid].id) + if (nbr.distance < _data[mid].distance) { hi = mid; } - else if (_data[mid].id == id) + else if (_data[mid].id == nbr.id) { loc = mid; break; @@ -121,6 +121,16 @@ class NeighborPriorityQueue { std::memmove(&_data[loc], &_data[loc+1], (_size - loc - 1) * sizeof(Neighbor)); _size--; + _cur = 0; + while (_cur < _size && _data[_cur].expanded) // RK: inefficient! + { + _cur++; + } + } else { + std::cout<<"Found a problem! " << lo <<" " << hi <<" " <_data) + std::cout< Index::iterate_to_fixed_point( while (best_L_nodes.has_unexpanded_node()) { - for (auto &x : color_to_nodes) { +/* for (auto &x : color_to_nodes) { if (x.second.size() > 0) { std::cout< Index::iterate_to_fixed_point( if (is_not_visited(id)) { id_scratch.push_back(id); - } - } - } - - // Mark nodes visited - for (auto id : id_scratch) - { if (fast_iterate) { inserted_into_pool_bs[id] = 1; @@ -998,8 +991,12 @@ std::pair Index::iterate_to_fixed_point( { inserted_into_pool_rs.insert(id); } + + } + } } + assert(dist_scratch.capacity() >= id_scratch.size()); compute_dists(id_scratch, dist_scratch); cmps += (uint32_t)id_scratch.size(); @@ -1007,6 +1004,17 @@ std::pair Index::iterate_to_fixed_point( // Insert pairs into the pool of candidates for (size_t m = 0; m < id_scratch.size(); ++m) { +/* std::cout<<"Going to insert " << id_scratch[m] << " (nbr of " << n <<"), color " << _location_to_seller[id_scratch[m]] << std::endl; + for (auto &x : color_to_nodes) { + if (x.second.size() > 0) { + std::cout< Index::iterate_to_fixed_point( if (cur_list.size() < maxLperSeller && best_L_nodes.size() < Lsize) { cur_list.insert(Neighbor(cur_id, cur_dist)); best_L_nodes.insert(Neighbor(cur_id, cur_dist)); - } else if (cur_list.size() == maxLperSeller && best_L_nodes.size() < Lsize) { + } else if (cur_list.size() == maxLperSeller) { if (cur_dist < cur_list[maxLperSeller-1].distance) { - best_L_nodes.delete_id(cur_list[maxLperSeller-1].id); + best_L_nodes.delete_id(cur_list[maxLperSeller-1]); cur_list.insert(Neighbor(cur_id, cur_dist)); best_L_nodes.insert(Neighbor(cur_id, cur_dist)); } } else if (cur_list.size() < maxLperSeller && best_L_nodes.size() == Lsize) { if (cur_dist < best_L_nodes[Lsize-1].distance) { - color_to_nodes[_location_to_seller[best_L_nodes[Lsize-1].id]].delete_id(best_L_nodes[Lsize-1].id); +/* if (color_to_nodes[_location_to_seller[best_L_nodes[Lsize-1].id]].size() == 0) { + std::cout<<"Trying to delete from empty Q. " << best_L_nodes[Lsize-1].id <<" of color " << _location_to_seller[best_L_nodes[Lsize-1].id] << std::endl; + }*/ + color_to_nodes[_location_to_seller[best_L_nodes[Lsize-1].id]].delete_id(best_L_nodes[Lsize-1]); cur_list.insert(Neighbor(cur_id, cur_dist)); best_L_nodes.insert(Neighbor(cur_id, cur_dist)); } - } else { - if (cur_dist < cur_list[maxLperSeller-1].distance) { - best_L_nodes.delete_id(cur_list[maxLperSeller-1].id); - cur_list.insert(Neighbor(cur_id, cur_dist)); - best_L_nodes.insert(Neighbor(cur_id, cur_dist)); - } } } else { @@ -2295,9 +2300,11 @@ std::pair Index::search(const T *query, con if (pos == K) break; } - if (pos < K) + while (pos < K) { - diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; + indices[pos] = std::numeric_limits::max(); + pos++; +// diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; } return retval; From b06ba4b85df50f5f278460d944f24e1f3782328f Mon Sep 17 00:00:00 2001 From: rakri Date: Thu, 29 Aug 2024 10:57:23 -0700 Subject: [PATCH 06/18] started work on diverse GT --- apps/search_memory_index.cpp | 2 +- apps/utils/CMakeLists.txt | 6 + apps/utils/compute_diverse_groundtruth.cpp | 629 +++++++++++++++++++++ 3 files changed, 636 insertions(+), 1 deletion(-) create mode 100644 apps/utils/compute_diverse_groundtruth.cpp diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index f49cfb889..03f8e6faa 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -68,7 +68,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } - query_num = 1; + //query_num = 1; const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); auto config = diskann::IndexConfigBuilder() diff --git a/apps/utils/CMakeLists.txt b/apps/utils/CMakeLists.txt index 3b8cf223c..98edfce5d 100644 --- a/apps/utils/CMakeLists.txt +++ b/apps/utils/CMakeLists.txt @@ -51,6 +51,12 @@ add_executable(compute_groundtruth compute_groundtruth.cpp) target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) + +# Compute ground truth thing outside of DiskANN main source that depends on MKL. +add_executable(compute_diverse_groundtruth compute_diverse_groundtruth.cpp) +target_include_directories(compute_diverse_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) +target_link_libraries(compute_diverse_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) + add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp) target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) diff --git a/apps/utils/compute_diverse_groundtruth.cpp b/apps/utils/compute_diverse_groundtruth.cpp new file mode 100644 index 000000000..d89adc71f --- /dev/null +++ b/apps/utils/compute_diverse_groundtruth.cpp @@ -0,0 +1,629 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WINDOWS +#include +#else +#include +#endif +#include "filter_utils.h" +#include "utils.h" + +// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED) + +#define PARTSIZE 10000000 +#define ALIGNMENT 512 + +// custom types (for readability) +typedef tsl::robin_set label_set; +typedef std::string path; + +namespace po = boost::program_options; + +template T div_round_up(const T numerator, const T denominator) +{ + return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); +} + +using pairIF = std::pair; +struct cmpmaxstruct +{ + bool operator()(const pairIF &l, const pairIF &r) + { + return l.second < r.second; + }; +}; + +using maxPQIFCS = std::priority_queue, cmpmaxstruct>; + +template T *aligned_malloc(const size_t n, const size_t alignment) +{ +#ifdef _WINDOWS + return (T *)_aligned_malloc(sizeof(T) * n, alignment); +#else + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); +#endif +} + +inline bool custom_dist(const std::pair &a, const std::pair &b) +{ + return a.second < b.second; +} + +void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) +{ + assert(points_l2sq != NULL); +#pragma omp parallel for schedule(static, 65536) + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); +} + +void distsq_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +{ + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, + ones_vec, nqueries, (float)1.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, + queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; +} + +void inner_prod_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +{ + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + + if (ones_vec_alloc) + delete[] ones_vec; +} + +void exact_knn(const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major +{ + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) + { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; +#pragma omp parallel for schedule(static, 4096) + for (int64_t i = 0; i < (int64_t)npoints; i++) + { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } + +#pragma omp parallel for schedule(static, 4096) + for (int64_t i = 0; i < (int64_t)nqueries; i++) + { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } + } + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); + } + + std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " + << dim << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) + { + int64_t q_b = b * q_batch_size; + int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) + { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); + } + else + { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; + +#pragma omp parallel for schedule(dynamic, 16) + for (long long q = q_b; q < q_e; q++) + { + maxPQIFCS point_dist; + for (size_t p = 0; p < k; p++) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + for (size_t p = k; p < npoints; p++) + { + if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + if (point_dist.size() > k) + point_dist.pop(); + } + for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) + { + closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; + dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; + point_dist.pop(); + } + assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, + dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); + } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; + } + + delete[] dist_matrix; + + delete[] points_l2sq; + delete[] queries_l2sq; + + if (metric == diskann::Metric::COSINE) + { + delete[] points; + delete[] queries; + } +} + +template inline int get_num_parts(const char *filename) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + uint32_t num_parts = + (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; +} + +template +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B" + << std::endl; + + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + T *data_T = new T[npts * ndims]; + reader.read((char *)data_T, sizeof(T) * npts * ndims); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts * ndims, ALIGNMENT); +#pragma omp parallel for schedule(dynamic, 32768) + for (int64_t i = 0; i < (int64_t)npts; i++) + { + for (int64_t j = 0; j < (int64_t)ndims; j++) + { + float cur_val_float = (float)data_T[i * ndims + j]; + std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float)); + } + } + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; +} + +template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) +{ + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; +} + +inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, + size_t ndims) +{ + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) + << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; +} + +template +std::vector>> processUnfilteredParts(const std::string &base_file, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag, std::vector &location_to_seller, uint32_t kperseller) +{ + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + metric); + + for (size_t i = 0; i < nqueries; i++) + { + for (size_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } + } + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); + } + return res; +}; + + +void parse_seller_file(const std::string &label_file, size_t &num_points, std::vector &location_to_seller) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + +template +int aux_main(const std::string &base_file, const std::string &query_file, const std::string &seller_file, const std::string >_file, size_t k, size_t kperseller, + const diskann::Metric &metric, const std::string &tags_file = std::string("")) +{ + size_t npoints, nqueries, dim; + + float *query_data; + + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; + + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector location_to_seller; + uint64_t num_pts_seller_file; + parse_seller_file(seller_file, num_pts_seller_file, location_to_seller); + + std::vector>> results = + processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag, location_to_seller, kperseller); + + for (size_t i = 0; i < nqueries; i++) + { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) + { + if (j == k) + break; + if (tags_enabled) + { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } + else + { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; + } + if (j < k) + std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; + } + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; +} + +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } +} + +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, seller_file; + uint64_t K, KperSeller; + + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), + "distance function "); + desc.add_options()("base_file", po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("seller_file", po::value(&seller_file)->required(), + "File containing the seller per point"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()("KperSeller", po::value(&KperSeller)->required(), + "Number of ground truth nearest neighbors to compute per Seller"); + desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; + return -1; + } + + try + { + if (data_type == std::string("float")) + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller,metric, tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } +} From 2427965d783ea75426f0866687050764a52f76e7 Mon Sep 17 00:00:00 2001 From: rakri Date: Thu, 29 Aug 2024 12:01:33 -0700 Subject: [PATCH 07/18] minor bug fix --- include/neighbor.h | 32 ++++++++++++++++++++++++-------- src/index.cpp | 28 +++++++++++++++++++++------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/include/neighbor.h b/include/neighbor.h index ba98fdb96..c95bc9955 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -99,21 +99,34 @@ class NeighborPriorityQueue { size_t lo = 0, hi = _size; size_t loc = std::numeric_limits::max(); - while (lo < hi) + while ((lo < hi) && loc == std::numeric_limits::max()) { size_t mid = (lo + hi) >> 1; if (nbr.distance < _data[mid].distance) { hi = mid; } - else if (_data[mid].id == nbr.id) + else if (nbr.distance > _data[mid].distance) { - loc = mid; - break; + lo = mid+1; } else { - lo = mid + 1; + uint32_t itr = 0; + for (;; itr++) { + if (mid + itr < hi) { + if (_data[mid+itr].id == nbr.id) { + loc = mid+itr; + break; + } + } + if(mid - itr >= lo) { + if (_data[mid-itr].id == nbr.id) { + loc = mid-itr; + break; + } + } + } } } @@ -127,9 +140,12 @@ class NeighborPriorityQueue _cur++; } } else { - std::cout<<"Found a problem! " << lo <<" " << hi <<" " <_data) - std::cout<_data) { + std::cout< Index::iterate_to_fixed_point( while (best_L_nodes.has_unexpanded_node()) { -/* for (auto &x : color_to_nodes) { - if (x.second.size() > 0) { - std::cout< Index::iterate_to_fixed_point( } } std::cout<<" == " < Index::iterate_to_fixed_point( if (cur_list.size() < maxLperSeller && best_L_nodes.size() < Lsize) { cur_list.insert(Neighbor(cur_id, cur_dist)); best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + run_flag = 1; } else if (cur_list.size() == maxLperSeller) { if (cur_dist < cur_list[maxLperSeller-1].distance) { best_L_nodes.delete_id(cur_list[maxLperSeller-1]); cur_list.insert(Neighbor(cur_id, cur_dist)); best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + run_flag = 2; } } else if (cur_list.size() < maxLperSeller && best_L_nodes.size() == Lsize) { if (cur_dist < best_L_nodes[Lsize-1].distance) { @@ -1040,12 +1036,30 @@ std::pair Index::iterate_to_fixed_point( color_to_nodes[_location_to_seller[best_L_nodes[Lsize-1].id]].delete_id(best_L_nodes[Lsize-1]); cur_list.insert(Neighbor(cur_id, cur_dist)); best_L_nodes.insert(Neighbor(cur_id, cur_dist)); + run_flag = 3; } } } else { best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); } + uint32_t sum_local_lists = 0; + for (auto &x : color_to_nodes) { + if (x.second.size() > 0) { +// std::cout< 0) { + std::cout< color_to_nodes; + uint32_t _Lsize = 0; + uint32_t _maxLperSeller = 0; + std::vector &_location_to_seller; + + bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = diskann::NeighborPriorityQueue(_Lsize); + } + void insert(uint32_t cur_id, float cur_dist) { + //std::cout< &running_results, diskann::Metric metric = diskann::Metric::L2) // queries in Col major { float *points_l2sq = new float[npoints]; @@ -212,24 +264,15 @@ void exact_knn(const size_t dim, const size_t k, #pragma omp parallel for schedule(dynamic, 16) for (long long q = q_b; q < q_e; q++) { - maxPQIFCS point_dist; - for (size_t p = 0; p < k; p++) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - for (size_t p = k; p < npoints; p++) - { - if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) - point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - if (point_dist.size() > k) - point_dist.pop(); - } - for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) + bestCandidates & cur_query_best_results = running_results[q]; +// for (size_t p = 0; p < k; p++) +// point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + for (size_t p = 0; p < npoints; p++) { - closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; - dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; - point_dist.pop(); + uint32_t cur_id = p + start_id; + float cur_dist = dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]; + cur_query_best_results.insert(cur_id, cur_dist); } - assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, - dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); } std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; } @@ -344,6 +387,7 @@ std::vector>> processUnfilteredParts(cons float *base_data = nullptr; int num_parts = get_num_parts(base_file.c_str()); std::vector>> res(nqueries); + std::vector running_results(nqueries, bestCandidates(k, kperseller, location_to_seller)); for (int p = 0; p < num_parts; p++) { size_t start_id = p * PARTSIZE; @@ -353,19 +397,22 @@ std::vector>> processUnfilteredParts(cons float *dist_closest_points_part = new float[nqueries * k]; auto part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, start_id, base_data, nqueries, query_data, kperseller, running_results, metric); for (size_t i = 0; i < nqueries; i++) { - for (size_t j = 0; j < part_k; j++) + auto & cur_results = running_results[i]; +// for (size_t j = 0; j < part_k; j++) + for (uint32_t x = 0; x < cur_results.best_L_nodes.size(); x++) { + auto &nbr = cur_results.best_L_nodes[x]; if (!location_to_tag.empty()) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + if (location_to_tag[nbr.id] == 0) continue; - res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); + res[i].push_back(std::make_pair((uint32_t)(nbr.id), + nbr.distance)); } } @@ -431,7 +478,7 @@ int aux_main(const std::string &base_file, const std::string &query_file, const size_t npoints, nqueries, dim; float *query_data; - + std::cout<<"Inside k=" << k <<", and kPerSeller=" << kperseller << std::endl; load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); if (nqueries > PARTSIZE) std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE @@ -616,7 +663,7 @@ int main(int argc, char **argv) if (data_type == std::string("float")) aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); if (data_type == std::string("int8")) - aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller,metric, tags_file); + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); if (data_type == std::string("uint8")) aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); } From 37d08e91b2b645db969bc2c7681346835c574f0f Mon Sep 17 00:00:00 2001 From: rakri Date: Fri, 6 Sep 2024 05:13:25 -0700 Subject: [PATCH 09/18] minor changes to search interface --- apps/search_memory_index.cpp | 26 +++++++++++++------------- src/index.cpp | 19 ++++++++++++++----- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 03f8e6faa..91b4eff87 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -30,7 +30,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, const float fail_if_recall_below, const uint32_t num_diverse_sellers = 0) + const std::vector &query_filters, const float fail_if_recall_below, const uint32_t max_L_per_seller = 0) { using TagT = uint32_t; // Load the query file @@ -158,9 +158,9 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, query_result_dists[test_id].resize(recall_at * query_num); std::vector res = std::vector(); - uint32_t maxLperSeller = (num_diverse_sellers > 0) ? (1.0*L)/(1.0*num_diverse_sellers) : L; + uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L; maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; - std::cout<<"MaxLperSeller = " << maxLperSeller << std::endl; + auto s = std::chrono::high_resolution_clock::now(); omp_set_num_threads(num_threads); #pragma omp parallel for schedule(dynamic, 1) @@ -290,7 +290,7 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K, num_diverse_sellers; + uint32_t num_threads, K, max_L_per_seller; std::vector Lvec; bool print_all_recalls, dynamic, tags, show_qps_per_thread; float fail_if_recall_below = 0.0f; @@ -334,9 +334,9 @@ int main(int argc, char **argv) optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("num_diverse_sellers", - po::value(&num_diverse_sellers)->default_value(0), - "How many diverse sellers we want search results to contain"); + optional_configs.add_options()("max_L_per_seller", + po::value(&max_L_per_seller)->default_value(0), + "How many results per seller we want search results to contain"); optional_configs.add_options()( "dynamic", po::value(&dynamic)->default_value(false), "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); @@ -435,19 +435,19 @@ int main(int argc, char **argv) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); } else if (data_type == std::string("uint8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); } else { @@ -461,19 +461,19 @@ int main(int argc, char **argv) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); } else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, num_diverse_sellers); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); } else { diff --git a/src/index.cpp b/src/index.cpp index 907b350cb..e143475fd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1039,10 +1039,7 @@ std::pair Index::iterate_to_fixed_point( run_flag = 3; } } - } - else { - best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); - } + uint32_t sum_local_lists = 0; for (auto &x : color_to_nodes) { if (x.second.size() > 0) { @@ -1060,6 +1057,11 @@ std::pair Index::iterate_to_fixed_point( exit(-1); } + } + else { + best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); + } + } } return std::make_pair(hops, cmps); @@ -1285,10 +1287,14 @@ void Index::occlude_list(const uint32_t location, std::vector cur_alpha) { + if (_diverse_index) { if (blockers[iter - pool.begin()].size() >= _num_diverse_build) need_to_add_edge = false; else if (blockers[iter - pool.begin()].find(_location_to_seller[iter->id]) != blockers[iter - pool.begin()].end()) need_to_add_edge = false; + } else { + need_to_add_edge = false; + } } // Set the entry to float::max so that is not considered again, similarly add its own color as a blocking color @@ -1341,9 +1347,11 @@ void Index::occlude_list(const uint32_t location, std::vector::max() : std::max(occlude_factor[t], iter2->distance / djk); + if (_diverse_index) { if (iter2->distance / djk > cur_alpha) { blockers[t].insert(_location_to_seller[iter->id]); } + } } else if (_dist_metric == diskann::Metric::INNER_PRODUCT) { @@ -1353,7 +1361,8 @@ void Index::occlude_list(const uint32_t location, std::vector cur_alpha * x) { occlude_factor[t] = std::max(occlude_factor[t], eps); - blockers[t].insert(_location_to_seller[iter->id]); + if (_diverse_index) + blockers[t].insert(_location_to_seller[iter->id]); } } } From c483ee06f4568c98cc052a39bbcd7950f812ccc6 Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 24 Sep 2024 09:09:50 -0700 Subject: [PATCH 10/18] search code needs some change to support post-processing and diversity search --- apps/search_memory_index.cpp | 130 ++++++++++++++++++++++++++++++++--- include/neighbor.h | 4 +- 2 files changed, 124 insertions(+), 10 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 91b4eff87..9af147c00 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -25,6 +25,97 @@ namespace po = boost::program_options; +struct bestCandidates { + diskann::NeighborPriorityQueue best_L_nodes; + tsl::robin_map color_to_nodes; + uint32_t _Lsize = 0; + uint32_t _maxLperSeller = 0; + std::vector &_location_to_seller; + + bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = diskann::NeighborPriorityQueue(_Lsize); + } + void insert(uint32_t cur_id, float cur_dist) { + //std::cout< &location_to_seller) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << " Search code: Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + + template int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix, const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, @@ -32,6 +123,13 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, const bool dynamic, const bool tags, const bool show_qps_per_thread, const std::vector &query_filters, const float fail_if_recall_below, const uint32_t max_L_per_seller = 0) { + std::vector location_to_sellers; + std::string seller_file = index_path +"_sellers.txt"; + if (file_exists(seller_file)) { + std::cout<<"Here" << std::endl; + uint64_t num_pts_seller_file; + parse_seller_file(seller_file, num_pts_seller_file, location_to_sellers); + } using TagT = uint32_t; // Load the query file T *query = nullptr; @@ -68,7 +166,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } - //query_num = 1; +// query_num = 1; const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); auto config = diskann::IndexConfigBuilder() @@ -159,7 +257,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, std::vector res = std::vector(); uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L; - maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; + //maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; auto s = std::chrono::high_resolution_clock::now(); omp_set_num_threads(num_threads); @@ -203,17 +301,33 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } else { - if (maxLperSeller != L) +/* if (maxLperSeller != L) cmp_stats[i] = index ->diverse_search(query + i * query_aligned_dim, recall_at, L, maxLperSeller, query_result_ids[test_id].data() + i * recall_at) - .second; - else + .second; */ +// else { + { + std::vector results(L,0); + std::vector dists(L,0); cmp_stats[i] = index - ->search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) + ->search(query + i * query_aligned_dim, L, L, + results.data(), dists.data()) .second; - + bestCandidates final_results(recall_at, maxLperSeller, location_to_sellers); + for (uint32_t rr = 0; rr < L; rr++) { + final_results.insert(results[rr], dists[rr]); + // std::cout< diff = qe - qs; diff --git a/include/neighbor.h b/include/neighbor.h index c95bc9955..c26416500 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -75,7 +75,7 @@ class NeighborPriorityQueue else { lo = mid + 1; - } + } } if (lo < _capacity) @@ -202,7 +202,7 @@ class NeighborPriorityQueue _cur = 0; } - private: + public: size_t _size, _capacity, _cur; std::vector _data; }; From 56c51a0e12223bec88de70feb89ba2ebf5d4d986 Mon Sep 17 00:00:00 2001 From: rakri Date: Wed, 25 Sep 2024 08:57:42 -0700 Subject: [PATCH 11/18] some code clean up --- apps/search_memory_index.cpp | 79 +++++----------------- apps/utils/compute_diverse_groundtruth.cpp | 13 ++-- include/neighbor.h | 48 +++++++++++++ 3 files changed, 70 insertions(+), 70 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 9af147c00..06056ae8b 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -25,49 +25,6 @@ namespace po = boost::program_options; -struct bestCandidates { - diskann::NeighborPriorityQueue best_L_nodes; - tsl::robin_map color_to_nodes; - uint32_t _Lsize = 0; - uint32_t _maxLperSeller = 0; - std::vector &_location_to_seller; - - bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { - _Lsize = Lsize; - _maxLperSeller = maxLperSeller; - best_L_nodes = diskann::NeighborPriorityQueue(_Lsize); - } - void insert(uint32_t cur_id, float cur_dist) { - //std::cout< &location_to_seller) { @@ -121,7 +78,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, const float fail_if_recall_below, const uint32_t max_L_per_seller = 0) + const std::vector &query_filters, const float fail_if_recall_below, const uint32_t max_L_per_seller = 0, const bool post_process = false) { std::vector location_to_sellers; std::string seller_file = index_path +"_sellers.txt"; @@ -301,32 +258,26 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } else { -/* if (maxLperSeller != L) + if (maxLperSeller != L && !post_process) cmp_stats[i] = index ->diverse_search(query + i * query_aligned_dim, recall_at, L, maxLperSeller, query_result_ids[test_id].data() + i * recall_at) - .second; */ -// else { - { + .second; + else { +// { std::vector results(L,0); std::vector dists(L,0); cmp_stats[i] = index ->search(query + i * query_aligned_dim, L, L, results.data(), dists.data()) .second; - bestCandidates final_results(recall_at, maxLperSeller, location_to_sellers); + diskann::bestCandidates final_results(recall_at, maxLperSeller, location_to_sellers); for (uint32_t rr = 0; rr < L; rr++) { final_results.insert(results[rr], dists[rr]); - // std::cout< Lvec; - bool print_all_recalls, dynamic, tags, show_qps_per_thread; + bool print_all_recalls, dynamic, tags, show_qps_per_thread, post_process; float fail_if_recall_below = 0.0f; po::options_description desc{ @@ -451,6 +402,10 @@ int main(int argc, char **argv) optional_configs.add_options()("max_L_per_seller", po::value(&max_L_per_seller)->default_value(0), "How many results per seller we want search results to contain"); + optional_configs.add_options()("post_process", + po::value(&post_process)->default_value(false), + "Whether to do vanilla search + post-processing for diversity"); + optional_configs.add_options()( "dynamic", po::value(&dynamic)->default_value(false), "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); @@ -549,19 +504,19 @@ int main(int argc, char **argv) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); } else if (data_type == std::string("uint8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); } else { @@ -575,19 +530,19 @@ int main(int argc, char **argv) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); } else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); } else { diff --git a/apps/utils/compute_diverse_groundtruth.cpp b/apps/utils/compute_diverse_groundtruth.cpp index 2e318913c..1f270055f 100644 --- a/apps/utils/compute_diverse_groundtruth.cpp +++ b/apps/utils/compute_diverse_groundtruth.cpp @@ -105,7 +105,7 @@ void distsq_to_points(const size_t dim, delete[] ones_vec; } - +/* struct bestCandidates { diskann::NeighborPriorityQueue best_L_nodes; tsl::robin_map color_to_nodes; @@ -137,9 +137,6 @@ struct bestCandidates { } } else if (cur_list.size() < _maxLperSeller && best_L_nodes.size() == _Lsize) { if (cur_dist < best_L_nodes[_Lsize-1].distance) { -/* if (color_to_nodes[_location_to_seller[best_L_nodes[Lsize-1].id]].size() == 0) { - std::cout<<"Trying to delete from empty Q. " << best_L_nodes[Lsize-1].id <<" of color " << _location_to_seller[best_L_nodes[Lsize-1].id] << std::endl; - }*/ color_to_nodes[_location_to_seller[best_L_nodes[_Lsize-1].id]].delete_id(best_L_nodes[_Lsize-1]); cur_list.insert(diskann::Neighbor(cur_id, cur_dist)); best_L_nodes.insert(diskann::Neighbor(cur_id, cur_dist)); @@ -148,7 +145,7 @@ struct bestCandidates { } } }; - +*/ void inner_prod_to_points(const size_t dim, float *dist_matrix, // Col Major, cols are queries, rows are points @@ -183,7 +180,7 @@ void exact_knn(const size_t dim, const size_t k, // corresponding closes_points size_t npoints, size_t start_id, float *points_in, // points in Col major - size_t nqueries, float *queries_in, uint32_t kPerSeller, std::vector &running_results, + size_t nqueries, float *queries_in, uint32_t kPerSeller, std::vector &running_results, diskann::Metric metric = diskann::Metric::L2) // queries in Col major { float *points_l2sq = new float[npoints]; @@ -264,7 +261,7 @@ void exact_knn(const size_t dim, const size_t k, #pragma omp parallel for schedule(dynamic, 16) for (long long q = q_b; q < q_e; q++) { - bestCandidates & cur_query_best_results = running_results[q]; + diskann::bestCandidates & cur_query_best_results = running_results[q]; // for (size_t p = 0; p < k; p++) // point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); for (size_t p = 0; p < npoints; p++) @@ -387,7 +384,7 @@ std::vector>> processUnfilteredParts(cons float *base_data = nullptr; int num_parts = get_num_parts(base_file.c_str()); std::vector>> res(nqueries); - std::vector running_results(nqueries, bestCandidates(k, kperseller, location_to_seller)); + std::vector running_results(nqueries, diskann::bestCandidates(k, kperseller, location_to_seller)); for (int p = 0; p < num_parts; p++) { size_t start_id = p * PARTSIZE; diff --git a/include/neighbor.h b/include/neighbor.h index c26416500..feb628e79 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -6,6 +6,7 @@ #include #include #include +#include #include "utils.h" namespace diskann @@ -207,4 +208,51 @@ class NeighborPriorityQueue std::vector _data; }; + +struct bestCandidates { + NeighborPriorityQueue best_L_nodes; + tsl::robin_map color_to_nodes; + uint32_t _Lsize = 0; + uint32_t _maxLperSeller = 0; + std::vector &_location_to_seller; + + bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = NeighborPriorityQueue(_Lsize); + } + void insert(uint32_t cur_id, float cur_dist) { + //std::cout< Date: Sat, 28 Sep 2024 03:25:12 -0700 Subject: [PATCH 12/18] code is now much more performant, can try to squeeze more juice after paper deadline --- apps/search_memory_index.cpp | 2 +- include/neighbor.h | 19 +++++++ include/scratch.h | 7 ++- src/index.cpp | 98 +++++++++++------------------------- src/scratch.cpp | 6 ++- 5 files changed, 60 insertions(+), 72 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 06056ae8b..98bc51f06 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -123,7 +123,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } -// query_num = 1; + //query_num = 2; const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); auto config = diskann::IndexConfigBuilder() diff --git a/include/neighbor.h b/include/neighbor.h index feb628e79..1363f85f2 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -48,6 +48,10 @@ class NeighborPriorityQueue { } + void setup(uint32_t capacity) { + _data.resize(capacity+1); + _capacity = capacity; + } // Inserts the item ordered into the set up to the sets capacity. // The item will be dropped if it is the same id as an exiting // set item or it has a greated distance than the final @@ -216,11 +220,26 @@ struct bestCandidates { uint32_t _maxLperSeller = 0; std::vector &_location_to_seller; + bestCandidates(std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + } + bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { _Lsize = Lsize; _maxLperSeller = maxLperSeller; best_L_nodes = NeighborPriorityQueue(_Lsize); } + + void clear() { + best_L_nodes.clear(); + color_to_nodes.clear(); + } + + void setup(uint32_t Lsize, uint32_t maxLperSeller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = NeighborPriorityQueue(_Lsize); + } + void insert(uint32_t cur_id, float cur_dist) { //std::cout< class InMemQueryScratch : public AbstractScratch public: ~InMemQueryScratch(); InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim, - size_t alignment_factor, bool init_pq_scratch = false); + size_t alignment_factor, std::vector &location_to_sellers, bool init_pq_scratch = false); void resize_for_new_L(uint32_t new_search_l); void clear(); @@ -61,6 +61,10 @@ template class InMemQueryScratch : public AbstractScratch { return _best_l_nodes; } + inline bestCandidates &best_diverse_nodes() + { + return _best_diverse_nodes; + } inline std::vector &occlude_factor() { return _occlude_factor; @@ -107,6 +111,7 @@ template class InMemQueryScratch : public AbstractScratch // _best_l_nodes is reserved for storing best L entries // Underlying storage is L+1 to support inserts NeighborPriorityQueue _best_l_nodes; + bestCandidates _best_diverse_nodes; // _occlude_factor.size() >= pool.size() in occlude_list function // _pool is clipped to maxc in occlude_list before affecting _occlude_factor diff --git a/src/index.cpp b/src/index.cpp index e143475fd..318014461 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -192,7 +192,7 @@ void Index::initialize_query_scratch(uint32_t num_threads, uint for (uint32_t i = 0; i < num_threads; i++) { auto scratch = new InMemQueryScratch(search_l, indexing_l, r, maxc, dim, _data_store->get_aligned_dim(), - _data_store->get_alignment_factor(), _pq_dist); + _data_store->get_alignment_factor(), _location_to_seller, _pq_dist); _query_scratch.push(scratch); } } @@ -816,10 +816,17 @@ std::pair Index::iterate_to_fixed_point( else diverse_search = true; std::vector &expanded_nodes = scratch->pool(); - NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); - best_L_nodes.reserve(Lsize); + NeighborPriorityQueue &best_L_nodes_ref = scratch->best_l_nodes(); + bestCandidates &best_diverse_nodes_ref = scratch->best_diverse_nodes(); + best_L_nodes_ref.reserve(Lsize); + best_diverse_nodes_ref.setup(Lsize, maxLperSeller); - tsl::robin_map color_to_nodes; + NeighborPriorityQueue* best_L_nodes; + if(diverse_search) { + best_L_nodes = &(best_diverse_nodes_ref.best_L_nodes); + } else { + best_L_nodes = &(best_L_nodes_ref); + } tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); @@ -898,12 +905,10 @@ std::pair Index::iterate_to_fixed_point( distance = distances[0]; Neighbor nn = Neighbor(id, distance); - best_L_nodes.insert(nn); if (diverse_search) { - auto &col = _location_to_seller[id]; - if (color_to_nodes.find(col) == color_to_nodes.end()) - color_to_nodes[col] = NeighborPriorityQueue(maxLperSeller); - color_to_nodes[col].insert(nn); + best_diverse_nodes_ref.insert(id, distance); + } else { + best_L_nodes->insert(nn); } } } @@ -911,9 +916,9 @@ std::pair Index::iterate_to_fixed_point( uint32_t hops = 0; uint32_t cmps = 0; - while (best_L_nodes.has_unexpanded_node()) + while (best_L_nodes->has_unexpanded_node()) { - auto nbr = best_L_nodes.closest_unexpanded(); + auto nbr = best_L_nodes->closest_unexpanded(); auto n = nbr.id; // std::cout< Index::iterate_to_fixed_point( std::cout<<" == " < 0) { -// std::cout< 0) { - std::cout< location_to_sellers; std::string seller_file = index_path +"_sellers.txt"; if (file_exists(seller_file)) { @@ -209,12 +210,18 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, continue; } - query_result_ids[test_id].resize(recall_at * query_num); - query_result_dists[test_id].resize(recall_at * query_num); + query_result_ids[test_id].resize(recall_at * query_num, std::numeric_limits::max()); + query_result_dists[test_id].resize(recall_at * query_num, std::numeric_limits::max()); std::vector res = std::vector(); - uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L; + //uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L; + //maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; + uint32_t maxLperSeller = max_K_per_seller; + if (diverse_search && scale_seller_limits) { + maxLperSeller = (1.0*L* max_K_per_seller)/(1.0*recall_at); + // std::cout< results(L,std::numeric_limits::max()); + std::vector dists(L,std::numeric_limits::max()); + uint32_t K_to_use = (post_process == true) ? L : recall_at; + + if (diverse_search) { + cmp_stats[i] = index - ->diverse_search(query + i * query_aligned_dim, recall_at, L, maxLperSeller, - query_result_ids[test_id].data() + i * recall_at) - .second; - else { + ->diverse_search(query + i * query_aligned_dim, K_to_use, L, maxLperSeller, + results.data(), dists.data()) + .second; + } else { // { - std::vector results(L,0); - std::vector dists(L,0); cmp_stats[i] = index - ->search(query + i * query_aligned_dim, L, L, + ->search(query + i * query_aligned_dim, K_to_use, L, results.data(), dists.data()) .second; - diskann::bestCandidates final_results(recall_at, maxLperSeller, location_to_sellers); + } + if (post_process) { + diskann::bestCandidates final_results(recall_at, max_K_per_seller, location_to_sellers); for (uint32_t rr = 0; rr < L; rr++) { final_results.insert(results[rr], dists[rr]); } - for (uint32_t ctr = 0; ctr < final_results.best_L_nodes.size(); ctr++) { + + for (uint32_t ctr = 0; ctr < std::min(final_results.best_L_nodes.size(), (uint64_t)recall_at); ctr++) { query_result_ids[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].id; + query_result_dists[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].distance; + } + } else { + for (uint32_t ctr = 0; ctr < std::min(results.size(),(uint64_t)recall_at); ctr++) { + query_result_ids[test_id][recall_at * i + ctr] = results[ctr]; + query_result_dists[test_id][recall_at * i + ctr] = dists[ctr]; } } } @@ -298,7 +317,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) { recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, curr_recall)); + query_result_ids[test_id].data(), recall_at, curr_recall, query_result_dists[test_id].data())); +// recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, +// query_result_ids[test_id].data(), recall_at, curr_recall)); + } } @@ -357,7 +379,7 @@ int main(int argc, char **argv) query_filters_file; uint32_t num_threads, K, max_L_per_seller; std::vector Lvec; - bool print_all_recalls, dynamic, tags, show_qps_per_thread, post_process; + bool print_all_recalls, dynamic, tags, show_qps_per_thread, post_process, diverse_search, scale_seller_limits; float fail_if_recall_below = 0.0f; po::options_description desc{ @@ -399,12 +421,19 @@ int main(int argc, char **argv) optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()("max_L_per_seller", + optional_configs.add_options()("max_K_per_seller", po::value(&max_L_per_seller)->default_value(0), "How many results per seller we want search results to contain"); + optional_configs.add_options()("diverse_search", + po::value(&diverse_search)->default_value(false), + "Whether to run diverse search or baseline search"); + optional_configs.add_options()("scale_seller_limits", + po::value(&scale_seller_limits)->default_value(false), + "Whether to run scale the max_L_per_seller based on the L value"); optional_configs.add_options()("post_process", po::value(&post_process)->default_value(false), - "Whether to do vanilla search + post-processing for diversity"); + "Whether to post-processing to ensure correct diversity"); + optional_configs.add_options()( "dynamic", po::value(&dynamic)->default_value(false), @@ -504,19 +533,19 @@ int main(int argc, char **argv) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("uint8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else { @@ -530,19 +559,19 @@ int main(int argc, char **argv) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, post_process); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else { diff --git a/apps/utils/compute_diverse_groundtruth.cpp b/apps/utils/compute_diverse_groundtruth.cpp index 1f270055f..48eac0a52 100644 --- a/apps/utils/compute_diverse_groundtruth.cpp +++ b/apps/utils/compute_diverse_groundtruth.cpp @@ -397,6 +397,13 @@ std::vector>> processUnfilteredParts(cons exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, start_id, base_data, nqueries, query_data, kperseller, running_results, metric); + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); + } + for (size_t i = 0; i < nqueries; i++) { auto & cur_results = running_results[i]; @@ -413,11 +420,6 @@ std::vector>> processUnfilteredParts(cons } } - delete[] closest_points_part; - delete[] dist_closest_points_part; - - diskann::aligned_free(base_data); - } return res; }; @@ -521,8 +523,14 @@ int aux_main(const std::string &base_file, const std::string &query_file, const ++j; } - if (j < k) + if (j < k) { std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; + while (j::max(); + closest_points[i * k + j] = std::numeric_limits::max(); + j++; + } + } } save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); diff --git a/include/neighbor.h b/include/neighbor.h index 1363f85f2..5cd852f4f 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -44,12 +44,12 @@ class NeighborPriorityQueue { } - explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1) + explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1, Neighbor(std::numeric_limits::max(), std::numeric_limits::max())) { } void setup(uint32_t capacity) { - _data.resize(capacity+1); + _data.resize(capacity+1,Neighbor(std::numeric_limits::max(), std::numeric_limits::max())); _capacity = capacity; } // Inserts the item ordered into the set up to the sets capacity. diff --git a/include/utils.h b/include/utils.h index d3af5c3a9..532bffd87 100644 --- a/include/utils.h +++ b/include/utils.h @@ -673,7 +673,7 @@ inline void copy_file(std::string in_file, std::string out_file) } DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, - unsigned *our_results, unsigned dim_or, unsigned recall_at); + unsigned *our_results, unsigned dim_or, unsigned recall_at, float* algo_distances = nullptr); DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, unsigned *our_results, unsigned dim_or, unsigned recall_at, diff --git a/src/utils.cpp b/src/utils.cpp index 3773cda22..74875d4be 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -126,14 +126,19 @@ void normalize_data_file(const std::string &inFileName, const std::string &outFi diskann::cout << "Wrote normalized points to file: " << outFileName << std::endl; } + double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs, - uint32_t *our_results, uint32_t dim_or, uint32_t recall_at) + uint32_t *our_results, uint32_t dim_or, uint32_t recall_at, float* algo_distances) { + bool use_distances_to_break_ties = false; + if (algo_distances != nullptr) { + use_distances_to_break_ties = true; + } double total_recall = 0; std::set gt, res; - for (size_t i = 0; i < num_queries; i++) { + if (!use_distances_to_break_ties) { gt.clear(); res.clear(); uint32_t *gt_vec = gold_std + dim_gs * i; @@ -160,6 +165,14 @@ double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist } } total_recall += cur_recall; + } else { // only works if dim_or == dim_gs. Not for the k-recall@k' regime. + uint32_t cur_recall =0; + for (uint32_t rr = 0; rr < std::min(dim_or, dim_gs); rr++) { + if (algo_distances[i*dim_or + rr] <= gs_dist[i*dim_gs + (recall_at-1)]) + cur_recall++; + } + total_recall += cur_recall; + } } return total_recall / (num_queries) * (100.0 / recall_at); } From 641054854030b90aec1103508aa5c5ebdd85ba99 Mon Sep 17 00:00:00 2001 From: rakri Date: Mon, 25 Nov 2024 09:33:52 -0800 Subject: [PATCH 14/18] added untested code for diversity in PQFlashIndex --- apps/search_disk_index.cpp | 4 +- include/pq_flash_index.h | 19 ++++- include/scratch.h | 5 +- python/src/static_disk_index.cpp | 2 +- src/disk_utils.cpp | 2 +- src/index.cpp | 2 +- src/pq_flash_index.cpp | 140 ++++++++++++++++++++++++++----- src/scratch.cpp | 5 +- 8 files changed, 146 insertions(+), 33 deletions(-) diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..16ae7fbee 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -232,7 +232,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, use_reorder_data, stats + i); + optimized_beamwidth, std::numeric_limits::max(), use_reorder_data, stats + i); } else { @@ -247,7 +247,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, + query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, std::numeric_limits::max(), use_reorder_data, stats + i); } } diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index ba5258e18..f8da87fdd 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -62,23 +62,23 @@ template class PQFlashIndex const bool shuffle = false); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data = false, + const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data = false, + const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); @@ -118,10 +118,14 @@ template class PQFlashIndex DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); std::unordered_map load_label_map(std::basic_istream &infile); DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); + + DISKANN_DLLEXPORT void parse_seller_file(const std::string &label_file, size_t &num_pts_labels); + void reset_stream_for_reading(std::basic_istream &infile); // sector # on disk where node_id is present with in the graph part @@ -234,6 +238,13 @@ template class PQFlashIndex tsl::robin_map> _real_to_dummy_map; std::unordered_map _label_map; + + bool _diverse_index = false; + uint32_t _max_L_per_seller = 0; + std::vector _location_to_seller; + std::string _seller_file; + + #ifdef EXEC_ENV_OLS // Set to a larger value than the actual header to accommodate // any additions we make to the header. This is an outer limit diff --git a/include/scratch.h b/include/scratch.h index 83240e77c..af1c6e421 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -154,8 +154,9 @@ template class SSDQueryScratch : public AbstractScratch tsl::robin_set visited; NeighborPriorityQueue retset; std::vector full_retset; + bestCandidates best_diverse_nodes; - SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); + SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers); ~SSDQueryScratch(); void reset(); @@ -167,7 +168,7 @@ template class SSDThreadData SSDQueryScratch scratch; IOContext ctx; - SSDThreadData(size_t aligned_dim, size_t visited_reserve); + SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers); void clear(); }; diff --git a/python/src/static_disk_index.cpp b/python/src/static_disk_index.cpp index 9e86b0ad5..6bf307e28 100644 --- a/python/src/static_disk_index.cpp +++ b/python/src/static_disk_index.cpp @@ -65,7 +65,7 @@ NeighborsAndDistances StaticDiskIndex
::search( std::vector u64_ids(knn); diskann::QueryStats stats; - _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, false, + _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, std::numeric_limits::max(), false, &stats); auto r = ids.mutable_unchecked<1>(); diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 016560217..51c42daed 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -813,7 +813,7 @@ uint32_t optimize_beamwidth(std::unique_ptr> &p { pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L, tuning_sample_result_ids_64.data() + (i * 1), - tuning_sample_result_dists.data() + (i * 1), cur_bw, false, stats + i); + tuning_sample_result_dists.data() + (i * 1), cur_bw, std::numeric_limits::max(), false, stats + i); } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; diff --git a/src/index.cpp b/src/index.cpp index 318014461..873c7ceca 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -598,7 +598,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - std::string index_seller_file = std::string(filename) + "_sellers.txt"; + std::string index_seller_file = std::string(filename) + "_sellers.txt"; if(file_exists(index_seller_file)) { uint64_t nrows_seller_file; parse_seller_file(index_seller_file, nrows_seller_file); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index fbb81d55f..c117440e4 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -131,7 +131,7 @@ void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visi { #pragma omp critical { - SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); + SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve, this->_location_to_seller); this->reader->register_thread(); data->ctx = this->reader->get_ctx(); this->_thread_data.push(data); @@ -326,7 +326,7 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, - tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); + tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, std::numeric_limits::max(), false); } std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(), @@ -752,6 +752,53 @@ void PQFlashIndex::parse_label_file(std::basic_istream &infile, reset_stream_for_reading(infile); } +template +void PQFlashIndex::parse_seller_file(const std::string &label_file, size_t &num_points) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + _location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + _location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + template void PQFlashIndex::set_universal_label(const LabelT &label) { _use_universal_label = true; @@ -1013,6 +1060,18 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons << std::endl; } + +#ifndef EXEC_ENV_OLS +// TODO: Make this friendly for DLVS + this->_seller_file = std ::string(_disk_index_file) + "_sellers.txt"; + if(file_exists(this->_seller_file)) { + uint64_t nrows_seller_file; + parse_seller_file(this->_seller_file, nrows_seller_file); + this->_diverse_index = true; + } +#endif + + // read index metadata #ifdef EXEC_ENV_OLS // This is a bit tricky. We have to read the header from the @@ -1241,31 +1300,31 @@ bool getNextCompletedRequest(std::shared_ptr &reader, IOConte template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, + uint64_t *indices, float *distances, const uint64_t beam_width, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), max_l_per_seller, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, - std::numeric_limits::max(), use_reorder_data, stats); + std::numeric_limits::max(), max_l_per_seller, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data, + uint64_t *indices, float *distances, const uint64_t beam_width, + const uint32_t io_limit, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, max_l_per_seller, use_reorder_data, stats); } @@ -1273,10 +1332,15 @@ template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data, + const uint32_t io_limit, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { + bool diverse_search = false; + if (max_l_per_seller != std::numeric_limits::max()) + diverse_search = true; + + uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, @@ -1358,8 +1422,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t Timer query_timer, io_timer, cpu_timer; tsl::robin_set &visited = query_scratch->visited; - NeighborPriorityQueue &retset = query_scratch->retset; - retset.reserve(l_search); + //NeighborPriorityQueue &retset = query_scratch->retset; + bestCandidates &best_diverse_nodes_ref = query_scratch->best_diverse_nodes; + + NeighborPriorityQueue* retset; + if(diverse_search) { + best_diverse_nodes_ref.setup(l_search, max_l_per_seller); + retset = &(best_diverse_nodes_ref.best_L_nodes); + } else { + retset = &(query_scratch->retset); + retset->reserve(l_search); + } + std::vector &full_retset = query_scratch->full_retset; uint32_t best_medoid = 0; @@ -1402,7 +1476,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } compute_dists(&best_medoid, 1, dist_scratch); - retset.insert(Neighbor(best_medoid, dist_scratch[0])); + if (diverse_search) { + best_diverse_nodes_ref.insert(best_medoid, dist_scratch[0]); + } else { + retset->insert(Neighbor(best_medoid, dist_scratch[0])); + } + + //retset->insert(Neighbor(best_medoid, dist_scratch[0])); visited.insert(best_medoid); uint32_t cmps = 0; @@ -1419,7 +1499,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::vector>> cached_nhoods; cached_nhoods.reserve(2 * beam_width); - while (retset.has_unexpanded_node() && num_ios < io_limit) + while (retset->has_unexpanded_node() && num_ios < io_limit) { // clear iteration state frontier.clear(); @@ -1429,9 +1509,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t sector_scratch_idx = 0; // find new beam uint32_t num_seen = 0; - while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) + while (retset->has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) { - auto nbr = retset.closest_unexpanded(); + auto nbr = retset->closest_unexpanded(); num_seen++; auto iter = _nhood_cache.find(nbr.id); if (iter != _nhood_cache.end()) @@ -1533,8 +1613,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t continue; cmps++; float dist = dist_scratch[m]; - Neighbor nn(id, dist); - retset.insert(nn); + +// retset->insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, dist); + } else { + retset->insert(Neighbor(id, dist)); + } } } } @@ -1602,7 +1687,12 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } Neighbor nn(id, dist); - retset.insert(nn); +// retset->insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, dist); + } else { + retset->insert(Neighbor(id, dist)); + } } } @@ -1616,6 +1706,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } // re-sort by distance + std::sort(full_retset.begin(), full_retset.end()); if (use_reorder_data) @@ -1668,6 +1759,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::sort(full_retset.begin(), full_retset.end()); } + + if (diverse_search) { + best_diverse_nodes_ref.clear(); + for (auto &x : full_retset) { + best_diverse_nodes_ref.insert(x.id, x.distance); + } + full_retset = best_diverse_nodes_ref.best_L_nodes._data; + } + // copy k_search values for (uint64_t i = 0; i < k_search; i++) { @@ -1725,7 +1825,7 @@ uint32_t PQFlashIndex::range_search(const T *query1, const double ran cur_bw = (cur_bw > 100) ? 100 : cur_bw; for (auto &x : distances) x = std::numeric_limits::max(); - this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, false, stats); + this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, std::numeric_limits::max(), false, stats); for (uint32_t i = 0; i < l_search; i++) { if (distances[i] > (float)range) diff --git a/src/scratch.cpp b/src/scratch.cpp index 650c0a1ce..a7a7e9c98 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -93,9 +93,10 @@ template void SSDQueryScratch::reset() visited.clear(); retset.clear(); full_retset.clear(); + best_diverse_nodes.clear(); } -template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) +template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers) : best_diverse_nodes(location_to_sellers) { size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); @@ -123,7 +124,7 @@ template SSDQueryScratch::~SSDQueryScratch() } template -SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve) +SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers) : scratch(aligned_dim, visited_reserve, location_to_sellers) { } From 3d0e0ce6e83311ba380ccf65866fd4eed881168d Mon Sep 17 00:00:00 2001 From: rakri Date: Mon, 10 Feb 2025 22:36:02 -0800 Subject: [PATCH 15/18] minor pq fix --- apps/utils/generate_pq.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp index a881b1104..50540e0fb 100644 --- a/apps/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -31,7 +31,7 @@ bool generate_pq(const std::string &data_path, const std::string &index_prefix_p (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); } diskann::generate_pq_data_from_pivots(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, - pq_pivots_path, pq_compressed_vectors_path, true); + pq_pivots_path, pq_compressed_vectors_path, opq); delete[] train_data; From 803fc33a3e2e9ac21d4a51d998352d4d1cdc0847 Mon Sep 17 00:00:00 2001 From: rakri Date: Thu, 13 Feb 2025 04:14:26 -0800 Subject: [PATCH 16/18] minor code for adding diversity in driver file search_disk_index --- apps/search_disk_index.cpp | 22 +++++++++++++--------- src/pq_flash_index.cpp | 11 ++++++++--- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 16ae7fbee..203079e8e 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -53,7 +53,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, const std::vector &Lvec, const float fail_if_recall_below, - const std::vector &query_filters, const bool use_reorder_data = false) + const std::vector &query_filters, const bool use_reorder_data = false, const uint32_t max_K_per_seller = std::numeric_limits::max()) { diskann::cout << "Search parameters: #threads: " << num_threads << ", "; if (beamwidth <= 0) @@ -232,7 +232,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, std::numeric_limits::max(), use_reorder_data, stats + i); + optimized_beamwidth, max_K_per_seller, use_reorder_data, stats + i); } else { @@ -314,7 +314,8 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + uint32_t max_K_per_seller = std::numeric_limits::max(); std::vector Lvec; bool use_reorder_data = false; float fail_if_recall_below = 0.0f; @@ -372,6 +373,9 @@ int main(int argc, char **argv) optional_configs.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), program_options_utils::FAIL_IF_RECALL_BELOW); + optional_configs.add_options()("max_K_per_seller", po::value(&max_K_per_seller)->default_value(std::numeric_limits::max()), + "Diverse search, max number of results per seller"); + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -451,15 +455,15 @@ int main(int argc, char **argv) if (data_type == std::string("float")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("int8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("uint8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; @@ -471,15 +475,15 @@ int main(int argc, char **argv) if (data_type == std::string("float")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("int8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("uint8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index c117440e4..b6293c390 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1332,14 +1332,17 @@ template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const uint32_t max_l_per_seller, const bool use_reorder_data, + const uint32_t io_limit, const uint32_t max_k_per_seller, const bool use_reorder_data, QueryStats *stats) { bool diverse_search = false; - if (max_l_per_seller != std::numeric_limits::max()) + uint32_t max_l_per_seller = std::numeric_limits::max(); + if (max_k_per_seller != std::numeric_limits::max()) + { diverse_search = true; - + max_l_per_seller = max_k_per_seller * (l_search / k_search); + } uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) @@ -1762,6 +1765,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t if (diverse_search) { best_diverse_nodes_ref.clear(); + best_diverse_nodes_ref.setup(k_search, max_k_per_seller); + for (auto &x : full_retset) { best_diverse_nodes_ref.insert(x.id, x.distance); } From 802870f00dde4d738dc347290da169751cbbcb7a Mon Sep 17 00:00:00 2001 From: rakri Date: Tue, 18 Feb 2025 03:20:12 -0800 Subject: [PATCH 17/18] added diversity in build index --- apps/build_disk_index.cpp | 16 +++++++++------- include/disk_utils.h | 2 +- src/disk_utils.cpp | 19 ++++++++++++------- src/pq_flash_index.cpp | 5 +++-- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..56e4f2723 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -16,7 +16,7 @@ namespace po = boost::program_options; int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, - label_type; + label_type, seller_file; uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; float B, M; bool append_reorder_data = false; @@ -78,6 +78,8 @@ int main(int argc, char **argv) "internally where each node has a maximum F labels."); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), + "In case of diverse index, need the seller file"); // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -146,15 +148,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, seller_file); else if (data_type == std::string("uint8")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, seller_file); else if (data_type == std::string("float")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, seller_file); else { diskann::cerr << "Error. Unsupported data type" << std::endl; @@ -166,15 +168,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, seller_file); else if (data_type == std::string("uint8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, seller_file); else if (data_type == std::string("float")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, seller_file); else { diskann::cerr << "Error. Unsupported data type" << std::endl; diff --git a/include/disk_utils.h b/include/disk_utils.h index 08f046dcd..f5d22e2c8 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -98,7 +98,7 @@ DISKANN_DLLEXPORT int build_disk_index( bool use_filters = false, const std::string &label_file = std::string(""), // default is empty string for no label_file const std::string &universal_label = "", const uint32_t filter_threshold = 0, - const uint32_t Lf = 0); // default is empty string for no universal label + const uint32_t Lf = 0, const std::string &seller_file = ""); // default is empty string for no universal label template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 51c42daed..d1d25dc57 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -1101,7 +1101,7 @@ template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, - const uint32_t Lf) + const uint32_t Lf, const std::string &seller_file) { std::stringstream parser; parser << std::string(indexBuildParameters); @@ -1368,6 +1368,11 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const if (use_disk_pq) std::remove(disk_pq_compressed_vectors_path.c_str()); + if (seller_file != "") { + std::string disk_index_seller_file = disk_index_path + "_sellers.txt"; + copy_file(seller_file, disk_index_seller_file); + } + auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; diskann::cout << "Indexing time: " << diff.count() << std::endl; @@ -1432,21 +1437,21 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); // LabelT = uint16 template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, @@ -1454,21 +1459,21 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index b6293c390..5510e0b15 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1063,7 +1063,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #ifndef EXEC_ENV_OLS // TODO: Make this friendly for DLVS - this->_seller_file = std ::string(_disk_index_file) + "_sellers.txt"; + this->_seller_file = std ::string(index_filepath) + "_sellers.txt"; + std::cout<_seller_file << std::endl; if(file_exists(this->_seller_file)) { uint64_t nrows_seller_file; parse_seller_file(this->_seller_file, nrows_seller_file); @@ -1715,7 +1716,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t if (use_reorder_data) { if (!(this->_reorder_data_exists)) - { + { throw ANNException("Requested use of reordering data which does " "not exist in index " "file", From 659041e1b12e3ab41c70ad12d11a29a883f17e15 Mon Sep 17 00:00:00 2001 From: pianand <123442561+pianand@users.noreply.github.com> Date: Tue, 18 Mar 2025 10:28:00 +0530 Subject: [PATCH 18/18] Added code for building diverse index (#633) --- apps/build_disk_index.cpp | 26 ++++++++----- include/disk_utils.h | 6 ++- include/program_options_utils.hpp | 1 + src/disk_utils.cpp | 62 +++++++++++++++++++++++-------- src/index.cpp | 2 +- 5 files changed, 68 insertions(+), 29 deletions(-) diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index 56e4f2723..c94df0866 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -17,10 +17,12 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, label_type, seller_file; - uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold, num_diverse_build; float B, M; bool append_reorder_data = false; bool use_opq = false; + bool diverse_index = false; + po::options_description desc{ program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")}; @@ -80,6 +82,9 @@ int main(int argc, char **argv) program_options_utils::LABEL_TYPE_DESCRIPTION); optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), "In case of diverse index, need the seller file"); + optional_configs.add_options()("NumDiverse", po::value(&num_diverse_build)->default_value(0), + program_options_utils::NUM_DIVERSE); + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -140,23 +145,24 @@ int main(int argc, char **argv) std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + std::string(std::to_string(append_reorder_data)) + " " + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); - + if(seller_file != "") + diverse_index = true; try { if (label_file != "" && label_type == "ushort") { if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf, seller_file); + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("uint8")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf, seller_file); + use_filters, label_file, universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("float")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf, seller_file); + use_filters, label_file, universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else { diskann::cerr << "Error. Unsupported data type" << std::endl; @@ -168,15 +174,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf, seller_file); + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("uint8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf, seller_file); + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("float")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf, seller_file); + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else { diskann::cerr << "Error. Unsupported data type" << std::endl; diff --git a/include/disk_utils.h b/include/disk_utils.h index f5d22e2c8..35e60dfea 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -82,7 +82,8 @@ DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann:: uint32_t num_threads, bool use_filters = false, const std::string &label_file = std::string(""), const std::string &labels_to_medoids_file = std::string(""), - const std::string &universal_label = "", const uint32_t Lf = 0); + const std::string &universal_label = "", const uint32_t Lf = 0, + bool diverse_index = false, const std::string &seller_file = std::string(""), size_t num_diverse_build = 0) ; template DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr> &_pFlashIndex, @@ -98,7 +99,8 @@ DISKANN_DLLEXPORT int build_disk_index( bool use_filters = false, const std::string &label_file = std::string(""), // default is empty string for no label_file const std::string &universal_label = "", const uint32_t filter_threshold = 0, - const uint32_t Lf = 0, const std::string &seller_file = ""); // default is empty string for no universal label + const uint32_t Lf = 0, // default is empty string for no universal label + bool diverse_index = false, const std::string &seller_file = std::string(""), const uint32_t num_diverse_build = 0); template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 3686e3885..383bdadaa 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -68,6 +68,7 @@ const char *GRAPH_BUILD_ALPHA = "Alpha controls density and diameter of graph, s "denser graphs with lower diameter"; const char *BUIlD_GRAPH_PQ_BYTES = "Number of PQ bytes to build the index; 0 for full precision build"; const char *USE_OPQ = "Use Optimized Product Quantization (OPQ)."; +const char *DIVERSE_INDEX = "Build Diverse Index"; const char *LABEL_FILE = "Input label file in txt format for Filtered Index build. The file should contain comma " "separated filters for each node with each line corresponding to a graph node"; const char *UNIVERSAL_LABEL = diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index d1d25dc57..f1af2abd9 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -630,7 +630,7 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, - const uint32_t Lf) + const uint32_t Lf, bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build) { size_t base_num, base_dim; diskann::get_bin_metadata(base_file, base_num, base_dim); @@ -647,6 +647,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr .with_filter_list_size(Lf) .with_saturate_graph(!use_filters) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(seller_file) + .with_num_diverse_build(num_diverse_build) .build(); using TagT = uint32_t; diskann::Index _index(compareMetric, base_dim, base_num, @@ -706,15 +709,24 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string shard_ids_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin"; std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; + std::string shard_sellers_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_sellers.txt"; retrieve_shard_data_from_ids(base_file, shard_ids_file, shard_base_file); std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; + if(diverse_index) + { + diskann::extract_shard_labels(seller_file, shard_ids_file, shard_sellers_file); + } + diskann::IndexWriteParameters low_degree_params = diskann::IndexWriteParametersBuilder(L, 2 * R / 3) .with_filter_list_size(Lf) .with_saturate_graph(false) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(shard_sellers_file) + .with_num_diverse_build(num_diverse_build) .build(); uint64_t shard_base_dim, shard_base_pts; @@ -724,8 +736,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::make_shared(low_degree_params), nullptr, defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, build_pq_bytes > 0, build_pq_bytes, use_opq); + if (!use_filters) - { + { _index.build(shard_base_file.c_str(), shard_base_pts); } else @@ -736,7 +749,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr LabelT unv_label_as_num = 0; _index.set_universal_label(unv_label_as_num); } + _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts); + } _index.save(shard_index_file.c_str()); // copy universal label file from first shard to the final destination @@ -768,11 +783,13 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; std::string shard_index_file_data = shard_index_file + ".data"; + std::string shard_sellers_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_sellers.txt"; std::remove(shard_base_file.c_str()); std::remove(shard_id_file.c_str()); std::remove(shard_index_file.c_str()); std::remove(shard_index_file_data.c_str()); + std::remove(shard_sellers_file.c_str()); if (use_filters) { std::string shard_index_label_file = shard_index_file + "_labels.txt"; @@ -1101,7 +1118,7 @@ template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, - const uint32_t Lf, const std::string &seller_file) + const uint32_t Lf, bool diverse_index, const std::string &seller_file, const uint32_t num_diverse_build) { std::stringstream parser; parser << std::string(indexBuildParameters); @@ -1326,7 +1343,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use, - labels_to_medoids_path, universal_label, Lf); + labels_to_medoids_path, universal_label, Lf, diverse_index, seller_file, num_diverse_build); diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl; timer.reset(); @@ -1437,21 +1454,24 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); // LabelT = uint16 template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, @@ -1459,51 +1479,61 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf, const std::string &seller_file); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); + template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); // Label=16_t template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); }; // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 873c7ceca..f12415dcf 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1070,7 +1070,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L scratch->clear(); _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false, maxLperSeller); for (auto unfiltered_neighbour : scratch->pool()) {