Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ public CuVSCodec(String name, Codec delegate) {
super(name, delegate);
KnnVectorsFormat format;
try {
format = new CuVSVectorsFormat(1, 128, 64, MergeStrategy.NON_TRIVIAL_MERGE, IndexType.CAGRA);
boolean useHNSW = Boolean.parseBoolean(System.getProperty("lucene.cuvs.hnsw", "true"));
format =
new CuVSVectorsFormat(
1, 128, 64, MergeStrategy.NON_TRIVIAL_MERGE, IndexType.CAGRA, useHNSW);
setKnnFormat(format);
} catch (LibraryException ex) {
Logger log = Logger.getLogger(CuVSCodec.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class CuVSVectorsFormat extends KnnVectorsFormat {
public static final int DEFAULT_GRAPH_DEGREE = 64;
public static final MergeStrategy DEFAULT_MERGE_STRATEGY = MergeStrategy.NON_TRIVIAL_MERGE;
public static final IndexType DEFAULT_INDEX_TYPE = IndexType.CAGRA;
public static final boolean DEFAULT_USE_HNSW = true;

static CuVSResources resources = cuVSResourcesOrNull();

Expand All @@ -61,6 +62,7 @@ public class CuVSVectorsFormat extends KnnVectorsFormat {
final int graphDegree;
final MergeStrategy mergeStrategy;
final CuVSVectorsWriter.IndexType indexType; // the index type to build, when writing
final boolean useHNSW;

/**
* Creates a CuVSVectorsFormat, with default values.
Expand All @@ -73,7 +75,8 @@ public CuVSVectorsFormat() {
DEFAULT_INTERMEDIATE_GRAPH_DEGREE,
DEFAULT_GRAPH_DEGREE,
DEFAULT_MERGE_STRATEGY,
DEFAULT_INDEX_TYPE);
DEFAULT_INDEX_TYPE,
DEFAULT_USE_HNSW);
}

/**
Expand All @@ -86,15 +89,22 @@ public CuVSVectorsFormat(
int intGraphDegree,
int graphDegree,
MergeStrategy mergeStrategy,
IndexType indexType) {
IndexType indexType,
boolean useHNSW)
throws LibraryException {
super("CuVSVectorsFormat");
this.mergeStrategy = mergeStrategy;
this.cuvsWriterThreads = cuvsWriterThreads;
this.intGraphDegree = intGraphDegree;
this.graphDegree = graphDegree;
this.indexType = indexType;
this.useHNSW = useHNSW;
}

public boolean isHNSWEnabled() {
return useHNSW;
}

private static CuVSResources cuVSResourcesOrNull() {
try {
resources = CuVSResources.create();
Expand Down Expand Up @@ -133,14 +143,15 @@ public CuVSVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio
mergeStrategy,
indexType,
resources,
flatWriter);
flatWriter,
useHNSW);
}

@Override
public CuVSVectorsReader fieldsReader(SegmentReadState state) throws IOException {
checkSupported();
var flatReader = flatVectorsFormat.fieldsReader(state);
return new CuVSVectorsReader(state, resources, flatReader);
return new CuVSVectorsReader(state, resources, flatReader, useHNSW);
}

@Override
Expand All @@ -156,6 +167,7 @@ public String toString() {
sb.append("graphDegree=").append(graphDegree);
sb.append("mergeStrategy=").append(mergeStrategy);
sb.append("resources=").append(resources);
sb.append("useHNSW=").append(useHNSW);
sb.append(")");
return sb.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,21 @@ public class CuVSVectorsReader extends KnnVectorsReader {
private static final Logger log = Logger.getLogger(CuVSVectorsReader.class.getName());

private final CuVSResources resources;
private final boolean useHNSW;
private final FlatVectorsReader flatVectorsReader; // for reading the raw vectors
private final FieldInfos fieldInfos;
private final IntObjectHashMap<FieldEntry> fields;
private final IntObjectHashMap<CuVSIndex> cuvsIndices;
private final IndexInput cuvsIndexInput;

public CuVSVectorsReader(
SegmentReadState state, CuVSResources resources, FlatVectorsReader flatReader)
SegmentReadState state, CuVSResources resources, FlatVectorsReader flatReader, boolean useHNSW)
throws IOException {
this.resources = resources;
this.flatVectorsReader = flatReader;
this.fieldInfos = state.fieldInfos;
this.useHNSW = Boolean.getBoolean("lucene.cuvs.hnsw");
log.info("CuVSVectorsReader initialized. useHNSW=" + this.useHNSW);
this.fields = new IntObjectHashMap<>();

String metaFileName =
Expand Down Expand Up @@ -239,6 +242,8 @@ private IntObjectHashMap<CuVSIndex> loadCuVSIndices() throws IOException {
}

private CuVSIndex loadCuVSIndex(FieldEntry fieldEntry) throws IOException {
log.info("Loading CuVS index for field: ");

CagraIndex cagraIndex = null;
BruteForceIndex bruteForceIndex = null;
HnswIndex hnswIndex = null;
Expand All @@ -263,7 +268,8 @@ private CuVSIndex loadCuVSIndex(FieldEntry fieldEntry) throws IOException {
}

len = fieldEntry.hnswIndexLength();
if (len > 0) {
if (useHNSW && len > 0) {
log.info("Attempting to load HNSW index.");
long off = fieldEntry.hnswIndexOffset();
try (var slice = cuvsIndexInput.slice("hnsw index", off, len);
var in = new IndexInputInputStream(slice)) {
Expand Down Expand Up @@ -350,7 +356,23 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
assert topK > 0 : "Expected topK > 0, got:" + topK;

Map<Integer, Float> result;
if (knnCollector.k() <= 1024 && cuvsIndex.getCagraIndex() != null) {
if(useHNSW && cuvsIndex.getHNSWIndex() != null) {
log.info("Searching with HNSW index");
var hnswQuery = new com.nvidia.cuvs.HnswQuery.Builder()
.withQueryVectors(new float[][] { target })
.withTopK(knnCollector.k())
.build();
List<Map<Integer, Float>> searchResult = null;
try {
searchResult = cuvsIndex.getHNSWIndex().search(hnswQuery).getResults();
}catch (Throwable t) {
handleThrowable(t);
}

assert searchResult.size() == 1;
result = searchResult.getFirst();
}
else if (knnCollector.k() <= 1024 && cuvsIndex.getCagraIndex() != null) {
// log.info("searching cagra index");
CagraSearchParams searchParams =
new CagraSearchParams.Builder(resources)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ public boolean hnsw() {
return hnsw;
}
}

private final boolean useHNSW;

public CuVSVectorsWriter(
SegmentWriteState state,
Expand All @@ -138,7 +140,8 @@ public CuVSVectorsWriter(
MergeStrategy mergeStrategy,
IndexType indexType,
CuVSResources resources,
FlatVectorsWriter flatVectorsWriter)
FlatVectorsWriter flatVectorsWriter,
boolean useHNSW)
throws IOException {
super();
this.mergeStrategy = mergeStrategy;
Expand All @@ -148,6 +151,8 @@ public CuVSVectorsWriter(
this.graphDegree = graphDegree;
this.resources = resources;
this.flatVectorsWriter = flatVectorsWriter;
this.useHNSW = Boolean.getBoolean("lucene.cuvs.hnsw");
log.info("CuVSVectorsWriter initialized. useHNSW=" + this.useHNSW);
this.infoStream = state.infoStream;

String metaFileName =
Expand Down Expand Up @@ -261,6 +266,14 @@ private void writeBruteForceIndex(OutputStream os, float[][] vectors) throws Thr
}

private void writeHNSWIndex(OutputStream os, float[][] vectors) throws Throwable {
if (!useHNSW) { // Skip HNSW writing if disabled
log.warning("Skipping HNSW indexing because useHNSW is false.");
return;
}

if (vectors.length == 0) {
log.warning("HNSW indexing failed because no vectors were provided.");
}
if (vectors.length < 2) {
throw new IllegalArgumentException(vectors.length + " vectors, less than min [2] required");
}
Expand Down Expand Up @@ -347,7 +360,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, float[][] vectors) throws I
}

hnswIndexOffset = cuvsIndex.getFilePointer();
if (indexType.hnsw()) {
if (useHNSW && indexType.hnsw()) {
var hnswIndexOutputStream = new IndexOutputOutputStream(cuvsIndex);
if (vectors.length > MIN_CAGRA_INDEX_SIZE) {
try {
Expand Down
Loading