diff --git a/src/core/mixed_reducer/mixed_streamer_reducer.cc b/src/core/mixed_reducer/mixed_streamer_reducer.cc index b5e241bb..bb84e3d6 100644 --- a/src/core/mixed_reducer/mixed_streamer_reducer.cc +++ b/src/core/mixed_reducer/mixed_streamer_reducer.cc @@ -239,6 +239,9 @@ int MixedStreamerReducer::read_vec(size_t source_streamer_index, bool need_revert = (target_streamer_->meta().reformer_name() != streamer->meta().reformer_name() && reformer != nullptr); + if (target_builder_ && reformer) { + need_revert = true; + } IndexProvider::Pointer provider = streamer->create_provider(); IndexProvider::Iterator::Pointer iterator = provider->create_iterator(); @@ -366,7 +369,7 @@ void MixedStreamerReducer::add_vec_with_builder(int *result) { std::string out_vector_buffer = std::string( static_cast(vector), original_query_meta_.dimension() * original_query_meta_.unit_size()); - PushToDocCache(target_streamer_query_meta, (uint32_t)vector_item.pkey_, + PushToDocCache(original_query_meta_, (uint32_t)vector_item.pkey_, out_vector_buffer); } @@ -509,8 +512,6 @@ void MixedStreamerReducer::PushToDocCache(const IndexQueryMeta &meta, } int MixedStreamerReducer::IndexBuild() { - const bool need_convert = !is_target_and_source_same_reformer_ && - target_streamer_reformer_ != nullptr; IndexHolder::Pointer target_holder; if (original_query_meta_.data_type() == core::IndexMeta::DataType::DT_FP16) { auto holder = std::make_shared< @@ -563,7 +564,7 @@ int MixedStreamerReducer::IndexBuild() { LOG_ERROR("data_type is not support"); return core::IndexError_Runtime; } - if (target_builder_converter_ && need_convert) { + if (target_builder_converter_) { core::IndexConverter::TrainAndTransform(target_builder_converter_, target_holder); target_holder = target_builder_converter_->result(); diff --git a/src/core/quantizer/cosine_converter.cc b/src/core/quantizer/cosine_converter.cc index e8deaca8..dda76b01 100644 --- a/src/core/quantizer/cosine_converter.cc +++ b/src/core/quantizer/cosine_converter.cc @@ -64,8 +64,8 @@ class CosineConverterHolder : public IndexHolder { //! Retrieve pointer of data const void *data(void) const override { - return type_ == IndexMeta::DataType::DT_FP32 ? normalize_buffer_.data() - : buffer_.data(); + return type_ == original_type_ ? normalize_buffer_.data() + : buffer_.data(); } //! Test if the iterator is valid @@ -325,7 +325,7 @@ class CosineConverter : public IndexConverter { //! Transform the data int transform(IndexHolder::Pointer holder) override { - if (holder->data_type() != IndexMeta::DataType::DT_FP32 || + if (holder->data_type() != original_type_ || holder->dimension() != meta_.dimension() - ExtraDimension(dst_type_)) { return IndexError_Mismatch; } diff --git a/tests/core/algorithm/ivf/ivf_searcher_test.cc b/tests/core/algorithm/ivf/ivf_searcher_test.cc index 0ce94ced..4c4c8c02 100644 --- a/tests/core/algorithm/ivf/ivf_searcher_test.cc +++ b/tests/core/algorithm/ivf/ivf_searcher_test.cc @@ -320,6 +320,97 @@ TEST_F(IVFSearcherTest, TestSimple) { EXPECT_EQ(0, ret); } +TEST_F(IVFSearcherTest, TestSimpleCosine) { + IVFBuilder builder; + // index_meta_.set_major_order(IndexMeta::MO_ROW); + params_.set(PARAM_IVF_BUILDER_CENTROID_COUNT, "1"); + params_.set(PARAM_IVF_BUILDER_CLUSTER_CLASS, "KmeansCluster"); + + Params converter_params; + auto converter = IndexFactory::CreateConverter("CosineNormalizeConverter"); + ASSERT_TRUE(converter != nullptr); + auto original_index_meta = index_meta_; + original_index_meta.set_metric("Cosine", 0, Params()); + EXPECT_EQ(0, converter->init(original_index_meta, converter_params)); + IndexMeta index_meta = converter->meta(); + auto reformer = IndexFactory::CreateReformer(index_meta.reformer_name()); + ASSERT_TRUE(reformer != nullptr); + ASSERT_EQ(0, reformer->init(index_meta.reformer_params())); + + int ret = builder.init(index_meta, params_); + EXPECT_EQ(0, ret); + prepare_index_holder(0, 33); + converter->transform(holder_); + auto holder = converter->result(); + + EXPECT_EQ(0, builder.train(threads_, holder)); + EXPECT_EQ(0, builder.build(threads_, holder)); + IndexDumper::Pointer dumper = IndexFactory::CreateDumper("FileDumper"); + EXPECT_EQ(0, dumper->create(index_path_)); + + ret = builder.dump(dumper); + EXPECT_EQ((size_t)33, builder.stats().built_count()); + EXPECT_EQ((size_t)33, builder.stats().dumped_count()); + EXPECT_EQ((size_t)0, builder.stats().discarded_count()); + EXPECT_EQ(0, dumper->close()); + + IVFSearcher searcher; + Params params; + params.set(PARAM_IVF_SEARCHER_SCAN_RATIO, 1.0); + params.set(PARAM_IVF_SEARCHER_BRUTE_FORCE_THRESHOLD, 1); + + ret = searcher.init(params); + EXPECT_EQ(0, ret); + + IndexStorage::Pointer container = + IndexFactory::CreateStorage("MMapFileReadStorage"); + EXPECT_TRUE(!!container); + + Params container_params; + container_params.set("proxima.mmap_file.container.memory_warmup", true); + container->init(container_params); + ret = container->open(index_path_, false); + EXPECT_EQ(0, ret); + + ret = searcher.load(container, IndexMetric::Pointer()); + EXPECT_EQ(0, ret); + + std::vector query; + for (size_t i = 0; i < dimension_; ++i) { + query.push_back(32.0f + i); + } + + size_t qnum = 33; + std::vector query1; + for (size_t i = 0; i < dimension_ * qnum; ++i) { + query1.push_back(i / dimension_); + } + auto context = searcher.create_context(); + IndexQueryMeta qmeta(IndexMeta::DataType::DT_FP32, dimension_); + + // single bf search + { + size_t topk = 33; + context->set_topk(topk); + + std::string new_vec; + IndexQueryMeta new_meta; + ASSERT_EQ(0, reformer->convert(query.data(), qmeta, &new_vec, &new_meta)); + + ret = searcher.search_bf_impl(new_vec.data(), new_meta, context); + EXPECT_EQ(0, ret); + + const IndexDocumentList &result = context->result(0); + EXPECT_EQ((size_t)topk, result.size()); + for (size_t i = 0; i < 1; ++i) { + // ASSERT_EQ(29, result[i].key()); + EXPECT_NEAR(0, result[i].score(), 1e-2); + } + } + ret = searcher.unload(); + EXPECT_EQ(0, ret); +} + TEST_F(IVFSearcherTest, TestColumnMajorFloatWithBuildMemory) { IVFBuilder builder; // index_meta_.set_major_order(IndexMeta::MO_ROW);