diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h index 1bad007dd..c098ade77 100644 --- a/AnnService/inc/Core/Common.h +++ b/AnnService/inc/Core/Common.h @@ -14,6 +14,7 @@ #include #include "inc/Helper/Logging.h" #include "inc/Helper/DiskIO.h" +#include #ifndef _MSC_VER #include @@ -152,17 +153,75 @@ enum class DistCalcMethod : std::uint8_t }; static_assert(static_cast(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!"); - enum class VectorValueType : std::uint8_t { #define DefineVectorValueType(Name, Type) Name, #include "DefinitionList.h" #undef DefineVectorValueType - Undefined }; static_assert(static_cast(VectorValueType::Undefined) != 0, "Empty VectorValueType!"); +// remove_last is by Vladimir Reshetnikov, https://stackoverflow.com/a/51805324 +template +struct remove_last; + +template<> +struct remove_last>; // Define as you wish or leave undefined + +template +struct remove_last> +{ +private: + using Tuple = std::tuple; + + template + static std::tuple...> + extract(std::index_sequence); + +public: + using type = decltype(extract(std::make_index_sequence())); +}; + +template +using remove_last_t = typename remove_last::type; + +using VectorValueTypeTuple = remove_last_t>; + +// Dispatcher is based on https://stackoverflow.com/a/34046180 +template +std::function call_with_default(F&& f) +{ + return [f]() {f(T{}); }; +} + +template +void VectorValueTypeDispatch(VectorValueType vectorType, F&& f, std::index_sequence) +{ + std::function fs[] = { + call_with_default>(f)... + }; + fs[static_cast(vectorType)](); + +} + +template +void VectorValueTypeDispatch(VectorValueType vectorType, F f) +{ + constexpr auto VectorCount = std::tuple_size::value; + if ((int)vectorType < VectorCount) + { + VectorValueTypeDispatch(vectorType, f, std::make_index_sequence{}); + } + else + { + throw std::exception(); + } +} enum class IndexAlgoType : std::uint8_t { @@ -214,20 +273,10 @@ constexpr VectorValueType GetEnumValueType() \ inline std::size_t GetValueTypeSize(VectorValueType p_valueType) { - switch (p_valueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return sizeof(Type); \ - -#include "DefinitionList.h" -#undef DefineVectorValueType - - default: - break; - } + std::size_t out = 0; + VectorValueTypeDispatch(p_valueType, [&](auto t) { out = sizeof(decltype(t)); }); - return 0; + return out; } enum class QuantizerType : std::uint8_t diff --git a/AnnService/inc/Core/Common/BKTree.h b/AnnService/inc/Core/Common/BKTree.h index ee81f138c..81976d187 100644 --- a/AnnService/inc/Core/Common/BKTree.h +++ b/AnnService/inc/Core/Common/BKTree.h @@ -421,18 +421,7 @@ namespace SPTAG float CountStd; if (args.m_pQuantizer) { - switch (args.m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -CountStd = TryClustering(data, indices, first, last, args, samples, lambdaFactor, true); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(args.m_pQuantizer->GetReconstructType(), [&](auto t) { CountStd = TryClustering(data, indices, first, last, args, samples, lambdaFactor, true); }); } else { @@ -469,18 +458,11 @@ break; if (args.m_pQuantizer) { - switch (args.m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -TryClustering(data, indices, first, last, args, samples, lambdaFactor, debug, abort); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(args.m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + TryClustering(data, indices, first, last, args, samples, lambdaFactor, debug, abort); + }); } else { diff --git a/AnnService/inc/Core/Common/KDTree.h b/AnnService/inc/Core/Common/KDTree.h index 8d16c2f68..59f45e05f 100644 --- a/AnnService/inc/Core/Common/KDTree.h +++ b/AnnService/inc/Core/Common/KDTree.h @@ -63,18 +63,11 @@ namespace SPTAG { if (m_pQuantizer) { - switch (m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -BuildTreesCore(data, numOfThreads, indices, abort); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + BuildTreesCore(data, numOfThreads, indices, abort); + }); } else { @@ -236,17 +229,11 @@ break; { if (m_pQuantizer) { - switch (m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -return KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, distBound); - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + return VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, distBound); + }); } else { diff --git a/AnnService/inc/Core/Common/NeighborhoodGraph.h b/AnnService/inc/Core/Common/NeighborhoodGraph.h index 807bd1f9e..41ad1700e 100644 --- a/AnnService/inc/Core/Common/NeighborhoodGraph.h +++ b/AnnService/inc/Core/Common/NeighborhoodGraph.h @@ -129,18 +129,11 @@ namespace SPTAG { if (index->m_pQuantizer) { - switch (index->m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -PartitionByTptreeCore(index, indices, first, last, leaves); \ -break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + return VectorValueTypeDispatch(index->m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + PartitionByTptreeCore(index, indices, first, last, leaves); + }); } else { diff --git a/AnnService/src/Aggregator/AggregatorService.cpp b/AnnService/src/Aggregator/AggregatorService.cpp index 96a3ce726..faf549cd7 100644 --- a/AnnService/src/Aggregator/AggregatorService.cpp +++ b/AnnService/src/Aggregator/AggregatorService.cpp @@ -222,29 +222,23 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID size_t vectorSize; SizeType vectorDimension = 0; std::vector servers; - switch (context->GetSettings()->m_valueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - if (!queryParser.GetVectorElements().empty()) { \ - Service::ConvertVectorFromString(queryParser.GetVectorElements(), vector, vectorDimension); \ - } else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) { \ - vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); \ - Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize); \ - vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); \ - } \ - for (int i = 0; i < context->GetCenters()->Count(); i++) { \ - servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(), \ - (Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod))); \ - } \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: - break; - } + VectorValueTypeDispatch(context->GetSettings()->m_valueType, [&](auto t) + { + using Type = decltype(t); + if (!queryParser.GetVectorElements().empty()) { + Service::ConvertVectorFromString(queryParser.GetVectorElements(), vector, vectorDimension); + } + else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) { + vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); + Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize); + vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); + } + for (int i = 0; i < context->GetCenters()->Count(); i++) { + servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(), + (Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod))); + } + }); + std::sort(servers.begin(), servers.end(), [](const BasicResult& a, const BasicResult& b) { return a.Dist < b.Dist; }); for (int i = 0; i < context->GetSettings()->m_topK; i++) { auto& server = context->GetRemoteServers().at(servers[i].VID); diff --git a/AnnService/src/Core/Common/IQuantizer.cpp b/AnnService/src/Core/Common/IQuantizer.cpp index 5f2629c26..9c2a0bb58 100644 --- a/AnnService/src/Core/Common/IQuantizer.cpp +++ b/AnnService/src/Core/Common/IQuantizer.cpp @@ -20,31 +20,11 @@ namespace SPTAG case QuantizerType::Undefined: break; case QuantizerType::PQQuantizer: - switch (reconstructType) { - #define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new PQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer()); }); if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset(); return ret; case QuantizerType::OPQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new OPQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: break; - } + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new OPQQuantizer()); }); if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset(); return ret; } @@ -68,31 +48,11 @@ namespace SPTAG case QuantizerType::Undefined: return ret; case QuantizerType::PQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new PQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: break; - } - + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer()); }); if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset(); return ret; case QuantizerType::OPQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new OPQQuantizer()); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: break; - } - + VectorValueTypeDispatch(reconstructType, [&](auto t) {ret.reset(new PQQuantizer()); }); if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset(); return ret; } diff --git a/AnnService/src/Core/VectorIndex.cpp b/AnnService/src/Core/VectorIndex.cpp index 952f40749..7e0003f24 100644 --- a/AnnService/src/Core/VectorIndex.cpp +++ b/AnnService/src/Core/VectorIndex.cpp @@ -543,45 +543,18 @@ VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) { return nullptr; } - + std::shared_ptr out; if (p_algo == IndexAlgoType::BKT) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new BKT::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(p_valuetype, [&](auto t) {out.reset(new BKT::Index); }); + return out; } else if (p_algo == IndexAlgoType::KDT) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new KDT::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(p_valuetype, [&](auto t) {out.reset(new KDT::Index); }); + return out; } else if (p_algo == IndexAlgoType::SPANN) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new SPANN::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } + VectorValueTypeDispatch(p_valuetype, [&](auto t) {out.reset(new SPANN::Index); }); + return out; } return nullptr; } @@ -933,16 +906,11 @@ void VectorIndex::ApproximateRNG(std::shared_ptr& fullVectors, std::u { reconstructed_vector = ALIGN_ALLOC(m_pQuantizer->ReconstructSize()); m_pQuantizer->ReconstructVector((const uint8_t*)fullVectors->GetVector(fullID), reconstructed_vector); - switch (m_pQuantizer->GetReconstructType()) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - (*((COMMON::QueryResultSet*)&resultSet)).SetTarget(reinterpret_cast(reconstructed_vector), m_pQuantizer); \ - break; -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - LOG(Helper::LogLevel::LL_Error, "Unable to get quantizer reconstruct type %s", Helper::Convert::ConvertToString(m_pQuantizer->GetReconstructType())); - } + VectorValueTypeDispatch(m_pQuantizer->GetReconstructType(), [&](auto t) + { + using Type = decltype(t); + (*((COMMON::QueryResultSet*) & resultSet)).SetTarget(reinterpret_cast(reconstructed_vector), m_pQuantizer); + }); } else { diff --git a/AnnService/src/Core/VectorSet.cpp b/AnnService/src/Core/VectorSet.cpp index 68cea06a4..416d423b5 100644 --- a/AnnService/src/Core/VectorSet.cpp +++ b/AnnService/src/Core/VectorSet.cpp @@ -142,16 +142,9 @@ SizeType BasicVectorSet::PerVectorDataSize() const void BasicVectorSet::Normalize(int p_threads) { - switch (m_valueType) - { -#define DefineVectorValueType(Name, Type) \ -case SPTAG::VectorValueType::Name: \ -SPTAG::COMMON::Utils::BatchNormalize(reinterpret_cast(m_data.Data()), m_vectorCount, m_dimension, SPTAG::COMMON::Utils::GetBase(), p_threads); \ -break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - break; - } + VectorValueTypeDispatch(m_valueType, [&](auto t) + { + using Type = decltype(t); + SPTAG::COMMON::Utils::BatchNormalize(reinterpret_cast(m_data.Data()), m_vectorCount, m_dimension, SPTAG::COMMON::Utils::GetBase(), p_threads); + }); } \ No newline at end of file diff --git a/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp b/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp index 970bbb551..c0d7bc8fd 100644 --- a/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp +++ b/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp @@ -222,20 +222,11 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, } bool parseSuccess = false; - switch (m_options->m_inputValueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: - parseSuccess = false; - break; - } + VectorValueTypeDispatch(m_options->m_inputValueType, [&](auto t) + { + using Type = decltype(t); + parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); + }); if (!parseSuccess) { diff --git a/AnnService/src/IndexSearcher/main.cpp b/AnnService/src/IndexSearcher/main.cpp index 8497f831f..2d9526223 100644 --- a/AnnService/src/IndexSearcher/main.cpp +++ b/AnnService/src/IndexSearcher/main.cpp @@ -386,17 +386,10 @@ int main(int argc, char** argv) vecIndex->UpdateIndex(); - switch (options->m_inputValueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - Process(options, *(vecIndex.get())); \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType + VectorValueTypeDispatch(options->m_inputValueType, [&](auto t) + { + Process(options, *(vecIndex.get())); + }); - default: break; - } return 0; } diff --git a/AnnService/src/Quantizer/main.cpp b/AnnService/src/Quantizer/main.cpp index 512586902..0d150d032 100644 --- a/AnnService/src/Quantizer/main.cpp +++ b/AnnService/src/Quantizer/main.cpp @@ -99,16 +99,11 @@ int main(int argc, char* argv[]) std::shared_ptr quantized_vectors = std::make_shared(PQ_vector_array, VectorValueType::UInt8, options->m_quantizedDim, set->Count()); LOG(Helper::LogLevel::LL_Info, "Quantizer Does not exist. Training a new one.\n"); - switch (options->m_inputValueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - quantizer.reset(new COMMON::PQQuantizer(options->m_quantizedDim, 256, (DimensionType)(options->m_dimension/options->m_quantizedDim), false, TrainPQQuantizer(options, set, quantized_vectors))); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - } + VectorValueTypeDispatch(options->m_inputValueType, [&](auto t) + { + using Type = decltype(t); + quantizer.reset(new COMMON::PQQuantizer(options->m_quantizedDim, 256, (DimensionType)(options->m_dimension / options->m_quantizedDim), false, TrainPQQuantizer(options, set, quantized_vectors))); + }); auto ptr = SPTAG::f_createIO(); if (ptr != nullptr && ptr->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::out)) diff --git a/AnnService/src/SSDServing/main.cpp b/AnnService/src/SSDServing/main.cpp index 7aa372b58..8f48519d3 100644 --- a/AnnService/src/SSDServing/main.cpp +++ b/AnnService/src/SSDServing/main.cpp @@ -107,13 +107,8 @@ namespace SPTAG { SPANN::Options* opts = nullptr; -#define DefineVectorValueType(Name, Type) \ - if (index->GetVectorValueType() == VectorValueType::Name) { \ - opts = ((SPANN::Index*)index.get())->GetOptions(); \ - } \ + VectorValueTypeDispatch(index->GetVectorValueType(), [&](auto t) { opts = ((SPANN::Index*)index.get())->GetOptions(); }); -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType if (opts == nullptr) { LOG(Helper::LogLevel::LL_Error, "Cannot get options.\n"); @@ -149,26 +144,25 @@ namespace SPTAG { omp_set_num_threads(opts->m_iSSDNumberOfThreads); -#define DefineVectorValueType(Name, Type) \ - if (opts->m_valueType == VectorValueType::Name) { \ - COMMON::TruthSet::GenerateTruth(querySet, vectorSet, opts->m_truthPath, \ - distCalcMethod, opts->m_resultNum, opts->m_truthType, index->m_pQuantizer); \ - } \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType + VectorValueTypeDispatch(opts->m_valueType, [&](auto t) + { + COMMON::TruthSet::GenerateTruth(querySet, + vectorSet, + opts->m_truthPath, + distCalcMethod, + opts->m_resultNum, + opts->m_truthType, + index->m_pQuantizer); + }); LOG(Helper::LogLevel::LL_Info, "End generating truth.\n"); } if (searchSSD) { -#define DefineVectorValueType(Name, Type) \ - if (opts->m_valueType == VectorValueType::Name) { \ - SSDIndex::Search((SPANN::Index*)(index.get())); \ - } \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType + VectorValueTypeDispatch(opts->m_valueType, [&](auto t) + { + SSDIndex::Search((SPANN::Index*)(index.get())); + }); } return 0; } diff --git a/AnnService/src/Server/SearchExecutionContext.cpp b/AnnService/src/Server/SearchExecutionContext.cpp index 45d83ec43..88a7c5a82 100644 --- a/AnnService/src/Server/SearchExecutionContext.cpp +++ b/AnnService/src/Server/SearchExecutionContext.cpp @@ -82,19 +82,12 @@ SearchExecutionContext::ExtractVector(VectorValueType p_targetType) { if (!m_queryParser.GetVectorElements().empty()) { - switch (p_targetType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return ConvertVectorFromString(m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); \ - break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: - break; - } + ErrorCode err; + VectorValueTypeDispatch(p_targetType, [&](auto t) + { + err = ConvertVectorFromString(m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); + }); + return err; } else if (m_queryParser.GetVectorBase64() != nullptr && m_queryParser.GetVectorBase64Length() != 0) diff --git a/Test/src/ReconstructIndexSimilarityTest.cpp b/Test/src/ReconstructIndexSimilarityTest.cpp index 2da00493b..59e11fc09 100644 --- a/Test/src/ReconstructIndexSimilarityTest.cpp +++ b/Test/src/ReconstructIndexSimilarityTest.cpp @@ -165,7 +165,7 @@ void GenerateReconstructData(std::shared_ptr& real_vecset, std::share if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) { BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); } - quantizer->LoadIQuantizer(ptr); + quantizer = COMMON::IQuantizer::LoadIQuantizer(ptr); BOOST_ASSERT(quantizer); std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); @@ -247,7 +247,7 @@ void GenerateReconstructData(std::shared_ptr& real_vecset, std::share if (!ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) { BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); } - quantizer->LoadIQuantizer(ptr); + quantizer = COMMON::IQuantizer::LoadIQuantizer(ptr); BOOST_ASSERT(quantizer); rec_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(R) * n * m), GetEnumValueType(), m, n)); diff --git a/Test/src/SSDServingTest.cpp b/Test/src/SSDServingTest.cpp index cba53866c..1dfcc71f8 100644 --- a/Test/src/SSDServingTest.cpp +++ b/Test/src/SSDServingTest.cpp @@ -73,19 +73,9 @@ void GenerateVectors(std::string fileName, SPTAG::SizeType rows, SPTAG::Dimensio } -void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) { - switch (vecType) - { -#define DefineVectorValueType(Name, Type) \ -case SPTAG::VectorValueType::Name: \ -GenerateVectors(vectorsName, rows, dims, vecFileType); \ -break; \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - break; - } +void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) +{ + VectorValueTypeDispatch(vecType, [&](auto t) {GenerateVectors(vectorsName, rows, dims, vecFileType); }); } std::string CreateBaseConfig(SPTAG::VectorValueType p_valueType, SPTAG::DistCalcMethod p_distCalcMethod,