From 2be1d3c73165de55ed15e156a1ec4a6cca4e11d2 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Thu, 16 Oct 2025 16:14:20 +0800 Subject: [PATCH 01/10] change filtered search interface --- include/abstract_index.h | 4 +- include/index.h | 4 +- include/pq_flash_index.h | 4 +- src/abstract_index.cpp | 8 ++-- src/index.cpp | 92 +++++++++++++++++++++++----------------- src/pq_flash_index.cpp | 62 +++++++++++++++------------ 6 files changed, 100 insertions(+), 74 deletions(-) diff --git a/include/abstract_index.h b/include/abstract_index.h index 1bb94fa06..4f021adea 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -80,7 +80,7 @@ class AbstractIndex // Filter support search // IndexType is either uint32_t or uint64_t template - std::pair search_with_filters(const DataType &query, const std::string &raw_label, + std::pair search_with_filters(const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, float *distances); @@ -122,7 +122,7 @@ class AbstractIndex std::any &indices, float *distances = nullptr) = 0; virtual std::pair _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any& indices, float* distances = nullptr) = 0; - virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, + virtual std::pair _search_with_filters(const DataType &query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector &labels) = 0; diff --git a/include/index.h b/include/index.h index 33059b43b..f74f10d1e 100644 --- a/include/index.h +++ b/include/index.h @@ -152,7 +152,7 @@ template clas // Filter support search template - DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const LabelT &filter_label, + DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, float *distances); @@ -217,7 +217,7 @@ template clas virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, std::any &indices, float *distances = nullptr) override; virtual std::pair _search_with_filters(const DataType &query, - const std::string &filter_label_raw, const size_t K, + const std::vector &filter_labels_raw, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances) override; diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index c0ecaa73d..67850922d 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -64,7 +64,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const std::vector &filter_labels, uint32_t maxLperSeller = 0, const bool use_reorder_data = false, std::function rerank_fn = nullptr, @@ -79,7 +79,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const std::vector &filter_labels, const uint32_t io_limit, uint32_t maxLperSeller = 0, const bool use_reorder_data = false, std::function rerank_fn = nullptr, diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index a2c85e08e..4f27af092 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -43,12 +43,12 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, } template -std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, +std::pair AbstractIndex::search_with_filters(const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, float *distances) { auto any_indices = std::any(indices); - return _search_with_filters(query, raw_label, K, L, maxLperSeller, any_indices, distances); + return _search_with_filters(query, raw_labels, K, L, maxLperSeller, any_indices, distances); } template @@ -173,11 +173,11 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search AbstractIndex::search_with_filters( - const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( - const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + const DataType &query, const std::vector& raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( diff --git a/src/index.cpp b/src/index.cpp index 3d6a15c7a..3a94e93bd 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2433,20 +2433,27 @@ std::pair Index::search(const T *query, con template std::pair Index::_search_with_filters(const DataType &query, - const std::string &raw_label, const size_t K, + const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances) { - auto converted_label = this->get_converted_label(raw_label); + std::vector converted_labels; + converted_labels.reserve(raw_labels.size()); + for (const auto &raw_label : raw_labels) + { + auto converted_label = this->get_converted_label(raw_label); + converted_labels.push_back(converted_label); + } + if (typeid(uint64_t *) == indices.type()) { auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, K, L, maxLperSeller, ptr, distances); + return this->search_with_filters(std::any_cast(query), converted_labels, K, L, maxLperSeller, ptr, distances); } else if (typeid(uint32_t *) == indices.type()) { auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, K, L, maxLperSeller, ptr, distances); + return this->search_with_filters(std::any_cast(query), converted_labels, K, L, maxLperSeller, ptr, distances); } else { @@ -2456,7 +2463,7 @@ std::pair Index::_search_with_filters(const template template -std::pair Index::search_with_filters(const T *query, const LabelT &filter_label, +std::pair Index::search_with_filters(const T *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IdType *indices, float *distances) { @@ -2477,6 +2484,8 @@ std::pair Index::search_with_filters(const } std::vector filter_vec; + filter_vec.reserve(filter_labels.size()); + std::vector init_ids = get_init_ids(); std::shared_lock lock(_update_lock); @@ -2484,20 +2493,27 @@ std::pair Index::search_with_filters(const if (_dynamic_index) tl.lock(); - if (_label_to_start_id.find(filter_label) != _label_to_start_id.end()) + for (auto& filter_label : filter_labels) { - init_ids.emplace_back(_label_to_start_id[filter_label]); - } - else - { - diskann::cout << "No filtered medoid found. exitting " - << std::endl; // RKNOTE: If universal label found start there - throw diskann::ANNException("No filtered medoid found. exitting ", -1); + if (_label_to_start_id.find(filter_label) != _label_to_start_id.end()) + { + init_ids.emplace_back(_label_to_start_id[filter_label]); + } + else + { + diskann::cout << "No filtered medoid found. exitting " + << std::endl; // RKNOTE: If universal label found start there + throw diskann::ANNException("No filtered medoid found. exitting ", -1); + } } + if (_dynamic_index) tl.unlock(); - filter_vec.emplace_back(filter_label); + for (auto& filter_label : filter_labels) + { + filter_vec.emplace_back(filter_label); + } _data_store->preprocess_query(query, scratch); auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller); @@ -3772,41 +3788,41 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const float *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const float *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const float *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const float *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( @@ -3836,40 +3852,40 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const float *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const float *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const float *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const float *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, + uint64_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, + uint32_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); } // namespace diskann diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 990f40bea..1f2f1e49d 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1140,18 +1140,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t QueryStats *stats) { cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), - maxLperSeller, use_reorder_data, stats); + maxLperSeller, use_reorder_data, nullptr, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const std::vector &filter_labels, uint32_t maxLperSeller, const bool use_reorder_data, std::function rerank_fn, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_labels, std::numeric_limits::max(), maxLperSeller, use_reorder_data, rerank_fn, stats); } @@ -1163,15 +1163,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::function rerank_fn, QueryStats *stats) { - LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + std::vector dummy_filters; + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filters, io_limit, maxLperSeller, use_reorder_data, rerank_fn, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const std::vector &filter_labels, const uint32_t io_limit, uint32_t maxLperSeller, const bool use_reorder_data, std::function rerank_fn, @@ -1212,8 +1212,11 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t query_bitmask_buf.resize(_bitmask_buf._bitmask_size, 0); bitmask_full_val._mask = query_bitmask_buf.data(); - auto bitmask_val = simple_bitmask::get_bitmask_val(filter_label); - bitmask_full_val.merge_bitmask_val(bitmask_val); + for (const auto& filter_label : filter_labels) + { + auto bitmask_val = simple_bitmask::get_bitmask_val(filter_label); + bitmask_full_val.merge_bitmask_val(bitmask_val); + } if (_use_universal_label) { @@ -1312,35 +1315,42 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t best_dist = cur_expanded_dist; } } + + compute_dists(&best_medoid, 1, dist_scratch); + retset->insert(Neighbor(best_medoid, dist_scratch[0])); + visited.insert(best_medoid); } else { - if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) + for (const auto& filter_label : filter_labels) { - const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; - for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) + if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) { - // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance - // as approximation to decide closest medoid matching the query filter. - compute_dists(&medoid_ids[cur_m], 1, dist_scratch); - float cur_expanded_dist = dist_scratch[0]; - if (cur_expanded_dist < best_dist) + const auto& medoid_ids = _filter_to_medoid_ids[filter_label]; + for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) { - best_medoid = medoid_ids[cur_m]; - best_dist = cur_expanded_dist; + // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance + // as approximation to decide closest medoid matching the query filter. + compute_dists(&medoid_ids[cur_m], 1, dist_scratch); + float cur_expanded_dist = dist_scratch[0]; + if (cur_expanded_dist < best_dist) + { + best_medoid = medoid_ids[cur_m]; + best_dist = cur_expanded_dist; + } } } - } - else - { - throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + else + { + throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + compute_dists(&best_medoid, 1, dist_scratch); + retset->insert(Neighbor(best_medoid, dist_scratch[0])); + visited.insert(best_medoid); } } - compute_dists(&best_medoid, 1, dist_scratch); - retset->insert(Neighbor(best_medoid, dist_scratch[0])); - visited.insert(best_medoid); - uint32_t cmps = 0; uint32_t hops = 0; uint32_t num_ios = 0; From 6ecedd971c895d17f17170faef4190e5daf4ca21 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Thu, 16 Oct 2025 19:52:29 +0800 Subject: [PATCH 02/10] fix streaming interface --- include/abstract_index.h | 8 ++++---- include/index.h | 8 ++++---- src/index.cpp | 34 +++++++++++++++++++--------------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/include/abstract_index.h b/include/abstract_index.h index 4f021adea..f472dbfaa 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -63,8 +63,8 @@ class AbstractIndex // Initialize space for res_vectors before calling. template size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, - float *distances, std::vector &res_vectors, bool use_filters = false, - const std::string filter_label = ""); + float *distances, std::vector &res_vectors, bool use_filters, + const std::vector& filter_labels); // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" @@ -133,8 +133,8 @@ class AbstractIndex virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, bool use_filters = false, - const std::string filter_label = "") = 0; + float *distances, DataVector &res_vectors, bool use_filters, + const std::vector& filter_labels) = 0; virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; virtual void _set_universal_label(const LabelType universal_label) = 0; }; diff --git a/include/index.h b/include/index.h index f74f10d1e..b82fe810c 100644 --- a/include/index.h +++ b/include/index.h @@ -144,8 +144,8 @@ template clas // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, - float *distances, std::vector &res_vectors, bool use_filters = false, - const std::string filter_label = ""); + float *distances, std::vector &res_vectors, bool use_filters, + const std::vector& filter_labels); virtual std::pair _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any& indices, float* distances = nullptr) override; @@ -237,8 +237,8 @@ template clas virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, bool use_filters = false, - const std::string filter_label = "") override; + float *distances, DataVector &res_vectors, bool use_filters, + const std::vector& filter_labels) override; virtual void _set_universal_label(const LabelType universal_label) override; diff --git a/src/index.cpp b/src/index.cpp index 3a94e93bd..0fdd754d2 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2559,12 +2559,12 @@ std::pair Index::search_with_filters(const template size_t Index::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors, - bool use_filters, const std::string filter_label) + bool use_filters, const std::vector& filter_labels) { try { return this->search_with_tags(std::any_cast(query), K, L, std::any_cast(tags), distances, - res_vectors.get>(), use_filters, filter_label); + res_vectors.get>(), use_filters, filter_labels); } catch (const std::bad_any_cast &e) { @@ -2579,7 +2579,7 @@ size_t Index::_search_with_tags(const DataType &query, const ui template size_t Index::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label) + const std::vector& filter_labels) { if (K > (uint64_t)L) { @@ -2611,20 +2611,24 @@ size_t Index::search_with_tags(const T *query, const uint64_t K else { std::vector filter_vec; - auto converted_label = this->get_converted_label(filter_label); - - if (_label_to_start_id.find(converted_label) != _label_to_start_id.end()) - { - init_ids.emplace_back(_label_to_start_id[converted_label]); - } - else + for (const auto& filter_label : filter_labels) { - diskann::cout << "No filtered medoid found. exitting " - << std::endl; // RKNOTE: If universal label found start there - throw diskann::ANNException("No filtered medoid found. exitting ", -1); - } + auto converted_label = this->get_converted_label(filter_label); - filter_vec.push_back(converted_label); + if (_label_to_start_id.find(converted_label) != _label_to_start_id.end()) + { + init_ids.emplace_back(_label_to_start_id[converted_label]); + } + else + { + diskann::cout << "No filtered medoid found. exitting " + << std::endl; // RKNOTE: If universal label found start there + throw diskann::ANNException("No filtered medoid found. exitting ", -1); + } + + filter_vec.push_back(converted_label); + } + iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); } From 506718d21bf69ec1267d27d33ce2b607c506031b Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Thu, 16 Oct 2025 22:20:59 +0800 Subject: [PATCH 03/10] fix compile issue --- src/abstract_index.cpp | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 4f27af092..2084bac6d 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -34,12 +34,12 @@ std::pair AbstractIndex::diverse_search(const data_type* que template size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, bool use_filters, - const std::string filter_label) + const std::vector& filter_labels) { auto any_query = std::any(query); auto any_tags = std::any(tags); auto any_res_vectors = DataVector(res_vectors); - return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_label); + return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors, use_filters, filter_labels); } template @@ -196,63 +196,63 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, bool use_filters, const std::string filter_label); + std::vector &res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float* query, const uint64_t K, const uint32_t L, tag_uint128* tags, float* distances, - std::vector& res_vectors, bool use_filters, const std::string filter_label); + std::vector& res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const uint8_t* query, const uint64_t K, const uint32_t L, tag_uint128* tags, float* distances, - std::vector& res_vectors, bool use_filters, const std::string filter_label); + std::vector& res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const int8_t* query, const uint64_t K, const uint32_t L, tag_uint128* tags, float* distances, - std::vector& res_vectors, bool use_filters, const std::string filter_label); + std::vector& res_vectors, bool use_filters, const std::vector& filter_labels); template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, size_t L, uint32_t *indices); From 595f3a0fdce6c4a750f5fe2d014b862cba65ee7a Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 27 Oct 2025 16:09:35 +0800 Subject: [PATCH 04/10] support integer filter label --- include/abstract_index.h | 2 +- include/filter_match_proxy.h | 66 +++++++++ include/index.h | 19 ++- include/integer_label_vector.h | 43 ++++++ include/label_helper.h | 109 ++++++++++++++ include/parameters.h | 7 + include/pq_flash_index.h | 8 +- src/filter_match_proxy.cpp | 93 ++++++++++++ src/index.cpp | 264 +++++++++++++++++++++------------ src/integer_label_vector.cpp | 198 +++++++++++++++++++++++++ src/label_helper.cpp | 112 ++++++++++++++ src/pq_flash_index.cpp | 69 ++++----- 12 files changed, 855 insertions(+), 135 deletions(-) create mode 100644 include/filter_match_proxy.h create mode 100644 include/integer_label_vector.h create mode 100644 src/filter_match_proxy.cpp create mode 100644 src/integer_label_vector.cpp diff --git a/include/abstract_index.h b/include/abstract_index.h index f472dbfaa..5ec0fd6bd 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -53,7 +53,7 @@ class AbstractIndex #ifdef EXEC_ENV_OLS virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0; #else - virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false) = 0; + virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String) = 0; #endif // For FastL2 search on optimized layout diff --git a/include/filter_match_proxy.h b/include/filter_match_proxy.h new file mode 100644 index 000000000..51ec52e9e --- /dev/null +++ b/include/filter_match_proxy.h @@ -0,0 +1,66 @@ +#pragma once +#include "label_bitmask.h" +#include "integer_label_vector.h" + +namespace diskann +{ + + class filter_match_proxy + { + public: + virtual bool contain_filtered_label(uint32_t id) = 0; + }; + + template + class bitmask_filter_match : public filter_match_proxy + { + public: + bitmask_filter_match(simple_bitmask_buf& bitmask_filters, + std::vector& query_bitmask_buf, + const std::vector& filter_labels, + LabelT unv_label); + + virtual bool contain_filtered_label(uint32_t id) override; + + private: + simple_bitmask_buf& _bitmask_filters; + std::vector& _query_bitmask_buf; + simple_bitmask_full_val _bitmask_full_val; + }; + + template + class integer_label_filter_match : public filter_match_proxy + { + public: + integer_label_filter_match(integer_label_vector& label_vector, + const std::vector& filter_labels, + LabelT unv_label); + + virtual bool contain_filtered_label(uint32_t id) override; + + private: + integer_label_vector& _label_vector; + const std::vector& _filter_labels; + LabelT _unv_label; + }; + +template +class label_filter_match_holder : public filter_match_proxy +{ +public: + label_filter_match_holder(simple_bitmask_buf& bitmask_filters, + std::vector& query_bitmask_buf, + integer_label_vector& label_vector, + const std::vector& filter_labels, + LabelT unv_label, + bool use_integer_labels); + + virtual bool contain_filtered_label(uint32_t id) override; + +private: + bitmask_filter_match _bitmask_filter_match; + integer_label_filter_match _integer_label_filter_match; + bool _use_integer_labels; +}; + +} \ No newline at end of file diff --git a/include/index.h b/include/index.h index b82fe810c..2a007fa18 100644 --- a/include/index.h +++ b/include/index.h @@ -24,6 +24,7 @@ #include "percentile_stats.h" #include #include "label_bitmask.h" +#include "integer_label_vector.h" #include "quantized_distance.h" #include "pq_data_store.h" @@ -80,7 +81,7 @@ template clas #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l); #else - DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false); + DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String); #endif // get some private variables @@ -118,6 +119,10 @@ template clas DISKANN_DLLEXPORT bool is_set_universal_label() const override; + DISKANN_DLLEXPORT void enable_integer_label(); + + DISKANN_DLLEXPORT bool integer_label_enabled(); + // Set starting point of an index before inserting any points incrementally. // The data count should be equal to _num_frozen_pts * _aligned_dim. DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count); @@ -253,11 +258,18 @@ template clas // determines navigating node of the graph by calculating medoid of datafopt uint32_t calculate_entry_point(); - void parse_label_file(const std::string &label_file, size_t &num_pts_labels); + void parse_label_file(const std::string &label_file, size_t &num_pts_labels, size_t& total_labels); void parse_seller_file(const std::string& label_file, size_t& num_pts_labels); void convert_pts_label_to_bitmask(std::vector>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels); + void convert_pts_label_to_integer_vector(std::vector> &pts_to_labels, + integer_label_vector &int_label_vector, size_t total_labels); + + void aggregate_points_by_bitmask_label(std::unordered_map>& label_to_points, size_t num_points_to_load); + + void aggregate_points_by_integer_label(std::unordered_map>& label_to_points, size_t num_points_to_load); + std::unordered_map load_label_map(const std::string &map_file); // Returns the locations of start point and frozen points suitable for use @@ -463,6 +475,9 @@ template clas simple_bitmask_buf _bitmask_buf; + bool _use_integer_labels = false; + integer_label_vector _label_vector; + TableStats _table_stats; static const float INDEX_GROWTH_FACTOR; diff --git a/include/integer_label_vector.h b/include/integer_label_vector.h new file mode 100644 index 000000000..b2ec405fe --- /dev/null +++ b/include/integer_label_vector.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include + +namespace diskann +{ + +class integer_label_vector +{ +public: + bool initialize(size_t numpoints, size_t total_labels); + + bool initialize_from_file(const std::string &label_file, size_t &numpoints); + + bool write_to_file(const std::string &label_file) const; + + template + bool add_labels(uint32_t point_id, std::vector &labels); + + bool check_label_exists(uint32_t point_id, uint32_t label); + + template + bool check_label_exists(uint32_t point_id, const std::vector &labels); + + bool check_label_full_contain(uint32_t point_id, const std::vector &labels); + + bool check_label_full_contain(uint32_t source_point, uint32_t target_point); + + const std::vector &get_offset_vector() const; + + const std::vector &get_data_vector() const; + + size_t get_memory_usage() const; + + private: + bool binary_search(size_t start, size_t end, uint32_t label, size_t& last_check); + +private: + std::vector _offset; + std::vector _data; +}; + +} \ No newline at end of file diff --git a/include/label_helper.h b/include/label_helper.h index f59b88b11..ef60bfb10 100644 --- a/include/label_helper.h +++ b/include/label_helper.h @@ -1,6 +1,8 @@ #pragma once #include "label_bitmask.h" +#include "integer_label_vector.h" #include "percentile_stats.h" +#include "tsl/robin_set.h" #include namespace diskann @@ -20,6 +22,113 @@ class label_helper bool read_bitmask_from_file(const std::string &bitmask_label_file, simple_bitmask_buf &bitmask_buf, size_t& num_points); + + bool parse_label_file_in_integer( + const std::string& label_file, + size_t& num_points, + integer_label_vector& integer_vector, + tsl::robin_set& labels, TableStats &table_stats); + + template + bool load_label_map( + const std::string& label_map_file, + std::unordered_map& label_map) + { + std::ifstream infile(label_map_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_map_file, -1); + } + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + infile.close(); + + unsigned line_cnt = 0; + + size_t cur_pos = 0; + size_t next_pos = 0; + size_t lbl_pos = 0; + std::string token; + std::string labe_str; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + lbl_pos = search_string_range(buffer, '\t', cur_pos, next_pos); + labe_str.assign(buffer.c_str() + cur_pos, lbl_pos - cur_pos); + + token.assign(buffer.c_str() + lbl_pos + 1, next_pos - lbl_pos - 1); + LabelT label_num = (LabelT)std::stoul(token); + + label_map[labe_str] = label_num; + + cur_pos = next_pos + 1; + + line_cnt++; + } + + return true; + } + + template + bool load_label_medoids( + const std::string& label_medoids_file, + std::unordered_map& label_to_start_id) + { + std::ifstream infile(label_medoids_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_medoids_file, -1); + } + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + infile.close(); + + unsigned line_cnt = 0; + + size_t cur_pos = 0; + size_t next_pos = 0; + size_t lbl_pos = 0; + std::string token; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + lbl_pos = search_string_range(buffer, ',', cur_pos, next_pos); + token.assign(buffer.c_str() + cur_pos, lbl_pos - cur_pos); + LabelT label_num = (LabelT)std::stoul(token); + + token.assign(buffer.c_str() + lbl_pos + 1, next_pos - lbl_pos - 1); + uint32_t medoid = (uint32_t)std::stoul(token); + + label_to_start_id[label_num] = medoid; + + cur_pos = next_pos + 1; + + line_cnt++; + } + + return true; + } + private: size_t search_string_range(const std::string& str, char ch, size_t start, size_t end); }; diff --git a/include/parameters.h b/include/parameters.h index a20fea693..01a8b834c 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -12,6 +12,13 @@ namespace diskann { +enum class LabelFormatType :uint8_t +{ + String = 0, + BitMask = 1, + Integer = 2 +}; + class IndexWriteParameters { diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 67850922d..10a40edcb 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -16,6 +16,7 @@ #include "tsl/robin_map.h" #include "tsl/robin_set.h" #include "label_bitmask.h" +#include "integer_label_vector.h" #define FULL_PRECISION_REORDER_MULTIPLIER 3 @@ -47,7 +48,7 @@ template class PQFlashIndex const char* labels_filepath, const char* labels_to_medoids_filepath, const char* labels_map_filepath, const char* unv_label_filepath, const char* seller_filepath, - bool load_bitmask_label = false); + LabelFormatType label_format_type = LabelFormatType::String); #endif DISKANN_DLLEXPORT void load_cache_list(std::vector &node_list); @@ -234,9 +235,12 @@ template class PQFlashIndex // filter support simple_bitmask_buf _bitmask_buf; + bool _use_integer_labels = false; + integer_label_vector _label_vector; + std::unordered_map> _filter_to_medoid_ids; bool _use_universal_label = false; - LabelT _universal_filter_label; + LabelT _universal_filter_label = 0; tsl::robin_set _dummy_pts; tsl::robin_set _has_dummy_pts; tsl::robin_map _dummy_to_real_map; diff --git a/src/filter_match_proxy.cpp b/src/filter_match_proxy.cpp new file mode 100644 index 000000000..4ba606a50 --- /dev/null +++ b/src/filter_match_proxy.cpp @@ -0,0 +1,93 @@ +#include "filter_match_proxy.h" + +namespace diskann +{ + +template +bitmask_filter_match::bitmask_filter_match( + simple_bitmask_buf& bitmask_filters, + std::vector& query_bitmask_buf, + const std::vector& filter_labels, + LabelT unv_label) + : _bitmask_filters(bitmask_filters), + _query_bitmask_buf(query_bitmask_buf) +{ + // _bitmask_size == 0 means no filter is set + if (_bitmask_filters._bitmask_size > 0) + { + query_bitmask_buf.resize(_bitmask_filters._bitmask_size, 0); + _bitmask_full_val._mask = query_bitmask_buf.data(); + + for (const auto& filter_label : filter_labels) + { + auto bitmask_val = simple_bitmask::get_bitmask_val(filter_label); + _bitmask_full_val.merge_bitmask_val(bitmask_val); + } + + // if unv isn't set, it will be default value 0 + auto bitmask_val = simple_bitmask::get_bitmask_val(unv_label); + _bitmask_full_val.merge_bitmask_val(bitmask_val); + } +} + +template +bool bitmask_filter_match::contain_filtered_label(uint32_t id) +{ + simple_bitmask bm(_bitmask_filters.get_bitmask(id), _bitmask_filters._bitmask_size); + + return bm.test_full_mask_val(_bitmask_full_val); +} + +template +integer_label_filter_match::integer_label_filter_match( + integer_label_vector& label_vector, + const std::vector& filter_labels, + LabelT unv_label) + : _label_vector(label_vector), + _filter_labels(filter_labels), + _unv_label(unv_label) +{ +} + +template +bool integer_label_filter_match::contain_filtered_label(uint32_t id) +{ + // if unv isn't set, it will be default value 0, and there will be no match + return _label_vector.check_label_exists(id, _filter_labels) + || _label_vector.check_label_exists(id, _unv_label); +} + +template +label_filter_match_holder::label_filter_match_holder(simple_bitmask_buf& bitmask_filters, + std::vector& query_bitmask_buf, + integer_label_vector& label_vector, + const std::vector& filter_labels, + LabelT unv_label, + bool use_integer_labels) + : _bitmask_filter_match(bitmask_filters, query_bitmask_buf, filter_labels, unv_label), + _integer_label_filter_match(label_vector, filter_labels, unv_label), + _use_integer_labels(use_integer_labels) +{ +} + +template +bool label_filter_match_holder::contain_filtered_label(uint32_t id) +{ + if (_use_integer_labels) + { + return _integer_label_filter_match.contain_filtered_label(id); + } + else + { + return _bitmask_filter_match.contain_filtered_label(id); + } +} + +template class bitmask_filter_match; +template class bitmask_filter_match; +template class integer_label_filter_match; +template class integer_label_filter_match; +template class label_filter_match_holder; +template class label_filter_match_holder; + +} \ No newline at end of file diff --git a/src/index.cpp b/src/index.cpp index 0fdd754d2..6951fc14d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -16,6 +16,8 @@ #include "tag_uint128.h" #include "label_helper.h" #include "color_helper.h" +#include "filter_match_proxy.h" + #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif @@ -384,6 +386,21 @@ void Index::save(const char *filename, bool compact_before_save throw diskann::ANNException(std::string("Failed to save bitmask labels to ") + bitmask_label_file, -1); } } + + if (_use_integer_labels) + { + std::string integer_label_file = std::string(filename) + "_integer_labels.bin"; + if (_label_vector.write_to_file(integer_label_file)) + { + diskann::cout << "Integer labels saved to " << integer_label_file << std::endl; + } + else + { + diskann::cerr << "Failed to save integer labels to " << integer_label_file << std::endl; + throw diskann::ANNException(std::string("Failed to save integer labels to ") + integer_label_file, + -1); + } + } } std::string graph_file = std::string(filename); @@ -566,7 +583,7 @@ template void Index::load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) { #else -void Index::load(const char *filename, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile) +void Index::load(const char *filename, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type) { #endif std::unique_lock ul(_update_lock); @@ -658,9 +675,13 @@ void Index::load(const char *filename, uint32_t num_threads, ui _diverse_index = true; } - if (file_exists(labels_file)) + std::string bitmask_label_file = std::string(filename) + "_bitmask_labels.bin"; + std::string integer_label_file = std::string(filename) + "_integer_labels.bin"; + if (file_exists(labels_file) + || file_exists(bitmask_label_file) + || file_exists(integer_label_file)) { - _label_map = load_label_map(labels_map_file); + label_helper().load_label_map(labels_map_file, _label_map); this->_table_stats.label_count = _label_map.size(); if (_enable_tags) @@ -668,7 +689,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui // resize bitmask buffer to max points label_num_pts = _max_points; } - if (!loadBitmaskLabelFile) + if (label_format_type == LabelFormatType::String) { label_helper().parse_label_file_in_bitset( labels_file, @@ -676,11 +697,11 @@ void Index::load(const char *filename, uint32_t num_threads, ui _label_map.size(), _bitmask_buf, _table_stats); + this->_table_stats.label_mem_usage = _bitmask_buf._buf.size() * sizeof(std::uint64_t); } - else + else if (label_format_type == LabelFormatType::BitMask) { // load bitmask labels from file - std::string bitmask_label_file = std::string(filename) + "_bitmask_labels.bin"; if (label_helper().read_bitmask_from_file(bitmask_label_file, _bitmask_buf, label_num_pts)) { diskann::cout << "Bitmask labels loaded from " << bitmask_label_file << std::endl; @@ -691,39 +712,31 @@ void Index::load(const char *filename, uint32_t num_threads, ui throw diskann::ANNException(std::string("Failed to load bitmask labels from ") + bitmask_label_file, -1); } + + this->_table_stats.label_mem_usage = _bitmask_buf._buf.size() * sizeof(std::uint64_t); } - - assert(label_num_pts == data_file_num_pts); - this->_table_stats.label_mem_usage = _bitmask_buf._buf.size() * sizeof(std::uint64_t); - if (file_exists(labels_to_medoids)) + else if (label_format_type == LabelFormatType::Integer) { - std::ifstream medoid_stream(labels_to_medoids); - std::string line, token; - uint32_t line_cnt = 0; - - _label_to_start_id.clear(); - - while (std::getline(medoid_stream, line)) + // load integer labels from file + if (_label_vector.initialize_from_file(integer_label_file, label_num_pts)) { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t medoid = 0; - LabelT label; - while (std::getline(iss, token, ',')) - { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); - if (cnt == 0) - label = token_as_num; - else - medoid = token_as_num; - cnt++; - } - _label_to_start_id[label] = medoid; - line_cnt++; + diskann::cout << "Integer labels loaded from " << integer_label_file << std::endl; + _use_integer_labels = true; + } + else + { + diskann::cerr << "Failed to load integer labels from " << integer_label_file << std::endl; + throw diskann::ANNException(std::string("Failed to load integer labels from ") + integer_label_file, + -1); } + + this->_table_stats.label_mem_usage = _label_vector.get_memory_usage(); } + + assert(label_num_pts == data_file_num_pts); + + _label_to_start_id.clear(); + label_helper().load_label_medoids(labels_to_medoids, _label_to_start_id); std::string universal_label_file(filename); universal_label_file += "_universal_label.txt"; @@ -935,34 +948,13 @@ std::pair Index::iterate_to_fixed_point( _pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch); }; - // only support one filter label - std::array local_buf; - simple_bitmask_full_val bitmask_full_val; - if (use_filter) - { - if (_bitmask_buf._bitmask_size <= 10) - { - local_buf.fill(0); - bitmask_full_val._mask = local_buf.data(); - } - else - { - query_bitmask_buf.resize(_bitmask_buf._bitmask_size, 0); - bitmask_full_val._mask = query_bitmask_buf.data(); - } - - for (size_t i = 0; i < filter_labels.size(); i++) - { - auto bitmask_val = simple_bitmask::get_bitmask_val(filter_labels[i]); - bitmask_full_val.merge_bitmask_val(bitmask_val); - } - - if (_use_universal_label) - { - auto bitmask_val = simple_bitmask::get_bitmask_val(_universal_label); - bitmask_full_val.merge_bitmask_val(bitmask_val); - } - } + label_filter_match_holder match_proxy( + _bitmask_buf, + query_bitmask_buf, + _label_vector, + filter_labels, + _universal_label, + _use_integer_labels); // Initialize the candidate pool with starting points for (auto id : init_ids) @@ -976,9 +968,7 @@ std::pair Index::iterate_to_fixed_point( if (use_filter) { - simple_bitmask bm(_bitmask_buf.get_bitmask(id), _bitmask_buf._bitmask_size); - - if (!bm.test_full_mask_val(bitmask_full_val)) + if (!match_proxy.contain_filtered_label(id)) { continue; } @@ -1052,10 +1042,7 @@ std::pair Index::iterate_to_fixed_point( cmps++; if (use_filter) { - // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. - simple_bitmask bm(_bitmask_buf.get_bitmask(id), _bitmask_buf._bitmask_size); - - if (!bm.test_full_mask_val(bitmask_full_val)) + if (!match_proxy.contain_filtered_label(id)) { continue; } @@ -1084,10 +1071,7 @@ std::pair Index::iterate_to_fixed_point( cmps++; if (use_filter) { - // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. - simple_bitmask bm(_bitmask_buf.get_bitmask(id), _bitmask_buf._bitmask_size); - - if (!bm.test_full_mask_val(bitmask_full_val)) + if (!match_proxy.contain_filtered_label(id)) { continue; } @@ -1335,10 +1319,18 @@ void Index::occlude_list(const uint32_t location, std::vectorid; uint32_t b = iter2->id; - simple_bitmask bm1(_bitmask_buf.get_bitmask(a), _bitmask_buf._bitmask_size); - simple_bitmask bm2(_bitmask_buf.get_bitmask(b), _bitmask_buf._bitmask_size); + if (!_use_integer_labels) + { + simple_bitmask bm1(_bitmask_buf.get_bitmask(a), _bitmask_buf._bitmask_size); + simple_bitmask bm2(_bitmask_buf.get_bitmask(b), _bitmask_buf._bitmask_size); + prune_allowed = bm1.test_full_mask_contain(bm2); + } + else + { + prune_allowed = _label_vector.check_label_full_contain(a, b); + } - prune_allowed = bm1.test_full_mask_contain(bm2); + } if (!prune_allowed) @@ -2074,6 +2066,18 @@ bool Index::is_set_universal_label() const return _use_universal_label; } +template +void Index::enable_integer_label() +{ + _use_integer_labels = true; +} + +template +bool Index::integer_label_enabled() +{ + return _use_integer_labels; +} + template bool Index::is_label_valid(const std::string& raw_label) const { @@ -2086,7 +2090,7 @@ bool Index::is_label_valid(const std::string& raw_label) const } template -void Index::parse_label_file(const std::string &label_file, size_t &num_points) +void Index::parse_label_file(const std::string &label_file, size_t &num_points, size_t& total_labels) { // Format of Label txt file: filters with comma separators @@ -2116,6 +2120,8 @@ void Index::parse_label_file(const std::string &label_file, siz infile.seekg(0, std::ios::beg); line_cnt = 0; + total_labels = 0; + while (std::getline(infile, line)) { std::istringstream iss(line); @@ -2134,6 +2140,7 @@ void Index::parse_label_file(const std::string &label_file, siz std::sort(lbls.begin(), lbls.end()); _location_to_labels[line_cnt] = lbls; line_cnt++; + total_labels += lbls.size(); } num_points = (size_t)line_cnt; diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; @@ -2155,6 +2162,18 @@ void Index::convert_pts_label_to_bitmask(std::vector +void Index::convert_pts_label_to_integer_vector( + std::vector>& pts_to_labels, + integer_label_vector& int_label_vector, size_t total_labels) +{ + int_label_vector.initialize(pts_to_labels.size(), total_labels); + for (size_t i = 0; i < pts_to_labels.size(); i++) + { + int_label_vector.add_labels(static_cast(i), pts_to_labels[i]); + } +} + template void Index::_set_universal_label(const LabelType universal_label) { @@ -2216,27 +2235,17 @@ void Index::parse_seller_file(const std::string& label_file, si } template -void Index::build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, const std::vector &tags) +void Index::aggregate_points_by_bitmask_label( + std::unordered_map>& label_to_points, + size_t num_points_to_load) { - _filtered_index = true; - _label_to_start_id.clear(); - size_t num_points_labels = 0; - - parse_label_file(label_file, - num_points_labels); // determines medoid for each label and identifies - // the points to label mapping - - convert_pts_label_to_bitmask(_location_to_labels, _bitmask_buf, _labels.size()); - - std::unordered_map> label_to_points; std::vector label_bitmask; for (int lbl = 0; lbl < _labels.size(); lbl++) { auto itr = _labels.begin(); std::advance(itr, lbl); - auto &x = *itr; - + auto& x = *itr; + label_bitmask.clear(); label_bitmask.resize(_bitmask_buf._bitmask_size, 0); @@ -2264,6 +2273,64 @@ void Index::build_filtered_index(const char *filename, const st label_to_points[x] = labeled_points; } +} + +template +void Index::aggregate_points_by_integer_label( + std::unordered_map>& label_to_points, + size_t num_points_to_load) +{ + for (int lbl = 0; lbl < _labels.size(); lbl++) + { + auto itr = _labels.begin(); + std::advance(itr, lbl); + auto& x = *itr; + + std::vector labeled_points; + for (uint32_t point_id = 0; point_id < num_points_to_load; point_id++) + { + if (_label_vector.check_label_exists(point_id, x) + || (_use_universal_label && _label_vector.check_label_exists(point_id, _universal_label))) + { + labeled_points.emplace_back(point_id); + } + } + + label_to_points[x] = labeled_points; + } +} + +template +void Index::build_filtered_index(const char *filename, const std::string &label_file, + const size_t num_points_to_load, const std::vector &tags) +{ + _filtered_index = true; + _label_to_start_id.clear(); + size_t num_points_labels = 0; + size_t total_labels = 0; + + parse_label_file(label_file, + num_points_labels, total_labels); // determines medoid for each label and identifies + // the points to label mapping + + if (!_use_integer_labels) + { + convert_pts_label_to_bitmask(_location_to_labels, _bitmask_buf, _labels.size()); + } + else + { + convert_pts_label_to_integer_vector(_location_to_labels, _label_vector, total_labels); + } + + std::unordered_map> label_to_points; + if (!_use_integer_labels) + { + aggregate_points_by_bitmask_label(label_to_points, num_points_to_load); + } + else + { + aggregate_points_by_integer_label(label_to_points, num_points_to_load); + } uint32_t num_cands = 25; for (auto itr = _labels.begin(); itr != _labels.end(); itr++) @@ -2890,6 +2957,12 @@ template void Indexsize() > 0) { throw ANNException("Can not compact data when index has non-empty _delete_set of " @@ -3273,6 +3346,12 @@ int Index::insert_point(const T *point, const TagT tag, const s << std::endl; return -1; } + + if (_use_integer_labels) + { + std::cerr << "Error: integer labels isn't support in streaming " << std::endl; + return -1; + } } std::shared_lock shared_ul(_update_lock); @@ -3731,7 +3810,6 @@ TableStats Index::get_table_stats() const /* Internals of the library */ template const float Index::INDEX_GROWTH_FACTOR = 1.5f; - // EXPORTS template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; diff --git a/src/integer_label_vector.cpp b/src/integer_label_vector.cpp new file mode 100644 index 000000000..9b12f65bd --- /dev/null +++ b/src/integer_label_vector.cpp @@ -0,0 +1,198 @@ +#include "integer_label_vector.h" +#include "ann_exception.h" +#include +#include + +namespace diskann +{ + +bool integer_label_vector::initialize(size_t numpoints, size_t total_labels) { + _offset.resize(numpoints + 1); + _offset[0] = 0; + + _data.reserve(total_labels); + return true; +} + +bool integer_label_vector::initialize_from_file(const std::string& label_file, size_t& numpoints) +{ + //format: + // format version: uint8 + // num_points: uint32 + // offset array: uint64[num_points + 1] + // label data: uint32[total_labels] + std::ifstream infile(label_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + uint8_t format_version = 0; + infile.read((char*)(&format_version), sizeof(uint8_t)); + if (format_version != 1) + { + throw diskann::ANNException(std::string("Unsupported label file format version ") + + std::to_string(format_version), + -1); + } + uint32_t num_points_in_file = 0; + infile.read((char*)(&num_points_in_file), sizeof(uint32_t)); + _offset.resize(num_points_in_file + 1); + infile.read((char *)_offset.data(), _offset.size() * sizeof(size_t)); + size_t total_labels = _offset[num_points_in_file]; + _data.resize(total_labels); + infile.read((char *)_data.data(), _data.size() * sizeof(uint32_t)); + infile.close(); + + numpoints = static_cast(num_points_in_file); + + return true; +} + +template +bool integer_label_vector::add_labels(uint32_t point_id, std::vector &labels) { + if (point_id >= _offset.size() - 1) + { + return false; + } + + auto start = _offset[point_id]; + for (const auto &label : labels) { + _data.push_back(static_cast(label)); + } + + _offset[point_id + 1] = _data.size(); + + return true; +} + +bool integer_label_vector::check_label_exists(uint32_t point_id, uint32_t label) { + if (point_id >= _offset.size() - 1) return false; + + auto start = _offset[point_id]; + auto end = _offset[point_id + 1]; + size_t last_check = 0; + return binary_search(start, end, label, last_check); +} + +template +bool integer_label_vector::check_label_exists(uint32_t point_id, const std::vector &labels) { + if (point_id >= _offset.size() - 1) return false; + + auto start = _offset[point_id]; + auto end = _offset[point_id + 1]; + + for (const auto &label : labels) { + size_t last_check = 0; + if (binary_search(start, end, static_cast(label), last_check)) { + return true; + } + start = last_check; + } + + return false; +} + +bool integer_label_vector::check_label_full_contain(uint32_t point_id, const std::vector& labels) +{ + if (point_id >= _offset.size() - 1) return false; + + auto start = _offset[point_id]; + auto end = _offset[point_id + 1]; + + + for (const auto &label : labels) { + size_t last_check = 0; + if (!binary_search(start, end, label, last_check)) { + return false; + } + start = last_check; + } + + return true; +} + +bool integer_label_vector::check_label_full_contain(uint32_t source_point, uint32_t target_point) +{ + if (source_point >= _offset.size() - 1 || target_point >= _offset.size() - 1) return false; + + auto start = _offset[source_point]; + auto end = _offset[source_point + 1]; + auto target_start = _offset[target_point]; + auto target_end = _offset[target_point + 1]; + + for (size_t i = target_start; i < target_end; i++) + { + size_t last_check = 0; + if (!binary_search(start, end, _data[i], last_check)) { + return false; + } + start = last_check; + } + + return true; +} + +bool integer_label_vector::binary_search(size_t start, size_t end, uint32_t label, size_t& last_check) +{ + while (start < end) { + size_t mid = (start + end) >> 1; + + if (_data[mid] == label) { + last_check = mid; + return true; + } + if (_data[mid] < label) start = mid + 1; + else end = mid; + } + + last_check = start; + + return false; +} + +size_t integer_label_vector::get_memory_usage() const +{ + return _offset.capacity() * sizeof(size_t) + _data.capacity() * sizeof(uint32_t); +} + +bool integer_label_vector::write_to_file(const std::string& label_file) const +{ + //format: + // format version: uint8 + // num_points: uint32 + // offset array: uint64[num_points + 1] + // label data: uint32[total_labels] + const uint8_t format_version = 1; + std::ofstream outfile(label_file, std::ios::binary); + if (outfile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + outfile.write((char*)(&format_version), sizeof(uint8_t)); + + uint32_t num_points = static_cast(_offset.size() - 1); + outfile.write((char*)(&num_points), sizeof(uint32_t)); + outfile.write((char*)_offset.data(), _offset.size() * sizeof(size_t)); + + outfile.write((char*)_data.data(), _data.size() * sizeof(uint32_t)); + outfile.close(); + + return true; +} + +const std::vector &integer_label_vector::get_offset_vector() const +{ + return _offset; +} + +const std::vector& integer_label_vector::get_data_vector() const +{ + return _data; +} + +template bool integer_label_vector::add_labels(uint32_t point_id, std::vector& labels); +template bool integer_label_vector::add_labels(uint32_t point_id, std::vector& labels); +template bool integer_label_vector::check_label_exists(uint32_t point_id, const std::vector& labels); +template bool integer_label_vector::check_label_exists(uint32_t point_id, const std::vector& labels); + +} \ No newline at end of file diff --git a/src/label_helper.cpp b/src/label_helper.cpp index 7fb348f5b..df41fe0f4 100644 --- a/src/label_helper.cpp +++ b/src/label_helper.cpp @@ -195,4 +195,116 @@ size_t label_helper::search_string_range(const std::string& str, char ch, size_t return std::string::npos; } +bool label_helper::parse_label_file_in_integer( + const std::string& label_file, + size_t& num_points, + integer_label_vector& integer_vector, + tsl::robin_set& labels, TableStats& table_stats) +{ + std::ifstream infile(label_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + infile.close(); + + unsigned line_cnt = 0; + + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + cur_pos = next_pos + 1; + + line_cnt++; + } + + const size_t rough_avg_labels_per_point = 10; + + if (num_points > line_cnt) + { + size_t rough_total_labels = num_points * rough_avg_labels_per_point; + integer_vector.initialize(num_points, rough_total_labels); + } + else + { + size_t rough_total_labels = line_cnt * rough_avg_labels_per_point; + integer_vector.initialize(line_cnt, rough_total_labels); + } + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + std::string label_str; + std::vector current_labels; + + cur_pos = 0; + next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + current_labels.clear(); + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = search_string_range(buffer, ',', lbl_pos, next_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line + { + next_lbl_pos = next_pos; + } + + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') + { + label_str.erase(label_str.length() - 1); + } + + size_t token_as_num = std::stoul(label_str); + current_labels.push_back(static_cast(token_as_num)); + + labels.insert(static_cast(token_as_num)); + table_stats.label_total_count++; + + lbl_pos = next_lbl_pos + 1; + } + + integer_vector.add_labels(line_cnt, current_labels); + + cur_pos = next_pos + 1; + + line_cnt++; + } + + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << labels.size() << " distinct label(s)" << std::endl; + + return true; +} + } \ No newline at end of file diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 1f2f1e49d..7c3e84137 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -10,6 +10,7 @@ #include "pq_flash_index.h" #include "cosine_similarity.h" #include "color_helper.h" +#include "filter_match_proxy.h" #include #include @@ -591,7 +592,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons const char *pivots_filepath, const char *compressed_filepath, const char* labels_filepath, const char* labels_to_medoids_filepath, const char* labels_map_filepath, const char* unv_label_filepath, - const char* seller_filepath, bool load_bitmask_label) + const char* seller_filepath, LabelFormatType label_format_type) { #endif std::string pq_table_bin = pivots_filepath; @@ -645,19 +646,15 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons { FileContent &content_labels = files.getContent(labels_file); std::stringstream infile(std::string((const char *)content_labels._content, content_labels._size)); + _label_map = load_label_map(infile); #else if (file_exists(labels_file)) { - std::ifstream map_reader(labels_map_file); #endif - _label_map = load_label_map(map_reader); + label_helper().load_label_map(labels_to_medoids, _label_map); this->_table_stats.label_count = _label_map.size(); -#ifndef EXEC_ENV_OLS - map_reader.close(); -#endif - - if (!load_bitmask_label) + if (label_format_type == LabelFormatType::String) { label_helper().parse_label_file_in_bitset( labels_file, @@ -666,7 +663,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons _bitmask_buf, _table_stats); } - else + else if (label_format_type == LabelFormatType::BitMask) { if (label_helper().read_bitmask_from_file(labels_file, _bitmask_buf, num_pts_in_label_file)) { @@ -679,6 +676,23 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons -1); } } + else if (label_format_type == LabelFormatType::Integer) + { + if (_label_vector.initialize_from_file(labels_file, num_pts_in_label_file)) + { + diskann::cout << "Integer labels loaded from " << labels_file << std::endl; + _use_integer_labels = true; + } + else + { + diskann::cerr << "Failed to load integer labels from " << labels_file << std::endl; + throw diskann::ANNException(std::string("Failed to load integer labels from ") + labels_file, + -1); + } + + this->_table_stats.label_mem_usage = _label_vector.get_memory_usage(); + + } #ifdef EXEC_ENV_OLS if (files.fileExists(labels_to_medoids)) @@ -1204,26 +1218,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t float *query_float = pq_query_scratch->aligned_query_float; float *query_rotated = pq_query_scratch->rotated_query; - simple_bitmask_full_val bitmask_full_val; - std::vector& query_bitmask_buf = query_scratch->query_label_bitmask(); - if (use_filter) - { - query_bitmask_buf.resize(_bitmask_buf._bitmask_size, 0); - bitmask_full_val._mask = query_bitmask_buf.data(); - for (const auto& filter_label : filter_labels) - { - auto bitmask_val = simple_bitmask::get_bitmask_val(filter_label); - bitmask_full_val.merge_bitmask_val(bitmask_val); - } - - if (_use_universal_label) - { - auto bitmask_val = simple_bitmask::get_bitmask_val(_universal_filter_label); - bitmask_full_val.merge_bitmask_val(bitmask_val); - } - } + label_filter_match_holder match_proxy( + _bitmask_buf, + query_bitmask_buf, + _label_vector, + filter_labels, + _universal_filter_label, + _use_integer_labels); // normalization step. for cosine, we simply normalize the query // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to @@ -1476,12 +1479,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t if (use_filter) { - simple_bitmask bm(_bitmask_buf.get_bitmask(id), _bitmask_buf._bitmask_size); - - if (!bm.test_full_mask_val(bitmask_full_val)) - { + if (!match_proxy.contain_filtered_label(id)) continue; - } } cmps++; @@ -1546,12 +1545,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t if (use_filter) { - simple_bitmask bm(_bitmask_buf.get_bitmask(id), _bitmask_buf._bitmask_size); - - if (!bm.test_full_mask_val(bitmask_full_val)) - { + if (!match_proxy.contain_filtered_label(id)) continue; - } } cmps++; From d54db9463d0695017547d4985238daa6f44ffef8 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 27 Oct 2025 16:55:59 +0800 Subject: [PATCH 05/10] fix ssd index building --- include/disk_utils.h | 4 ++-- src/disk_utils.cpp | 43 ++++++++++++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/include/disk_utils.h b/include/disk_utils.h index e300a10ee..def83a8a5 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -79,7 +79,7 @@ DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann:: uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters = false, + uint32_t num_threads, bool use_filters = false, bool use_integer_labels = false, const std::string &label_file = std::string(""), const std::string &labels_to_medoids_file = std::string(""), const std::string &universal_label = "", const uint32_t Lf = 0); @@ -95,7 +95,7 @@ DISKANN_DLLEXPORT int build_disk_index( const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric _compareMetric, bool use_opq = false, const std::string &codebook_prefix = "", // default is empty for no codebook pass in - bool use_filters = false, + bool use_filters = false, bool use_integer_labels = false, const std::string &label_file = std::string(""), // default is empty string for no label_file const std::string &universal_label = "", const uint32_t filter_threshold = 0, const uint32_t Lf = 0, diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 846fb3b89..2fe2d2bb8 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -628,7 +628,7 @@ template int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, - uint32_t num_threads, bool use_filters, const std::string &label_file, + uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, uint32_t universal_label_num = 0, const char* seller_file_path = nullptr, @@ -663,6 +663,10 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::make_shared(paras), nullptr, defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, build_pq_bytes > 0, build_pq_bytes, use_opq, use_filters); + if (use_integer_labels) + { + _index.enable_integer_label(); + } if (!use_filters) _index.build(base_file.c_str(), base_num); else @@ -1113,7 +1117,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, + diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, @@ -1203,6 +1207,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::string mem_labels_file = mem_index_path + "_labels.txt"; std::string disk_labels_file = disk_index_path + "_labels.txt"; std::string disk_bitmask_labels_file = disk_index_path + "_bitmask_labels.bin"; + std::string disk_integer_labels_file = disk_index_path + "_integer_labels.bin"; std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; @@ -1346,7 +1351,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const timer.reset(); diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, - build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use, + build_pq_bytes, use_opq, num_threads, use_filters, use_integer_labels, labels_file_to_use, labels_to_medoids_path, universal_label, Lf, universal_label_id, sellerFilePath, num_diverse_build); diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl; @@ -1388,6 +1393,14 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::remove(bitmask_label_file.c_str()); } + // rename integer label file + std::string integer_label_file = std::string(mem_index_path) + "_integer_labels.bin"; + if (file_exists(integer_label_file)) + { + copy_file(integer_label_file, disk_integer_labels_file); + std::remove(integer_label_file.c_str()); + } + std::remove(augmented_data_file.c_str()); std::remove(augmented_labels_file.c_str()); std::remove(labels_file_to_use.c_str()); @@ -1475,7 +1488,7 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, + bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, @@ -1484,7 +1497,7 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *d const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, + bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, @@ -1493,7 +1506,7 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, + bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, @@ -1503,7 +1516,7 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, + bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, @@ -1512,7 +1525,7 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *d const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, + bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, @@ -1521,7 +1534,7 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, - const std::string &label_file, + bool use_integer_labels, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, const uint32_t Lf, const char* reorderDataFilePath, const char* sellerFilePath, @@ -1530,32 +1543,32 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *dat template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); // Label=16_t template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, bool use_integer_labels, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); }; // namespace diskann From 6ae08a272ef4c9efaeb5482ed55a1d1ae96ac910 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Mon, 27 Oct 2025 17:21:14 +0800 Subject: [PATCH 06/10] expose integer label get/set --- include/abstract_index.h | 3 +++ include/index.h | 4 ++-- src/index.cpp | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/include/abstract_index.h b/include/abstract_index.h index 5ec0fd6bd..175509552 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -112,6 +112,9 @@ class AbstractIndex template void set_universal_label(const label_type universal_label); + virtual void enable_integer_label() = 0; + virtual bool integer_label_enabled() const = 0; + virtual bool is_label_valid(const std::string &raw_label) const = 0; virtual bool is_set_universal_label() const = 0; virtual TableStats get_table_stats() const = 0; diff --git a/include/index.h b/include/index.h index 2a007fa18..b4eb6f681 100644 --- a/include/index.h +++ b/include/index.h @@ -119,9 +119,9 @@ template clas DISKANN_DLLEXPORT bool is_set_universal_label() const override; - DISKANN_DLLEXPORT void enable_integer_label(); + DISKANN_DLLEXPORT void enable_integer_label() override; - DISKANN_DLLEXPORT bool integer_label_enabled(); + DISKANN_DLLEXPORT bool integer_label_enabled() const override; // Set starting point of an index before inserting any points incrementally. // The data count should be equal to _num_frozen_pts * _aligned_dim. diff --git a/src/index.cpp b/src/index.cpp index 6951fc14d..146456313 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2073,7 +2073,7 @@ void Index::enable_integer_label() } template -bool Index::integer_label_enabled() +bool Index::integer_label_enabled() const { return _use_integer_labels; } From 31c859766f8c7b686f316f0fe7455d9c1e944583 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Tue, 28 Oct 2025 12:08:57 +0800 Subject: [PATCH 07/10] clean up --- src/integer_label_vector.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/integer_label_vector.cpp b/src/integer_label_vector.cpp index 9b12f65bd..11616b562 100644 --- a/src/integer_label_vector.cpp +++ b/src/integer_label_vector.cpp @@ -1,5 +1,4 @@ #include "integer_label_vector.h" -#include "ann_exception.h" #include #include @@ -24,15 +23,13 @@ bool integer_label_vector::initialize_from_file(const std::string& label_file, s std::ifstream infile(label_file, std::ios::binary); if (infile.fail()) { - throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + return false; } uint8_t format_version = 0; infile.read((char*)(&format_version), sizeof(uint8_t)); if (format_version != 1) { - throw diskann::ANNException(std::string("Unsupported label file format version ") + - std::to_string(format_version), - -1); + return false; } uint32_t num_points_in_file = 0; infile.read((char*)(&num_points_in_file), sizeof(uint32_t)); @@ -54,8 +51,7 @@ bool integer_label_vector::add_labels(uint32_t point_id, std::vector &la { return false; } - - auto start = _offset[point_id]; + for (const auto &label : labels) { _data.push_back(static_cast(label)); } @@ -166,7 +162,7 @@ bool integer_label_vector::write_to_file(const std::string& label_file) const std::ofstream outfile(label_file, std::ios::binary); if (outfile.fail()) { - throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + return false; } outfile.write((char*)(&format_version), sizeof(uint8_t)); From 6d699210026619a4d16cba647aa67453b4d9cea8 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Tue, 28 Oct 2025 14:16:23 +0800 Subject: [PATCH 08/10] Fix ssd medoid file loading --- include/label_helper.h | 50 ++++++++++++++++++++++++++++++++++++++++++ src/pq_flash_index.cpp | 15 ++++++------- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/include/label_helper.h b/include/label_helper.h index ef60bfb10..96ba931f3 100644 --- a/include/label_helper.h +++ b/include/label_helper.h @@ -129,6 +129,56 @@ class label_helper return true; } + // duplicate of above to fit ssd index API, NEED TO UNIFY LATER + template + bool load_label_medoids( + const std::string& label_medoids_file, + std::unordered_map>& label_to_start_ids) + { + std::ifstream infile(label_medoids_file, std::ios::binary); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_medoids_file, -1); + } + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + infile.close(); + + unsigned line_cnt = 0; + + size_t cur_pos = 0; + size_t next_pos = 0; + size_t lbl_pos = 0; + std::string token; + while (cur_pos < file_size && cur_pos != std::string::npos) + { + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } + + lbl_pos = search_string_range(buffer, ',', cur_pos, next_pos); + token.assign(buffer.c_str() + cur_pos, lbl_pos - cur_pos); + LabelT label_num = (LabelT)std::stoul(token); + + token.assign(buffer.c_str() + lbl_pos + 1, next_pos - lbl_pos - 1); + uint32_t medoid = (uint32_t)std::stoul(token); + + label_to_start_ids[label_num].push_back(medoid); + + cur_pos = next_pos + 1; + + line_cnt++; + } + + return true; + } private: size_t search_string_range(const std::string& str, char ch, size_t start, size_t end); }; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 7c3e84137..1d87f596d 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -651,7 +651,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons if (file_exists(labels_file)) { #endif - label_helper().load_label_map(labels_to_medoids, _label_map); + label_helper().load_label_map(labels_map_file, _label_map); this->_table_stats.label_count = _label_map.size(); if (label_format_type == LabelFormatType::String) @@ -700,12 +700,6 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons FileContent &content_labels_to_meoids = files.getContent(labels_to_medoids); std::stringstream medoid_stream( std::string((const char *)content_labels_to_meoids._content, content_labels_to_meoids._size)); -#else - if (file_exists(labels_to_medoids)) - { - std::ifstream medoid_stream(labels_to_medoids); - assert(medoid_stream.is_open()); -#endif std::string line, token; _filter_to_medoid_ids.clear(); @@ -728,11 +722,16 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons _filter_to_medoid_ids[label].swap(medoids); } } - catch (std::system_error &e) + catch (std::system_error& e) { throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); } } +#else + _filter_to_medoid_ids.clear(); + label_helper().load_label_medoids(labels_to_medoids, _filter_to_medoid_ids); +#endif + std::string univ_label_file = (unv_label_filepath == nullptr ? "" : unv_label_filepath); #ifdef EXEC_ENV_OLS From 9c665d673317889255b454997e8c85e2394dfb08 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Tue, 28 Oct 2025 16:32:43 +0800 Subject: [PATCH 09/10] sort label befor searching --- src/index.cpp | 5 ++++- src/pq_flash_index.cpp | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 146456313..d3eff8b7b 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2550,7 +2550,8 @@ std::pair Index::search_with_filters(const diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl; } - std::vector filter_vec; + thread_local std::vector filter_vec; + filter_vec.clear(); filter_vec.reserve(filter_labels.size()); std::vector init_ids = get_init_ids(); @@ -2582,6 +2583,8 @@ std::pair Index::search_with_filters(const filter_vec.emplace_back(filter_label); } + std::sort(filter_vec.begin(), filter_vec.end()); + _data_store->preprocess_query(query, scratch); auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 1d87f596d..2124fec9f 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1201,6 +1201,19 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t maxLperSeller = static_cast(l_search); } + thread_local std::vector local_filter_labels; + local_filter_labels.clear(); + local_filter_labels.reserve(filter_labels.size()); + for (const auto& label : filter_labels) + { + local_filter_labels.push_back(label); + } + + if (local_filter_labels.size() > 0) + { + std::sort(local_filter_labels.begin(), local_filter_labels.end()); + } + ScratchStoreManager> manager(this->_thread_data); auto data = manager.scratch_space(); IOContext &ctx = data->ctx; @@ -1223,7 +1236,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t _bitmask_buf, query_bitmask_buf, _label_vector, - filter_labels, + local_filter_labels, _universal_filter_label, _use_integer_labels); @@ -1324,7 +1337,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } else { - for (const auto& filter_label : filter_labels) + for (const auto& filter_label : local_filter_labels) { if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) { From 5d329eb5354af6c28e00bfbf271682e497e49e40 Mon Sep 17 00:00:00 2001 From: Sanhaoji2 Date: Wed, 29 Oct 2025 19:15:11 +0800 Subject: [PATCH 10/10] resolve comment --- include/integer_label_vector.h | 2 -- src/index.cpp | 1 + src/integer_label_vector.cpp | 19 ------------------- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/include/integer_label_vector.h b/include/integer_label_vector.h index b2ec405fe..68688419f 100644 --- a/include/integer_label_vector.h +++ b/include/integer_label_vector.h @@ -22,8 +22,6 @@ class integer_label_vector template bool check_label_exists(uint32_t point_id, const std::vector &labels); - bool check_label_full_contain(uint32_t point_id, const std::vector &labels); - bool check_label_full_contain(uint32_t source_point, uint32_t target_point); const std::vector &get_offset_vector() const; diff --git a/src/index.cpp b/src/index.cpp index d3eff8b7b..030238281 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2699,6 +2699,7 @@ size_t Index::search_with_tags(const T *query, const uint64_t K filter_vec.push_back(converted_label); } + std::sort(filter_vec.begin(), filter_vec.end()); iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); } diff --git a/src/integer_label_vector.cpp b/src/integer_label_vector.cpp index 11616b562..c467050f1 100644 --- a/src/integer_label_vector.cpp +++ b/src/integer_label_vector.cpp @@ -88,25 +88,6 @@ bool integer_label_vector::check_label_exists(uint32_t point_id, const std::vect return false; } -bool integer_label_vector::check_label_full_contain(uint32_t point_id, const std::vector& labels) -{ - if (point_id >= _offset.size() - 1) return false; - - auto start = _offset[point_id]; - auto end = _offset[point_id + 1]; - - - for (const auto &label : labels) { - size_t last_check = 0; - if (!binary_search(start, end, label, last_check)) { - return false; - } - start = last_check; - } - - return true; -} - bool integer_label_vector::check_label_full_contain(uint32_t source_point, uint32_t target_point) { if (source_point >= _offset.size() - 1 || target_point >= _offset.size() - 1) return false;