diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java index c3ddc809c4d3..eaf4c6127d54 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSCodec.java @@ -36,7 +36,10 @@ public CuVSCodec(String name, Codec delegate) { super(name, delegate); KnnVectorsFormat format; try { - format = new CuVSVectorsFormat(1, 128, 64, MergeStrategy.NON_TRIVIAL_MERGE, IndexType.CAGRA); + boolean useHNSW = Boolean.parseBoolean(System.getProperty("lucene.cuvs.hnsw", "true")); + format = + new CuVSVectorsFormat( + 1, 128, 64, MergeStrategy.NON_TRIVIAL_MERGE, IndexType.CAGRA, useHNSW); setKnnFormat(format); } catch (LibraryException ex) { Logger log = Logger.getLogger(CuVSCodec.class.getName()); diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java index e0d4678aa5fe..ed7c812dd66e 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsFormat.java @@ -48,6 +48,7 @@ public class CuVSVectorsFormat extends KnnVectorsFormat { public static final int DEFAULT_GRAPH_DEGREE = 64; public static final MergeStrategy DEFAULT_MERGE_STRATEGY = MergeStrategy.NON_TRIVIAL_MERGE; public static final IndexType DEFAULT_INDEX_TYPE = IndexType.CAGRA; + public static final boolean DEFAULT_USE_HNSW = true; static CuVSResources resources = cuVSResourcesOrNull(); @@ -61,6 +62,7 @@ public class CuVSVectorsFormat extends KnnVectorsFormat { final int graphDegree; final MergeStrategy mergeStrategy; final CuVSVectorsWriter.IndexType indexType; // the index type to build, when writing + final boolean useHNSW; /** * Creates a CuVSVectorsFormat, with default values. @@ -73,7 +75,8 @@ public CuVSVectorsFormat() { DEFAULT_INTERMEDIATE_GRAPH_DEGREE, DEFAULT_GRAPH_DEGREE, DEFAULT_MERGE_STRATEGY, - DEFAULT_INDEX_TYPE); + DEFAULT_INDEX_TYPE, + DEFAULT_USE_HNSW); } /** @@ -86,15 +89,22 @@ public CuVSVectorsFormat( int intGraphDegree, int graphDegree, MergeStrategy mergeStrategy, - IndexType indexType) { + IndexType indexType, + boolean useHNSW) + throws LibraryException { super("CuVSVectorsFormat"); this.mergeStrategy = mergeStrategy; this.cuvsWriterThreads = cuvsWriterThreads; this.intGraphDegree = intGraphDegree; this.graphDegree = graphDegree; this.indexType = indexType; + this.useHNSW = useHNSW; } + public boolean isHNSWEnabled() { + return useHNSW; +} + private static CuVSResources cuVSResourcesOrNull() { try { resources = CuVSResources.create(); @@ -133,14 +143,15 @@ public CuVSVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio mergeStrategy, indexType, resources, - flatWriter); + flatWriter, + useHNSW); } @Override public CuVSVectorsReader fieldsReader(SegmentReadState state) throws IOException { checkSupported(); var flatReader = flatVectorsFormat.fieldsReader(state); - return new CuVSVectorsReader(state, resources, flatReader); + return new CuVSVectorsReader(state, resources, flatReader, useHNSW); } @Override @@ -156,6 +167,7 @@ public String toString() { sb.append("graphDegree=").append(graphDegree); sb.append("mergeStrategy=").append(mergeStrategy); sb.append("resources=").append(resources); + sb.append("useHNSW=").append(useHNSW); sb.append(")"); return sb.toString(); } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java index cfb59121e36e..4505a88eb1a9 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsReader.java @@ -69,6 +69,7 @@ public class CuVSVectorsReader extends KnnVectorsReader { private static final Logger log = Logger.getLogger(CuVSVectorsReader.class.getName()); private final CuVSResources resources; + private final boolean useHNSW; private final FlatVectorsReader flatVectorsReader; // for reading the raw vectors private final FieldInfos fieldInfos; private final IntObjectHashMap fields; @@ -76,11 +77,13 @@ public class CuVSVectorsReader extends KnnVectorsReader { private final IndexInput cuvsIndexInput; public CuVSVectorsReader( - SegmentReadState state, CuVSResources resources, FlatVectorsReader flatReader) + SegmentReadState state, CuVSResources resources, FlatVectorsReader flatReader, boolean useHNSW) throws IOException { this.resources = resources; this.flatVectorsReader = flatReader; this.fieldInfos = state.fieldInfos; + this.useHNSW = Boolean.getBoolean("lucene.cuvs.hnsw"); + log.info("CuVSVectorsReader initialized. useHNSW=" + this.useHNSW); this.fields = new IntObjectHashMap<>(); String metaFileName = @@ -239,6 +242,8 @@ private IntObjectHashMap loadCuVSIndices() throws IOException { } private CuVSIndex loadCuVSIndex(FieldEntry fieldEntry) throws IOException { + log.info("Loading CuVS index for field: "); + CagraIndex cagraIndex = null; BruteForceIndex bruteForceIndex = null; HnswIndex hnswIndex = null; @@ -263,7 +268,8 @@ private CuVSIndex loadCuVSIndex(FieldEntry fieldEntry) throws IOException { } len = fieldEntry.hnswIndexLength(); - if (len > 0) { + if (useHNSW && len > 0) { + log.info("Attempting to load HNSW index."); long off = fieldEntry.hnswIndexOffset(); try (var slice = cuvsIndexInput.slice("hnsw index", off, len); var in = new IndexInputInputStream(slice)) { @@ -350,7 +356,23 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits assert topK > 0 : "Expected topK > 0, got:" + topK; Map result; - if (knnCollector.k() <= 1024 && cuvsIndex.getCagraIndex() != null) { + if(useHNSW && cuvsIndex.getHNSWIndex() != null) { + log.info("Searching with HNSW index"); + var hnswQuery = new com.nvidia.cuvs.HnswQuery.Builder() + .withQueryVectors(new float[][] { target }) + .withTopK(knnCollector.k()) + .build(); + List> searchResult = null; + try { + searchResult = cuvsIndex.getHNSWIndex().search(hnswQuery).getResults(); + }catch (Throwable t) { + handleThrowable(t); + } + + assert searchResult.size() == 1; + result = searchResult.getFirst(); + } + else if (knnCollector.k() <= 1024 && cuvsIndex.getCagraIndex() != null) { // log.info("searching cagra index"); CagraSearchParams searchParams = new CagraSearchParams.Builder(resources) diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java index 61f77ee26e7c..a206051b149b 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/vectorsearch/CuVSVectorsWriter.java @@ -129,6 +129,8 @@ public boolean hnsw() { return hnsw; } } + + private final boolean useHNSW; public CuVSVectorsWriter( SegmentWriteState state, @@ -138,7 +140,8 @@ public CuVSVectorsWriter( MergeStrategy mergeStrategy, IndexType indexType, CuVSResources resources, - FlatVectorsWriter flatVectorsWriter) + FlatVectorsWriter flatVectorsWriter, + boolean useHNSW) throws IOException { super(); this.mergeStrategy = mergeStrategy; @@ -148,6 +151,8 @@ public CuVSVectorsWriter( this.graphDegree = graphDegree; this.resources = resources; this.flatVectorsWriter = flatVectorsWriter; + this.useHNSW = Boolean.getBoolean("lucene.cuvs.hnsw"); + log.info("CuVSVectorsWriter initialized. useHNSW=" + this.useHNSW); this.infoStream = state.infoStream; String metaFileName = @@ -261,6 +266,14 @@ private void writeBruteForceIndex(OutputStream os, float[][] vectors) throws Thr } private void writeHNSWIndex(OutputStream os, float[][] vectors) throws Throwable { + if (!useHNSW) { // Skip HNSW writing if disabled + log.warning("Skipping HNSW indexing because useHNSW is false."); + return; + } + + if (vectors.length == 0) { + log.warning("HNSW indexing failed because no vectors were provided."); + } if (vectors.length < 2) { throw new IllegalArgumentException(vectors.length + " vectors, less than min [2] required"); } @@ -347,7 +360,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws I } hnswIndexOffset = cuvsIndex.getFilePointer(); - if (indexType.hnsw()) { + if (useHNSW && indexType.hnsw()) { var hnswIndexOutputStream = new IndexOutputOutputStream(cuvsIndex); if (vectors.length > MIN_CAGRA_INDEX_SIZE) { try {