Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AbstractIndex
#ifdef EXEC_ENV_OLS
virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
#else
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false) = 0;
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String) = 0;
#endif

// For FastL2 search on optimized layout
Expand All @@ -63,8 +63,8 @@ class AbstractIndex
// Initialize space for res_vectors before calling.
template <typename data_type, typename tag_type>
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");
float *distances, std::vector<data_type *> &res_vectors, bool use_filters,
const std::vector<std::string>& filter_labels);

// Added search overload that takes L as parameter, so that we
// can customize L on a per-query basis without tampering with "Parameters"
Expand All @@ -80,7 +80,7 @@ class AbstractIndex
// Filter support search
// IndexType is either uint32_t or uint64_t
template <typename IndexType>
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::string> &raw_labels,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IndexType *indices,
float *distances);
Expand Down Expand Up @@ -112,6 +112,9 @@ class AbstractIndex

template <typename label_type> void set_universal_label(const label_type universal_label);

virtual void enable_integer_label() = 0;
virtual bool integer_label_enabled() const = 0;

virtual bool is_label_valid(const std::string &raw_label) const = 0;
virtual bool is_set_universal_label() const = 0;
virtual TableStats get_table_stats() const = 0;
Expand All @@ -122,7 +125,7 @@ class AbstractIndex
std::any &indices, float *distances = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
std::any& indices, float* distances = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::vector<std::string> &filter_labels,
const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) = 0;
Expand All @@ -133,8 +136,8 @@ class AbstractIndex
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") = 0;
float *distances, DataVector &res_vectors, bool use_filters,
const std::vector<std::string>& filter_labels) = 0;
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
virtual void _set_universal_label(const LabelType universal_label) = 0;
};
Expand Down
4 changes: 2 additions & 2 deletions include/disk_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::
uint32_t R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_file,
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
uint32_t num_threads, bool use_filters = false,
uint32_t num_threads, bool use_filters = false, bool use_integer_labels = false,
const std::string &label_file = std::string(""),
const std::string &labels_to_medoids_file = std::string(""),
const std::string &universal_label = "", const uint32_t Lf = 0);
Expand All @@ -95,7 +95,7 @@ DISKANN_DLLEXPORT int build_disk_index(
const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters,
diskann::Metric _compareMetric, bool use_opq = false,
const std::string &codebook_prefix = "", // default is empty for no codebook pass in
bool use_filters = false,
bool use_filters = false, bool use_integer_labels = false,
const std::string &label_file = std::string(""), // default is empty string for no label_file
const std::string &universal_label = "", const uint32_t filter_threshold = 0,
const uint32_t Lf = 0,
Expand Down
66 changes: 66 additions & 0 deletions include/filter_match_proxy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#pragma once
#include "label_bitmask.h"
#include "integer_label_vector.h"

namespace diskann
{

class filter_match_proxy
{
public:
virtual bool contain_filtered_label(uint32_t id) = 0;
};

template <typename LabelT>
class bitmask_filter_match : public filter_match_proxy
{
public:
bitmask_filter_match(simple_bitmask_buf& bitmask_filters,
std::vector<std::uint64_t>& query_bitmask_buf,
const std::vector<LabelT>& filter_labels,
LabelT unv_label);

virtual bool contain_filtered_label(uint32_t id) override;

private:
simple_bitmask_buf& _bitmask_filters;
std::vector<std::uint64_t>& _query_bitmask_buf;
simple_bitmask_full_val _bitmask_full_val;
};

template <typename LabelT>
class integer_label_filter_match : public filter_match_proxy
{
public:
integer_label_filter_match(integer_label_vector& label_vector,
const std::vector<LabelT>& filter_labels,
LabelT unv_label);

virtual bool contain_filtered_label(uint32_t id) override;

private:
integer_label_vector& _label_vector;
const std::vector<LabelT>& _filter_labels;
LabelT _unv_label;
};

template <typename LabelT>
class label_filter_match_holder : public filter_match_proxy
{
public:
label_filter_match_holder(simple_bitmask_buf& bitmask_filters,
std::vector<std::uint64_t>& query_bitmask_buf,
integer_label_vector& label_vector,
const std::vector<LabelT>& filter_labels,
LabelT unv_label,
bool use_integer_labels);

virtual bool contain_filtered_label(uint32_t id) override;

private:
bitmask_filter_match<LabelT> _bitmask_filter_match;
integer_label_filter_match<LabelT> _integer_label_filter_match;
bool _use_integer_labels;
};

}
31 changes: 23 additions & 8 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "percentile_stats.h"
#include <bitset>
#include "label_bitmask.h"
#include "integer_label_vector.h"

#include "quantized_distance.h"
#include "pq_data_store.h"
Expand Down Expand Up @@ -80,7 +81,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
#else
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false);
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String);
#endif

// get some private variables
Expand Down Expand Up @@ -118,6 +119,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

DISKANN_DLLEXPORT bool is_set_universal_label() const override;

DISKANN_DLLEXPORT void enable_integer_label() override;

DISKANN_DLLEXPORT bool integer_label_enabled() const override;

// Set starting point of an index before inserting any points incrementally.
// The data count should be equal to _num_frozen_pts * _aligned_dim.
DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count);
Expand All @@ -144,15 +149,15 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

// Initialize space for res_vectors before calling.
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
const std::string filter_label = "");
float *distances, std::vector<T *> &res_vectors, bool use_filters,
const std::vector<std::string>& filter_labels);

virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
std::any& indices, float* distances = nullptr) override;

// Filter support search
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const std::vector<LabelT> &filter_labels,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IndexType *indices, float *distances);

Expand Down Expand Up @@ -217,7 +222,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) override;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::string &filter_label_raw, const size_t K,
const std::vector<std::string> &filter_labels_raw, const size_t K,
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances) override;

Expand All @@ -237,8 +242,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;

virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
float *distances, DataVector &res_vectors, bool use_filters = false,
const std::string filter_label = "") override;
float *distances, DataVector &res_vectors, bool use_filters,
const std::vector<std::string>& filter_labels) override;

virtual void _set_universal_label(const LabelType universal_label) override;

Expand All @@ -253,11 +258,18 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// determines navigating node of the graph by calculating medoid of datafopt
uint32_t calculate_entry_point();

void parse_label_file(const std::string &label_file, size_t &num_pts_labels);
void parse_label_file(const std::string &label_file, size_t &num_pts_labels, size_t& total_labels);
void parse_seller_file(const std::string& label_file, size_t& num_pts_labels);

void convert_pts_label_to_bitmask(std::vector<std::vector<LabelT>>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels);

void convert_pts_label_to_integer_vector(std::vector<std::vector<LabelT>> &pts_to_labels,
integer_label_vector &int_label_vector, size_t total_labels);

void aggregate_points_by_bitmask_label(std::unordered_map<LabelT, std::vector<uint32_t>>& label_to_points, size_t num_points_to_load);

void aggregate_points_by_integer_label(std::unordered_map<LabelT, std::vector<uint32_t>>& label_to_points, size_t num_points_to_load);

std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);

// Returns the locations of start point and frozen points suitable for use
Expand Down Expand Up @@ -463,6 +475,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

simple_bitmask_buf _bitmask_buf;

bool _use_integer_labels = false;
integer_label_vector _label_vector;

TableStats _table_stats;

static const float INDEX_GROWTH_FACTOR;
Expand Down
41 changes: 41 additions & 0 deletions include/integer_label_vector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once
#include <vector>
#include <string>

namespace diskann
{

class integer_label_vector
{
public:
bool initialize(size_t numpoints, size_t total_labels);

bool initialize_from_file(const std::string &label_file, size_t &numpoints);

bool write_to_file(const std::string &label_file) const;

template <typename LabelT>
bool add_labels(uint32_t point_id, std::vector<LabelT> &labels);

bool check_label_exists(uint32_t point_id, uint32_t label);

template <typename LabelT>
bool check_label_exists(uint32_t point_id, const std::vector<LabelT> &labels);

bool check_label_full_contain(uint32_t source_point, uint32_t target_point);

const std::vector<size_t> &get_offset_vector() const;

const std::vector<uint32_t> &get_data_vector() const;

size_t get_memory_usage() const;

private:
bool binary_search(size_t start, size_t end, uint32_t label, size_t& last_check);

private:
std::vector<size_t> _offset;
std::vector<uint32_t> _data;
};

}
Loading
Loading