From b78e47c658a1038bef385b93bf748a5a8ebbea83 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Wed, 19 Nov 2025 16:04:44 +0530 Subject: [PATCH 1/3] Add FAISS integration module for Solr - Backport FaissKnnVectorsFormat from Lucene 11.0 to Lucene 10.3.1 - Add FaissCodecFactory and FaissCodec for Solr integration - Add Solr branch_10_0 submodule for test compatibility --- .gitignore | 3 + .gitmodules | 4 + faiss/README.md | 79 +++ faiss/build.gradle | 192 +++++++ faiss/solr | 1 + .../codecs/faiss/FaissKnnVectorsFormat.java | 120 +++++ .../codecs/faiss/FaissKnnVectorsReader.java | 206 ++++++++ .../codecs/faiss/FaissKnnVectorsWriter.java | 238 +++++++++ .../sandbox/codecs/faiss/FaissLibrary.java | 58 +++ .../codecs/faiss/FaissLibraryNativeImpl.java | 429 ++++++++++++++++ .../codecs/faiss/FaissNativeWrapper.java | 481 ++++++++++++++++++ .../codecs/faiss/Java21Compatibility.java | 187 +++++++ .../sandbox/codecs/faiss/package-info.java | 52 ++ .../org/apache/solr/faiss/FaissCodec.java | 102 ++++ .../apache/solr/faiss/FaissCodecFactory.java | 59 +++ .../org/apache/solr/faiss/package-info.java | 21 + .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + .../solr/collection1/conf/schema.xml | 38 ++ .../solr/collection1/conf/solrconfig.xml | 33 ++ .../solr/faiss/TestFaissIntegration.java | 136 +++++ .../configs/collection1/conf/schema.xml | 38 ++ .../configs/collection1/conf/solrconfig.xml | 33 ++ settings.gradle | 2 + 23 files changed, 2513 insertions(+) create mode 100644 .gitmodules create mode 100644 faiss/README.md create mode 100644 faiss/build.gradle create mode 160000 faiss/solr create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java create mode 100644 faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java create mode 100644 faiss/src/main/java/org/apache/solr/faiss/FaissCodec.java create mode 100644 faiss/src/main/java/org/apache/solr/faiss/FaissCodecFactory.java create mode 100644 faiss/src/main/java/org/apache/solr/faiss/package-info.java create mode 100644 faiss/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat create mode 100644 faiss/src/test-files/solr/collection1/conf/schema.xml create mode 100644 faiss/src/test-files/solr/collection1/conf/solrconfig.xml create mode 100644 faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java create mode 100644 faiss/src/test/resources/configs/collection1/conf/schema.xml create mode 100644 faiss/src/test/resources/configs/collection1/conf/solrconfig.xml diff --git a/.gitignore b/.gitignore index ee0b600..506cc5d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,9 +14,12 @@ # Ignore Gradle build output directory build +bin out cluster .gatling logs + +faiss/test-config diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..5ba24ec --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "faiss/solr"] + path = faiss/solr + url = https://github.com/apache/solr.git + branch = branch_10_0 diff --git a/faiss/README.md b/faiss/README.md new file mode 100644 index 0000000..fb136b6 --- /dev/null +++ b/faiss/README.md @@ -0,0 +1,79 @@ +# FAISS Module for Solr + +FAISS integration for Apache Solr. Backported from Lucene 11.0 to work with Lucene 10.3.1. + +## Requirements + +- Java 21+ +- FAISS native library (`libfaiss_c.so` version 1.11.0) +- Solr 9.2.0+ + +## Building + +```bash +./gradlew :faiss:jar +``` + +JAR is created at `faiss/build/libs/solr-faiss.jar`. + +## Installation + +Copy the JAR to Solr's lib directory: + +```bash +cp faiss/build/libs/solr-faiss.jar $SOLR_HOME/server/solr-webapp/webapp/WEB-INF/lib/ +``` + +Set `LD_LIBRARY_PATH` to include the FAISS library directory. + +## Configuration + +### solrconfig.xml + +```xml + + IDMap,HNSW16 + efConstruction=100 + +``` + +### schema.xml + +```xml + + + +``` + +Note: FAISS supports `dot_product` and `euclidean`, but not `cosine`. + +## Testing + +Tests require the Solr branch_10_0 submodule. Clone with `--recursive` or run: + +```bash +git submodule update --init --recursive +./gradlew :faiss:test +``` + +Without the submodule, tests use published Solr 9.2.0 artifacts and may fail due to Lucene version mismatch. + +## Git Submodule + +This module uses a git submodule pointing to Solr branch_10_0 for tests. Only a reference (~200 bytes) is stored, not the actual Solr code. + +To clone with submodule: +```bash +git clone --recursive +``` + +## Troubleshooting + +- **ClassNotFoundException** - Ensure JAR is in Solr's classpath +- **UnsatisfiedLinkError** - Set `LD_LIBRARY_PATH` to include `libfaiss_c.so` +- **LinkageError** - Requires exact FAISS version 1.11.0 +- **Tests fail** - Initialize submodule: `git submodule update --init --recursive` diff --git a/faiss/build.gradle b/faiss/build.gradle new file mode 100644 index 0000000..111731f --- /dev/null +++ b/faiss/build.gradle @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'java' + id 'java-library' +} + +description = 'FAISS plugin for Solr' + +repositories { + mavenCentral() +} + +configurations { + provided +} + +sourceSets { + main { compileClasspath += configurations.provided } +} + +def solrSubmodulePath = new File(projectDir, 'solr') +def solrSubmoduleExists = solrSubmodulePath.exists() && solrSubmodulePath.isDirectory() + +dependencies { + implementation "org.apache.lucene:lucene-core:10.3.1" + implementation "org.apache.lucene:lucene-codecs:10.3.1" + provided "org.apache.solr:solr-core:${solrVersion}" + implementation "org.slf4j:slf4j-api:2.0.5" + + testImplementation 'org.slf4j:slf4j-api:2.0.5' + testImplementation 'org.hamcrest:hamcrest:2.2' + testImplementation 'junit:junit:4.13.2' + testImplementation 'org.mockito:mockito-inline:5.2.0' + testImplementation "org.apache.lucene:lucene-test-framework:10.3.1" + testImplementation "org.apache.lucene:lucene-backward-codecs:10.3.1" + testImplementation "org.apache.lucene:lucene-sandbox:10.3.1" + testImplementation "commons-io:commons-io:2.20.0" + + if (solrSubmoduleExists) { + testImplementation fileTree(dir: "${solrSubmodulePath}/solr/test-framework/build/libs", include: 'solr-test-framework-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) + testImplementation fileTree(dir: "${solrSubmodulePath}/solr/core/build/libs", include: 'solr-core-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) + testImplementation fileTree(dir: "${solrSubmodulePath}/solr/solrj/build/libs", include: 'solr-solrj-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) + testImplementation fileTree(dir: "${solrSubmodulePath}/solr/api/build/libs", include: 'solr-api-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) + testImplementation fileTree(dir: "${solrSubmodulePath}/solr/solrj-zookeeper/build/libs", include: 'solr-solrj-zookeeper-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) + + testImplementation("org.apache.zookeeper:zookeeper:3.9.4") { + exclude group: "org.apache.yetus", module: "audience-annotations" + } + testImplementation("org.apache.zookeeper:zookeeper-jute:3.9.4") { + exclude group: "org.apache.yetus", module: "audience-annotations" + } + testImplementation("org.apache.curator:curator-client:5.9.0") { + exclude group: 'org.apache.zookeeper', module: 'zookeeper' + } + testImplementation("org.apache.curator:curator-framework:5.9.0") { + exclude group: 'org.apache.zookeeper', module: 'zookeeper' + } + testImplementation("org.apache.curator:curator-test:5.9.0") { + exclude group: 'org.apache.zookeeper', module: 'zookeeper' + exclude group: 'com.google.guava', module: 'guava' + exclude group: 'io.dropwizard.metrics', module: 'metrics-core' + } + testImplementation "jakarta.servlet:jakarta.servlet-api:6.0.0" + testImplementation "org.eclipse.jetty:jetty-server:12.0.27" + testImplementation "org.eclipse.jetty:jetty-session:12.0.27" + testImplementation "org.eclipse.jetty.ee10:jetty-ee10-servlet:12.0.27" + testImplementation "org.eclipse.jetty:jetty-util:12.0.27" + testImplementation "org.eclipse.jetty:jetty-client:12.0.27" + testImplementation "org.eclipse.jetty:jetty-alpn-server:12.0.27" + testImplementation "org.eclipse.jetty:jetty-alpn-java-server:12.0.27" + testImplementation "org.eclipse.jetty:jetty-rewrite:12.0.27" + testImplementation "org.eclipse.jetty.http2:jetty-http2-server:12.0.27" + testImplementation "org.eclipse.jetty.http2:jetty-http2-common:12.0.27" + testImplementation "org.slf4j:slf4j-api:2.0.17" + testImplementation "org.apache.logging.log4j:log4j-api:2.21.0" + testImplementation "org.apache.logging.log4j:log4j-core:2.21.0" + testImplementation "io.dropwizard.metrics:metrics-core:4.2.26" + testImplementation "io.dropwizard.metrics:metrics-jetty12-ee10:4.2.26" + testImplementation "commons-cli:commons-cli:1.10.0" + testImplementation "org.apache.httpcomponents:httpclient:4.5.14" + testImplementation "org.apache.httpcomponents:httpcore:4.4.16" + testImplementation "org.apache.httpcomponents:httpmime:4.5.14" + testImplementation "io.opentelemetry:opentelemetry-api:1.53.0" + testImplementation("io.opentelemetry:opentelemetry-exporter-prometheus:1.50.0-alpha") { + transitive = false + } + testImplementation "io.prometheus:prometheus-metrics-model:1.1.0" + testImplementation("io.opentelemetry:opentelemetry-sdk:1.53.0") { + exclude group: "io.opentelemetry", module: "opentelemetry-sdk-logs" + } + testImplementation("io.prometheus:prometheus-metrics-exposition-formats:1.1.0") { + exclude group: "io.prometheus", module: "prometheus-metrics-shaded-protobuf" + exclude group: "io.prometheus", module: "prometheus-metrics-config" + } + testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:2.8.3" + testImplementation "org.hamcrest:hamcrest:3.0" + } else { + testImplementation group: 'org.apache.solr', name: 'solr-core', version: "${solrVersion}", { + exclude group: "org.eclipse.jetty", module: "jetty-http" + exclude group: "org.eclipse.jetty", module: "jetty-server" + exclude group: "org.eclipse.jetty", module: "jetty-servlet" + } + testImplementation group: 'org.apache.solr', name: 'solr-test-framework', version: "${solrVersion}" + testImplementation group: 'org.apache.solr', name: 'solr-solrj', version: "${solrVersion}" + testImplementation group: 'org.apache.solr', name: 'solr-solrj-zookeeper', version: "${solrVersion}" + } +} + +java { + sourceCompatibility = JavaVersion.VERSION_21 + targetCompatibility = JavaVersion.VERSION_21 +} + +def currentJavaVersion = JavaVersion.current() +if (currentJavaVersion == JavaVersion.VERSION_21) { + tasks.withType(JavaCompile).configureEach { + options.compilerArgs += ['--enable-preview'] + } +} + +jar { + archiveBaseName.set('solr-faiss') +} + +task buildSolrSubmodule(type: Exec) { + description = 'Builds Solr submodule to generate test framework JARs' + group = 'build' + onlyIf { solrSubmoduleExists } + workingDir = solrSubmodulePath + commandLine = ['sh', '-c', './gradlew :solr:test-framework:jar :solr:core:jar :solr:solrj:jar :solr:api:jar :solr:solrj-zookeeper:jar --no-daemon'] + doFirst { + logger.lifecycle("Building Solr submodule at ${solrSubmodulePath.absolutePath}") + } +} + +test.doFirst { + if (!solrSubmoduleExists) { + def separator = "=".multiply(70) + logger.warn("") + logger.warn(separator) + logger.warn("WARNING: Solr submodule not found at ${solrSubmodulePath.absolutePath}") + logger.warn("") + logger.warn("Tests will use published Solr 9.2.0 artifacts (Lucene 9.4.2)") + logger.warn("which may cause test failures due to Lucene version mismatch.") + logger.warn("") + logger.warn("To fix this, initialize the submodule:") + logger.warn(" git submodule update --init --recursive") + logger.warn("") + logger.warn("Or clone the repository with:") + logger.warn(" git clone --recursive") + logger.warn("") + logger.warn("Production code will still compile and work without the submodule.") + logger.warn(separator) + logger.warn("") + } +} + +test.dependsOn buildSolrSubmodule +compileTestJava.dependsOn buildSolrSubmodule + +test { + jvmArgs '-Djava.security.egd=file:/dev/./urandom' + if (currentJavaVersion == JavaVersion.VERSION_21) { + jvmArgs '--enable-preview' + } + if (System.getenv('LD_LIBRARY_PATH') != null) { + jvmArgs "-Djava.library.path=${System.getenv('LD_LIBRARY_PATH')}" + } else if (System.getenv('CONDA_PREFIX') != null) { + def condaLibPath = "${System.getenv('CONDA_PREFIX')}/lib" + jvmArgs "-Djava.library.path=${condaLibPath}" + } + if (solrSubmoduleExists) { + jvmArgs "-Dtests.src.home=${solrSubmodulePath.absolutePath}" + } + testClassesDirs = sourceSets.test.output.classesDirs + enabled = true +} diff --git a/faiss/solr b/faiss/solr new file mode 160000 index 0000000..12f6f48 --- /dev/null +++ b/faiss/solr @@ -0,0 +1 @@ +Subproject commit 12f6f483cfbd45b5ff6bc7c9709a1b4c8a9cb266 diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java new file mode 100644 index 0000000..8cef8d2 --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsFormat.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN; + +import java.io.IOException; +import java.util.Locale; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * A Faiss-based format to create and search vector indexes, using {@link FaissLibrary} to interact + * with the native library. + * + *

The Faiss index is configured using its flexible index factory, which + * allows creating arbitrary indexes by "describing" them. These indexes can be tuned by setting + * relevant parameters. + * + *

A separate Faiss index is created per-segment, and uses the following files: + * + *

    + *
  • .faissm (metadata file): stores field number, offset and length of actual + * Faiss index in data file. + *
  • .faissd (data file): stores concatenated Faiss indexes for all fields. + *
  • All files required by {@link Lucene99FlatVectorsFormat} for storing raw vectors. + *
+ * + *

Note: Set the {@code $OMP_NUM_THREADS} environment variable to control internal + * threading. + * + *

TODO: There is no guarantee of backwards compatibility! + * + * @lucene.experimental + */ +public final class FaissKnnVectorsFormat extends KnnVectorsFormat { + public static final String NAME = FaissKnnVectorsFormat.class.getSimpleName(); + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = NAME + "Meta"; + static final String DATA_CODEC_NAME = NAME + "Data"; + static final String META_EXTENSION = "faissm"; + static final String DATA_EXTENSION = "faissd"; + + private final String description; + private final String indexParams; + private final FlatVectorsFormat rawVectorsFormat; + + /** + * Constructs an HNSW-based format using default {@code maxConn}={@value + * org.apache.lucene.util.hnsw.HnswGraphBuilder#DEFAULT_MAX_CONN} and {@code beamWidth}={@value + * org.apache.lucene.util.hnsw.HnswGraphBuilder#DEFAULT_BEAM_WIDTH}. + */ + public FaissKnnVectorsFormat() { + this( + String.format(Locale.ROOT, "IDMap,HNSW%d", DEFAULT_MAX_CONN), + String.format(Locale.ROOT, "efConstruction=%d", DEFAULT_BEAM_WIDTH)); + } + + /** + * Constructs a format using the specified index factory string and index parameters (see class + * docs for more information). + * + * @param description the index factory string to initialize Faiss indexes. + * @param indexParams the index params to set on Faiss indexes. + */ + public FaissKnnVectorsFormat(String description, String indexParams) { + super(NAME); + this.description = description; + this.indexParams = indexParams; + this.rawVectorsFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new FaissKnnVectorsWriter( + description, indexParams, state, rawVectorsFormat.fieldsWriter(state)); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new FaissKnnVectorsReader(state, rawVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return DEFAULT_MAX_DIMENSIONS; + } + + @Override + public String toString() { + return String.format( + Locale.ROOT, "%s(description=%s indexParams=%s)", NAME, description, indexParams); + } +} diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java new file mode 100644 index 0000000..c192639 --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsReader.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_START; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.IOUtils; + +/** + * Read per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsReader extends KnnVectorsReader { + private final FlatVectorsReader rawVectorsReader; + private final IndexInput data; + private final Map indexMap; + private boolean closed; + + public FaissKnnVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.rawVectorsReader = rawVectorsReader; + this.indexMap = new HashMap<>(); + + List fieldMetaList = new ArrayList<>(); + String metaFileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, META_EXTENSION); + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + int versionMeta = -1; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + META_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + + FieldMeta fieldMeta; + while ((fieldMeta = parseNextField(meta, state)) != null) { + fieldMetaList.add(fieldMeta); + } + } catch (Throwable t) { + priorE = t; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + + String dataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, DATA_EXTENSION); + this.data = + state.directory.openInput( + dataFileName, state.context.withHints(FileTypeHint.DATA, DataAccessHint.RANDOM)); + + int versionData = + CodecUtil.checkIndexHeader( + this.data, + DATA_CODEC_NAME, + VERSION_START, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionData) { + throw new CorruptIndexException( + String.format( + Locale.ROOT, + "Format versions mismatch (meta=%d, data=%d)", + versionMeta, + versionData), + data); + } + CodecUtil.retrieveChecksum(data); + + for (FieldMeta fieldMeta : fieldMetaList) { + if (indexMap.containsKey(fieldMeta.name)) { + throw new CorruptIndexException("Duplicate field: " + fieldMeta.name, meta); + } + IndexInput indexInput = data.slice(fieldMeta.name, fieldMeta.offset, fieldMeta.length); + FaissLibrary.Index index = FaissLibrary.INSTANCE.readIndex(indexInput); + indexMap.put(fieldMeta.name, index); + } + } catch (Throwable t) { + IOUtils.closeWhileHandlingException(this); + throw t; + } + } + + private static FieldMeta parseNextField(IndexInput meta, SegmentReadState state) + throws IOException { + int fieldNumber = meta.readInt(); + if (fieldNumber == -1) { + return null; + } + + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldNumber); + if (fieldInfo == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + + long dataOffset = meta.readLong(); + long dataLength = meta.readLong(); + + return new FieldMeta(fieldInfo.name, dataOffset, dataLength); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + // TODO: Evaluate if we need an explicit check for validity of Faiss indexes + CodecUtil.checksumEntireFile(data); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) { + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + } + + @Override + public void search( + String field, float[] vector, KnnCollector knnCollector, AcceptDocs acceptDocs) { + FaissLibrary.Index index = indexMap.get(field); + if (index != null) { + index.search(vector, knnCollector, acceptDocs); + } + } + + @Override + public void search( + String field, byte[] vector, KnnCollector knnCollector, AcceptDocs acceptDocs) { + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + // TODO: How to estimate Faiss usage? + return rawVectorsReader.getOffHeapByteSize(fieldInfo); + } + + @Override + public void close() throws IOException { + if (closed == false) { + // Close all indexes + for (FaissLibrary.Index index : indexMap.values()) { + index.close(); + } + indexMap.clear(); + + IOUtils.close(rawVectorsReader, data); + closed = true; + } + } + + private record FieldMeta(String name, long offset, long length) {} +} diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java new file mode 100644 index 0000000..ce15c3d --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissKnnVectorsWriter.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.DATA_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_CODEC_NAME; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.META_EXTENSION; +import static org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat.VERSION_CURRENT; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Write per-segment Faiss indexes and associated metadata. + * + * @lucene.experimental + */ +final class FaissKnnVectorsWriter extends KnnVectorsWriter { + private final String description, indexParams; + private final FlatVectorsWriter rawVectorsWriter; + private final IndexOutput meta, data; + private final Map> rawFields; + private boolean finished; + + public FaissKnnVectorsWriter( + String description, + String indexParams, + SegmentWriteState state, + FlatVectorsWriter rawVectorsWriter) + throws IOException { + + this.description = description; + this.indexParams = indexParams; + this.rawVectorsWriter = rawVectorsWriter; + this.rawFields = new HashMap<>(); + this.finished = false; + + try { + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, META_EXTENSION); + this.meta = state.directory.createOutput(metaFileName, state.context); + CodecUtil.writeIndexHeader( + this.meta, + META_CODEC_NAME, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + + String dataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, DATA_EXTENSION); + this.data = state.directory.createOutput(dataFileName, state.context); + CodecUtil.writeIndexHeader( + this.data, + DATA_CODEC_NAME, + VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileHandlingException(this); + throw t; + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorsWriter.mergeOneField(fieldInfo, mergeState); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + case FLOAT32 -> { + FloatVectorValues merged = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + writeFloatField(fieldInfo, merged, doc -> doc); + } + } + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawFieldVectorsWriter = rawVectorsWriter.addField(fieldInfo); + rawFields.put(fieldInfo, rawFieldVectorsWriter); + return rawFieldVectorsWriter; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorsWriter.flush(maxDoc, sortMap); + for (Map.Entry> entry : rawFields.entrySet()) { + FieldInfo fieldInfo = entry.getKey(); + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> + // TODO: Support using SQ8 quantization, see: + // - https://github.com/opensearch-project/k-NN/pull/2425 + throw new UnsupportedOperationException("Byte vectors not supported"); + + case FLOAT32 -> { + @SuppressWarnings("unchecked") + FlatFieldVectorsWriter rawWriter = + (FlatFieldVectorsWriter) entry.getValue(); + + List vectors = rawWriter.getVectors(); + int dimension = fieldInfo.getVectorDimension(); + DocIdSet docIdSet = rawWriter.getDocsWithFieldSet(); + + writeFloatField( + fieldInfo, + new BufferedFloatVectorValues(vectors, dimension, docIdSet), + (sortMap != null) ? sortMap::oldToNew : doc -> doc); + } + } + } + } + + private void writeFloatField( + FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IntToIntFunction oldToNewDocId) + throws IOException { + int number = fieldInfo.number; + meta.writeInt(number); + + try (FaissLibrary.Index index = + FaissLibrary.INSTANCE.createIndex( + description, + indexParams, + fieldInfo.getVectorSimilarityFunction(), + floatVectorValues, + oldToNewDocId)) { + + // Write index + long dataOffset = data.getFilePointer(); + index.write(data); + long dataLength = data.getFilePointer() - dataOffset; + + meta.writeLong(dataOffset); + meta.writeLong(dataLength); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("Already finished"); + } + finished = true; + + rawVectorsWriter.finish(); + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + CodecUtil.writeFooter(data); + } + + @Override + public void close() throws IOException { + IOUtils.close(rawVectorsWriter, meta, data); + } + + @Override + public long ramBytesUsed() { + // TODO: How to estimate Faiss usage? + return rawVectorsWriter.ramBytesUsed(); + } + + private static class BufferedFloatVectorValues extends FloatVectorValues { + private final List floats; + private final int dimension; + private final DocIdSet docIdSet; + + public BufferedFloatVectorValues(List floats, int dimension, DocIdSet docIdSet) { + this.floats = floats; + this.dimension = dimension; + this.docIdSet = docIdSet; + } + + @Override + public float[] vectorValue(int ord) { + return floats.get(ord); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return floats.size(); + } + + @Override + public FloatVectorValues copy() { + throw new AssertionError("Should not be called"); + } + + @Override + public DocIndexIterator iterator() { + try { + return fromDISI(docIdSet.iterator()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java new file mode 100644 index 0000000..ad0715b --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibrary.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import java.io.Closeable; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * Minimal interface to create and query Faiss indexes. + * + * @lucene.experimental + */ +interface FaissLibrary { + FaissLibrary INSTANCE = new FaissLibraryNativeImpl(); + + // TODO: Use SIMD version at runtime. The "faiss_c" library is linked to the main "faiss" library, + // which does not use SIMD instructions. However, there are SIMD versions of "faiss" (like + // "faiss_avx2", "faiss_avx512", "faiss_sve", etc.) available, which can be used by changing the + // dependencies of "faiss_c" using the "patchelf" utility. Figure out how to do this dynamically, + // or via modifications to upstream Faiss. + String NAME = "faiss_c"; + String VERSION = "1.11.0"; + + interface Index extends Closeable { + void search(float[] query, KnnCollector knnCollector, AcceptDocs acceptDocs); + + void write(IndexOutput output); + } + + Index createIndex( + String description, + String indexParams, + VectorSimilarityFunction function, + FloatVectorValues floatVectorValues, + IntToIntFunction oldToNewDocId); + + Index readIndex(IndexInput input); +} diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java new file mode 100644 index 0000000..aa6ae96 --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.sandbox.codecs.faiss.FaissNativeWrapper.Exception.handleException; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.lang.foreign.Arena; +import java.util.Locale; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.IntToIntFunction; + +/** + * A native implementation of {@link FaissLibrary} using {@link FaissNativeWrapper}. + * + * @lucene.experimental + */ +@SuppressWarnings("restricted") // uses unsafe calls +final class FaissLibraryNativeImpl implements FaissLibrary { + private final FaissNativeWrapper wrapper; + + FaissLibraryNativeImpl() { + this.wrapper = new FaissNativeWrapper(); + } + + private static MemorySegment getStub( + Arena arena, MethodHandle target, FunctionDescriptor descriptor) { + return Linker.nativeLinker().upcallStub(target, descriptor, arena); + } + + private static final int BUFFER_SIZE = 256 * 1024 * 1024; // 256 MB + + @SuppressWarnings("unused") // called using a MethodHandle + private static long writeBytes( + IndexOutput output, MemorySegment inputPointer, long itemSize, long numItems) + throws IOException { + long size = itemSize * numItems; + inputPointer = inputPointer.reinterpret(size); + + if (size <= BUFFER_SIZE) { // simple case, avoid buffering + output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size); + } else { // copy buffered number of bytes repeatedly + byte[] bytes = new byte[BUFFER_SIZE]; + for (long offset = 0; offset < size; offset += BUFFER_SIZE) { + int length = (int) Math.min(size - offset, BUFFER_SIZE); + MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length); + output.writeBytes(bytes, length); + } + } + return numItems; + } + + @SuppressWarnings("unused") // called using a MethodHandle + private static long readBytes( + IndexInput input, MemorySegment outputPointer, long itemSize, long numItems) + throws IOException { + long size = itemSize * numItems; + outputPointer = outputPointer.reinterpret(size); + + if (size <= BUFFER_SIZE) { // simple case, avoid buffering + byte[] bytes = new byte[(int) size]; + input.readBytes(bytes, 0, bytes.length); + MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length); + } else { // copy buffered number of bytes repeatedly + byte[] bytes = new byte[BUFFER_SIZE]; + for (long offset = 0; offset < size; offset += BUFFER_SIZE) { + int length = (int) Math.min(size - offset, BUFFER_SIZE); + input.readBytes(bytes, 0, length); + MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length); + } + } + return numItems; + } + + private static final MethodHandle WRITE_BYTES_HANDLE; + private static final MethodHandle READ_BYTES_HANDLE; + + static { + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + + WRITE_BYTES_HANDLE = + lookup.findStatic( + FaissLibraryNativeImpl.class, + "writeBytes", + MethodType.methodType( + long.class, IndexOutput.class, MemorySegment.class, long.class, long.class)); + + READ_BYTES_HANDLE = + lookup.findStatic( + FaissLibraryNativeImpl.class, + "readBytes", + MethodType.methodType( + long.class, IndexInput.class, MemorySegment.class, long.class, long.class)); + + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new LinkageError( + "FaissLibraryNativeImpl reader / writer functions are missing or inaccessible", e); + } + } + + private static final Map FUNCTION_TO_METRIC = + Map.of( + // Mapped from faiss/MetricType.h + DOT_PRODUCT, 0, + EUCLIDEAN, 1); + + private static int functionToMetric(VectorSimilarityFunction function) { + Integer metric = FUNCTION_TO_METRIC.get(function); + if (metric == null) { + throw new UnsupportedOperationException("Similarity function not supported: " + function); + } + return metric; + } + + // Invert FUNCTION_TO_METRIC + private static final Map METRIC_TO_FUNCTION = + FUNCTION_TO_METRIC.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + private static VectorSimilarityFunction metricToFunction(int metric) { + VectorSimilarityFunction function = METRIC_TO_FUNCTION.get(metric); + if (function == null) { + throw new UnsupportedOperationException("Metric not supported: " + metric); + } + return function; + } + + @Override + public FaissLibrary.Index createIndex( + String description, + String indexParams, + VectorSimilarityFunction function, + FloatVectorValues floatVectorValues, + IntToIntFunction oldToNewDocId) { + + try (Arena temp = Arena.ofConfined()) { + int size = floatVectorValues.size(); + int dimension = floatVectorValues.dimension(); + int metric = functionToMetric(function); + + // Create an index + MemorySegment pointer = temp.allocate(ADDRESS); + handleException( + wrapper.faiss_index_factory(pointer, dimension, Java21Compatibility.allocateFrom(temp, description), metric)); + + MemorySegment indexPointer = pointer.get(ADDRESS, 0); + + // Set index params + handleException(wrapper.faiss_ParameterSpace_new(pointer)); + MemorySegment parameterSpacePointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_ParameterSpace_free); + + handleException( + wrapper.faiss_ParameterSpace_set_index_parameters( + parameterSpacePointer, indexPointer, Java21Compatibility.allocateFrom(temp, indexParams))); + + // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see: + // - https://github.com/opensearch-project/k-NN/issues/1506 + // - https://github.com/opensearch-project/k-NN/issues/1938 + + // Allocate docs in native memory + MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension); + long docsOffset = 0; + long perDocByteSize = dimension * JAVA_FLOAT.byteSize(); + + // Allocate ids in native memory + MemorySegment ids = temp.allocate(JAVA_LONG, size); + int idsIndex = 0; + + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + int actualSize = 0; + for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) { + int id = oldToNewDocId.apply(i); + ids.setAtIndex(JAVA_LONG, idsIndex, id); + idsIndex++; + + float[] vector = floatVectorValues.vectorValue(iterator.index()); + MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length); + docsOffset += perDocByteSize; + actualSize++; + } + + // Verify we got the expected number of vectors + if (actualSize != size) { + throw new IllegalStateException( + "Vector count mismatch: expected " + size + " but got " + actualSize); + } + + // Train index + int isTrained = wrapper.faiss_Index_is_trained(indexPointer); + if (isTrained == 0) { + handleException(wrapper.faiss_Index_train(indexPointer, size, docs)); + } + + // Add docs to index + handleException(wrapper.faiss_Index_add_with_ids(indexPointer, size, docs, ids)); + + return new Index(indexPointer); + + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + // See flags defined in c_api/index_io_c.h + private static final int FAISS_IO_FLAG_MMAP = 1; + private static final int FAISS_IO_FLAG_READ_ONLY = 2; + + @Override + public FaissLibrary.Index readIndex(IndexInput input) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle readerHandle = READ_BYTES_HANDLE.bindTo(input); + MemorySegment readerStub = + getStub( + temp, readerHandle, FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); + + MemorySegment pointer = temp.allocate(ADDRESS); + handleException(wrapper.faiss_CustomIOReader_new(pointer, readerStub)); + MemorySegment customIOReaderPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_CustomIOReader_free); + + // Read index + handleException( + wrapper.faiss_read_index_custom( + customIOReaderPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY, pointer)); + MemorySegment indexPointer = pointer.get(ADDRESS, 0); + + return new Index(indexPointer); + } + } + + private class Index implements FaissLibrary.Index { + @FunctionalInterface + private interface FloatToFloatFunction { + float scale(float score); + } + + private final Arena arena; + private final MemorySegment indexPointer; + private final FloatToFloatFunction scaler; + private boolean closed; + + private Index(MemorySegment indexPointer) { + this.arena = Arena.ofShared(); + this.indexPointer = + indexPointer + // Ensure timely cleanup + .reinterpret(arena, wrapper::faiss_Index_free); + + // Get underlying function + int metricType = wrapper.faiss_Index_metric_type(indexPointer); + VectorSimilarityFunction function = metricToFunction(metricType); + + // Scale Faiss distances to Lucene scores, see VectorSimilarityFunction.java + this.scaler = + switch (function) { + case DOT_PRODUCT -> + // distance in Faiss === dotProduct in Lucene + distance -> Math.max((1 + distance) / 2, 0); + + case EUCLIDEAN -> + // distance in Faiss === squareDistance in Lucene + distance -> 1 / (1 + distance); + + case COSINE, MAXIMUM_INNER_PRODUCT -> throw new AssertionError("Should not reach here"); + }; + + this.closed = false; + } + + @Override + public void close() { + if (closed == false) { + arena.close(); + closed = true; + } + } + + @Override + public void search(float[] query, KnnCollector knnCollector, AcceptDocs acceptDocs) { + try (Arena temp = Arena.ofConfined()) { + FixedBitSet fixedBitSet = + switch (acceptDocs.bits()) { + case null -> null; + case FixedBitSet bitSet -> bitSet; + // TODO: Add optimized case for SparseFixedBitSet + case Bits bits -> FixedBitSet.copyOf(bits); + }; + + // Allocate queries in native memory + MemorySegment queries = Java21Compatibility.allocateFrom(temp, JAVA_FLOAT, query); + + // Faiss knn search + int k = knnCollector.k(); + MemorySegment distancesPointer = temp.allocate(JAVA_FLOAT, k); + MemorySegment idsPointer = temp.allocate(JAVA_LONG, k); + + MemorySegment localIndex = indexPointer.reinterpret(temp, null); + if (fixedBitSet == null) { + // Search without runtime filters + handleException( + wrapper.faiss_Index_search(localIndex, 1, queries, k, distancesPointer, idsPointer)); + } else { + MemorySegment pointer = temp.allocate(ADDRESS); + + long[] bits = fixedBitSet.getBits(); + MemorySegment nativeBits = + // Use LITTLE_ENDIAN to convert long[] -> uint8_t* + Java21Compatibility.allocateFrom(temp, JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits); + + handleException( + wrapper.faiss_IDSelectorBitmap_new(pointer, fixedBitSet.length(), nativeBits)); + MemorySegment idSelectorBitmapPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_IDSelectorBitmap_free); + + handleException(wrapper.faiss_SearchParameters_new(pointer, idSelectorBitmapPointer)); + MemorySegment searchParametersPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_SearchParameters_free); + + // Search with runtime filters + handleException( + wrapper.faiss_Index_search_with_params( + localIndex, + 1, + queries, + k, + searchParametersPointer, + distancesPointer, + idsPointer)); + } + + // Record hits + int numResults = 0; + for (int i = 0; i < k; i++) { + int id = (int) idsPointer.getAtIndex(JAVA_LONG, i); + + // Not enough results + if (id == -1) { + break; + } + + // Collect result + float distance = distancesPointer.getAtIndex(JAVA_FLOAT, i); + knnCollector.collect(id, scaler.scale(distance)); + numResults++; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void write(IndexOutput output) { + try (Arena temp = Arena.ofConfined()) { + MethodHandle writerHandle = WRITE_BYTES_HANDLE.bindTo(output); + MemorySegment writerStub = + getStub( + temp, + writerHandle, + FunctionDescriptor.of(JAVA_LONG, ADDRESS, JAVA_LONG, JAVA_LONG)); + + MemorySegment pointer = temp.allocate(ADDRESS); + handleException(wrapper.faiss_CustomIOWriter_new(pointer, writerStub)); + MemorySegment customIOWriterPointer = + pointer + .get(ADDRESS, 0) + // Ensure timely cleanup + .reinterpret(temp, wrapper::faiss_CustomIOWriter_free); + + // Write index + handleException( + wrapper.faiss_write_index_custom( + indexPointer, customIOWriterPointer, FAISS_IO_FLAG_MMAP | FAISS_IO_FLAG_READ_ONLY)); + } + } + } +} diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java new file mode 100644 index 0000000..4620467 --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; + +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.invoke.MethodHandle; +import java.util.Arrays; +import java.util.Locale; + +/** + * Utility class to wrap necessary functions of the native C API of Faiss + * using Project Panama. + * + * @lucene.experimental + */ +@SuppressWarnings("restricted") // uses unsafe calls +final class FaissNativeWrapper { + static { + System.loadLibrary(FaissLibrary.NAME); + } + + private static MethodHandle getHandle(String functionName, FunctionDescriptor descriptor) { + // Java 21 (JEP 442 preview) uses libraryLookup(), Java 22+ (JEP 454 final) uses loaderLookup() + // Use runtime detection to support both + SymbolLookup lookup; + int javaVersion = Runtime.version().feature(); + if (javaVersion >= 22) { + // Java 22+ final API + lookup = SymbolLookup.loaderLookup(); + } else { + // Java 21 preview API - use reflection since libraryLookup() doesn't exist in Java 22+ + try { + java.lang.foreign.Arena arena = java.lang.foreign.Arena.global(); + java.lang.reflect.Method libraryLookupMethod = + SymbolLookup.class.getMethod("libraryLookup", String.class, java.lang.foreign.Arena.class); + lookup = (SymbolLookup) libraryLookupMethod.invoke(null, FaissLibrary.NAME, arena); + } catch (java.lang.reflect.InvocationTargetException + | IllegalAccessException + | NoSuchMethodException e) { + throw new RuntimeException("Failed to create SymbolLookup for Java 21", e); + } + } + // Use compatibility layer for findOrThrow() vs find() + MemorySegment symbol = Java21Compatibility.findSymbol(lookup, functionName); + return Linker.nativeLinker().downcallHandle(symbol, descriptor); + } + + FaissNativeWrapper() { + // Check Faiss version + String expectedVersion = FaissLibrary.VERSION; + String actualVersion = Java21Compatibility.getString(faiss_get_version().reinterpret(Long.MAX_VALUE), 0); + + if (expectedVersion.equals(actualVersion) == false) { + throw new LinkageError( + String.format( + Locale.ROOT, + "Expected Faiss library version %s, found %s", + expectedVersion, + actualVersion)); + } + } + + private final MethodHandle faiss_get_version$MH = + getHandle("faiss_get_version", FunctionDescriptor.of(ADDRESS)); + + MemorySegment faiss_get_version() { + try { + return (MemorySegment) faiss_get_version$MH.invokeExact(); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOReader_free$MH = + getHandle("faiss_CustomIOReader_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_CustomIOReader_free(MemorySegment customIOReaderPointer) { + try { + faiss_CustomIOReader_free$MH.invokeExact(customIOReaderPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOReader_new$MH = + getHandle("faiss_CustomIOReader_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + int faiss_CustomIOReader_new(MemorySegment pointer, MemorySegment readerStub) { + try { + return (int) faiss_CustomIOReader_new$MH.invokeExact(pointer, readerStub); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOWriter_free$MH = + getHandle("faiss_CustomIOWriter_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_CustomIOWriter_free(MemorySegment customIOWriterPointer) { + try { + faiss_CustomIOWriter_free$MH.invokeExact(customIOWriterPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_CustomIOWriter_new$MH = + getHandle("faiss_CustomIOWriter_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + int faiss_CustomIOWriter_new(MemorySegment pointer, MemorySegment writerStub) { + try { + return (int) faiss_CustomIOWriter_new$MH.invokeExact(pointer, writerStub); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_IDSelectorBitmap_free$MH = + getHandle("faiss_IDSelectorBitmap_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_IDSelectorBitmap_free(MemorySegment idSelectorBitmapPointer) { + try { + faiss_IDSelectorBitmap_free$MH.invokeExact(idSelectorBitmapPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_IDSelectorBitmap_new$MH = + getHandle( + "faiss_IDSelectorBitmap_new", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); + + int faiss_IDSelectorBitmap_new(MemorySegment pointer, long length, MemorySegment bitmapPointer) { + try { + return (int) faiss_IDSelectorBitmap_new$MH.invokeExact(pointer, length, bitmapPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_add_with_ids$MH = + getHandle( + "faiss_Index_add_with_ids", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + + int faiss_Index_add_with_ids( + MemorySegment indexPointer, long size, MemorySegment docsPointer, MemorySegment idsPointer) { + try { + return (int) + faiss_Index_add_with_ids$MH.invokeExact(indexPointer, size, docsPointer, idsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_free$MH = + getHandle("faiss_Index_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_Index_free(MemorySegment indexPointer) { + try { + faiss_Index_free$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_is_trained$MH = + getHandle("faiss_Index_is_trained", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_Index_is_trained(MemorySegment indexPointer) { + try { + return (int) faiss_Index_is_trained$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_metric_type$MH = + getHandle("faiss_Index_metric_type", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_Index_metric_type(MemorySegment indexPointer) { + try { + return (int) faiss_Index_metric_type$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_ntotal$MH = + getHandle("faiss_Index_ntotal", FunctionDescriptor.of(JAVA_LONG, ADDRESS)); + + long faiss_Index_ntotal(MemorySegment indexPointer) { + try { + return (long) faiss_Index_ntotal$MH.invokeExact(indexPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_search$MH = + getHandle( + "faiss_Index_search", + FunctionDescriptor.of( + JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS)); + + int faiss_Index_search( + MemorySegment indexPointer, + long numQueries, + MemorySegment queriesPointer, + long k, + MemorySegment distancesPointer, + MemorySegment idsPointer) { + try { + return (int) + faiss_Index_search$MH.invokeExact( + indexPointer, numQueries, queriesPointer, k, distancesPointer, idsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_search_with_params$MH = + getHandle( + "faiss_Index_search_with_params", + FunctionDescriptor.of( + JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS, JAVA_LONG, ADDRESS, ADDRESS, ADDRESS)); + + int faiss_Index_search_with_params( + MemorySegment indexPointer, + long numQueries, + MemorySegment queriesPointer, + long k, + MemorySegment searchParametersPointer, + MemorySegment distancesPointer, + MemorySegment idsPointer) { + try { + return (int) + faiss_Index_search_with_params$MH.invokeExact( + indexPointer, + numQueries, + queriesPointer, + k, + searchParametersPointer, + distancesPointer, + idsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_Index_train$MH = + getHandle("faiss_Index_train", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS)); + + int faiss_Index_train(MemorySegment indexPointer, long size, MemorySegment docsPointer) { + try { + return (int) faiss_Index_train$MH.invokeExact(indexPointer, size, docsPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_ParameterSpace_free$MH = + getHandle("faiss_ParameterSpace_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_ParameterSpace_free(MemorySegment parameterSpacePointer) { + try { + faiss_ParameterSpace_free$MH.invokeExact(parameterSpacePointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_ParameterSpace_new$MH = + getHandle("faiss_ParameterSpace_new", FunctionDescriptor.of(JAVA_INT, ADDRESS)); + + int faiss_ParameterSpace_new(MemorySegment pointer) { + try { + return (int) faiss_ParameterSpace_new$MH.invokeExact(pointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_ParameterSpace_set_index_parameters$MH = + getHandle( + "faiss_ParameterSpace_set_index_parameters", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, ADDRESS)); + + int faiss_ParameterSpace_set_index_parameters( + MemorySegment parameterSpacePointer, + MemorySegment indexPointer, + MemorySegment descriptionPointer) { + try { + return (int) + faiss_ParameterSpace_set_index_parameters$MH.invokeExact( + parameterSpacePointer, indexPointer, descriptionPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_SearchParameters_free$MH = + getHandle("faiss_SearchParameters_free", FunctionDescriptor.ofVoid(ADDRESS)); + + void faiss_SearchParameters_free(MemorySegment searchParametersPointer) { + try { + faiss_SearchParameters_free$MH.invokeExact(searchParametersPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_SearchParameters_new$MH = + getHandle("faiss_SearchParameters_new", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS)); + + int faiss_SearchParameters_new(MemorySegment pointer, MemorySegment idSelectorBitmapPointer) { + try { + return (int) faiss_SearchParameters_new$MH.invokeExact(pointer, idSelectorBitmapPointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_index_factory$MH = + getHandle( + "faiss_index_factory", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT)); + + int faiss_index_factory( + MemorySegment pointer, int dimension, MemorySegment description, int metric) { + try { + return (int) faiss_index_factory$MH.invokeExact(pointer, dimension, description, metric); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_read_index_custom$MH = + getHandle( + "faiss_read_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_INT, ADDRESS)); + + int faiss_read_index_custom( + MemorySegment customIOReaderPointer, int ioFlags, MemorySegment pointer) { + try { + return (int) faiss_read_index_custom$MH.invokeExact(customIOReaderPointer, ioFlags, pointer); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private final MethodHandle faiss_write_index_custom$MH = + getHandle( + "faiss_write_index_custom", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + + int faiss_write_index_custom( + MemorySegment indexPointer, MemorySegment customIOWriterPointer, int ioFlags) { + try { + return (int) + faiss_write_index_custom$MH.invokeExact(indexPointer, customIOWriterPointer, ioFlags); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + /** + * Exception used to rethrow handled Faiss errors in native code. + * + * @lucene.experimental + */ + static class Exception extends RuntimeException { + // See error codes defined in c_api/error_c.h + enum ErrorCode { + /// No error + OK(0), + /// Any exception other than Faiss or standard C++ library exceptions + UNKNOWN_EXCEPT(-1), + /// Faiss library exception + FAISS_EXCEPT(-2), + /// Standard C++ library exception + STD_EXCEPT(-4); + + private final int code; + + ErrorCode(int code) { + this.code = code; + } + + static ErrorCode fromCode(int code) { + return Arrays.stream(ErrorCode.values()) + .filter(errorCode -> errorCode.code == code) + .findFirst() + .orElseThrow(); + } + } + + private Exception(int code) { + super( + String.format( + Locale.ROOT, + "%s[%s(%d)]", + Exception.class.getName(), + ErrorCode.fromCode(code), + code)); + } + + static void handleException(int code) { + if (code < 0) { + throw new Exception(code); + } + } + } +} diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java new file mode 100644 index 0000000..f45e729 --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.faiss; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.Optional; + +/** + * Compatibility layer for Java 21 and Java 22+ Foreign Function & Memory API differences. + * + * @lucene.experimental + */ +final class Java21Compatibility { + private static final int JAVA_VERSION = Runtime.version().feature(); + private static final boolean IS_JAVA_21 = JAVA_VERSION == 21; + + private static final MethodHandle SYMBOL_LOOKUP_FIND; + private static final MethodHandle MEMORY_SEGMENT_GET_UTF8_STRING; + private static final MethodHandle MEMORY_SEGMENT_GET_STRING; + private static final MethodHandle ARENA_ALLOCATE; + private static final MethodHandle ARENA_ALLOCATE_FROM_STRING; + private static final MethodHandle ARENA_ALLOCATE_FROM_FLOAT; + private static final MethodHandle ARENA_ALLOCATE_FROM_LONG; + + static { + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + + SYMBOL_LOOKUP_FIND = + lookup.findVirtual( + SymbolLookup.class, + "find", + MethodType.methodType(Optional.class, String.class)); + + if (IS_JAVA_21) { + MEMORY_SEGMENT_GET_UTF8_STRING = + lookup.findVirtual( + MemorySegment.class, "getUtf8String", MethodType.methodType(String.class, long.class)); + MEMORY_SEGMENT_GET_STRING = null; + + ARENA_ALLOCATE = + lookup.findVirtual( + Arena.class, "allocate", MethodType.methodType(MemorySegment.class, long.class)); + ARENA_ALLOCATE_FROM_STRING = null; + ARENA_ALLOCATE_FROM_FLOAT = null; + ARENA_ALLOCATE_FROM_LONG = null; + } else { + MEMORY_SEGMENT_GET_UTF8_STRING = null; + MEMORY_SEGMENT_GET_STRING = + lookup.findVirtual( + MemorySegment.class, "getString", MethodType.methodType(String.class, long.class)); + + ARENA_ALLOCATE = null; + ARENA_ALLOCATE_FROM_STRING = + lookup.findVirtual( + Arena.class, + "allocateFrom", + MethodType.methodType(MemorySegment.class, String.class)); + ARENA_ALLOCATE_FROM_FLOAT = + lookup.findVirtual( + Arena.class, + "allocateFrom", + MethodType.methodType(MemorySegment.class, ValueLayout.OfFloat.class, float[].class)); + ARENA_ALLOCATE_FROM_LONG = + lookup.findVirtual( + Arena.class, + "allocateFrom", + MethodType.methodType(MemorySegment.class, ValueLayout.OfLong.class, long[].class)); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException("Failed to initialize Java compatibility layer", e); + } + } + + private Java21Compatibility() {} + + /** + * Find a symbol in SymbolLookup. + * Both Java 21 and Java 22 use find() which returns Optional. + */ + static MemorySegment findSymbol(SymbolLookup lookup, String name) { + try { + @SuppressWarnings("unchecked") + Optional optional = (Optional) SYMBOL_LOOKUP_FIND.invokeExact(lookup, name); + return optional.orElseThrow(() -> new UnsatisfiedLinkError("Symbol not found: " + name)); + } catch (Throwable e) { + if (e instanceof UnsatisfiedLinkError) { + throw (UnsatisfiedLinkError) e; + } + throw new RuntimeException("Failed to find symbol: " + name, e); + } + } + + /** + * Get a UTF-8 string from MemorySegment. + */ + static String getString(MemorySegment segment, long offset) { + try { + if (IS_JAVA_21) { + return (String) MEMORY_SEGMENT_GET_UTF8_STRING.invokeExact(segment, offset); + } else { + return (String) MEMORY_SEGMENT_GET_STRING.invokeExact(segment, offset); + } + } catch (Throwable e) { + throw new RuntimeException("Failed to get string from MemorySegment", e); + } + } + + /** + * Allocate a MemorySegment from a string. + */ + static MemorySegment allocateFrom(Arena arena, String str) { + try { + if (IS_JAVA_21) { + byte[] bytes = str.getBytes(java.nio.charset.StandardCharsets.UTF_8); + MemorySegment segment = (MemorySegment) ARENA_ALLOCATE.invokeExact(arena, (long) (bytes.length + 1)); + MemorySegment.copy( + MemorySegment.ofArray(bytes), ValueLayout.JAVA_BYTE, 0, segment, ValueLayout.JAVA_BYTE, 0, (long) bytes.length); + segment.set(ValueLayout.JAVA_BYTE, bytes.length, (byte) 0); + return segment; + } else { + return (MemorySegment) ARENA_ALLOCATE_FROM_STRING.invokeExact(arena, str); + } + } catch (Throwable e) { + throw new RuntimeException("Failed to allocate string in native memory", e); + } + } + + /** + * Allocate a MemorySegment from a float array. + */ + static MemorySegment allocateFrom(Arena arena, ValueLayout.OfFloat layout, float[] array) { + try { + if (IS_JAVA_21) { + long size = (long) array.length * layout.byteSize(); + MemorySegment segment = (MemorySegment) ARENA_ALLOCATE.invokeExact(arena, size); + MemorySegment.copy( + MemorySegment.ofArray(array), layout, 0, segment, layout, 0, (long) array.length); + return segment; + } else { + return (MemorySegment) ARENA_ALLOCATE_FROM_FLOAT.invokeExact(arena, layout, array); + } + } catch (Throwable e) { + throw new RuntimeException("Failed to allocate float array in native memory", e); + } + } + + /** + * Allocate a MemorySegment from a long array. + */ + static MemorySegment allocateFrom(Arena arena, ValueLayout.OfLong layout, long[] array) { + try { + if (IS_JAVA_21) { + long size = (long) array.length * layout.byteSize(); + MemorySegment segment = (MemorySegment) ARENA_ALLOCATE.invokeExact(arena, size); + MemorySegment.copy( + MemorySegment.ofArray(array), layout, 0, segment, layout, 0, (long) array.length); + return segment; + } else { + return (MemorySegment) ARENA_ALLOCATE_FROM_LONG.invokeExact(arena, layout, array); + } + } catch (Throwable e) { + throw new RuntimeException("Failed to allocate long array in native memory", e); + } + } +} + diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java new file mode 100644 index 0000000..bd4a5b8 --- /dev/null +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/package-info.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Faiss is "a library for efficient + * similarity search and clustering of dense vectors", with support for various vector + * transforms, indexing algorithms, quantization techniques, etc. This package provides a pluggable + * Faiss-based format to perform vector searches in Lucene, via {@link + * org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat}. + * + *

To use this format: Install pytorch/faiss-cpu from Conda and place shared libraries (including + * dependencies) on the {@code $LD_LIBRARY_PATH} environment variable or {@code -Djava.library.path} + * JVM argument. + * + *

Important: Ensure that the license of the Conda distribution and channels is applicable to + * you. pytorch and conda-forge are community-maintained channels with + * permissive licenses! + * + *

Sample setup: + * + *

    + *
  • Install micromamba (an open-source Conda + * package manager) or similar + *
  • Install dependencies using {@code micromamba create -n faiss-env -c pytorch -c conda-forge + * -y faiss-cpu=}{@value org.apache.lucene.sandbox.codecs.faiss.FaissLibrary#VERSION} + *
  • Activate environment using {@code micromamba activate faiss-env} + *
  • Add shared libraries to runtime using {@code export LD_LIBRARY_PATH=$CONDA_PREFIX/lib} + * (verify that the {@value org.apache.lucene.sandbox.codecs.faiss.FaissLibrary#NAME} library + * is present here) + *
  • And you're good to go! (add the {@code -Dtests.faiss.run=true} JVM argument to ensure Faiss + * tests are run) + *
+ * + * @lucene.experimental + */ +package org.apache.lucene.sandbox.codecs.faiss; diff --git a/faiss/src/main/java/org/apache/solr/faiss/FaissCodec.java b/faiss/src/main/java/org/apache/solr/faiss/FaissCodec.java new file mode 100644 index 0000000..82b0a34 --- /dev/null +++ b/faiss/src/main/java/org/apache/solr/faiss/FaissCodec.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.faiss; + +import java.lang.invoke.MethodHandles; +import java.util.Locale; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.schema.DenseVectorField; +import org.apache.solr.schema.FieldType; +import org.apache.solr.schema.SchemaField; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Codec that uses FaissKnnVectorsFormat for FAISS-based vector search. + * + * @since 10.0.0 + */ +public class FaissCodec extends FilterCodec { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private static final String FAISS_ALGORITHM = "faiss"; + private static final String DEFAULT_FAISS_DESCRIPTION = + String.format(Locale.ROOT, "IDMap,HNSW%d", HnswGraphBuilder.DEFAULT_MAX_CONN); + private static final String DEFAULT_FAISS_INDEX_PARAMS = + String.format(Locale.ROOT, "efConstruction=%d", HnswGraphBuilder.DEFAULT_BEAM_WIDTH); + + private final SolrCore core; + private final Codec fallbackCodec; + private final FaissKnnVectorsFormat faissKnnVectorsFormat; + + public FaissCodec(SolrCore core, Codec fallback, NamedList args) { + super(fallback.getName(), fallback); + this.core = core; + this.fallbackCodec = fallback; + + String descriptionStr = (String) args.get("faissDescription"); + String description = descriptionStr != null ? descriptionStr : DEFAULT_FAISS_DESCRIPTION; + String indexParamsStr = (String) args.get("faissIndexParams"); + String indexParams = indexParamsStr != null ? indexParamsStr : DEFAULT_FAISS_INDEX_PARAMS; + + faissKnnVectorsFormat = new FaissKnnVectorsFormat(description, indexParams); + + if (log.isInfoEnabled()) { + log.info( + "FaissKnnVectorsFormat initialized with parameter values: description={}, indexParams={}", + description, + indexParams); + } + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldKnnVectorsFormat; + } + + private PerFieldKnnVectorsFormat perFieldKnnVectorsFormat = + new PerFieldKnnVectorsFormat() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + final SchemaField schemaField = core.getLatestSchema().getFieldOrNull(field); + FieldType fieldType = (schemaField == null ? null : schemaField.getType()); + if (fieldType instanceof DenseVectorField vectorType) { + String knnAlgorithm = vectorType.getKnnAlgorithm(); + if (FAISS_ALGORITHM.equals(knnAlgorithm)) { + return faissKnnVectorsFormat; + } else if (DenseVectorField.HNSW_ALGORITHM.equals(knnAlgorithm)) { + return fallbackCodec.knnVectorsFormat(); + } else { + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, + knnAlgorithm + " KNN algorithm is not supported"); + } + } + return fallbackCodec.knnVectorsFormat(); + } + }; +} + diff --git a/faiss/src/main/java/org/apache/solr/faiss/FaissCodecFactory.java b/faiss/src/main/java/org/apache/solr/faiss/FaissCodecFactory.java new file mode 100644 index 0000000..b43cf23 --- /dev/null +++ b/faiss/src/main/java/org/apache/solr/faiss/FaissCodecFactory.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.faiss; + +import org.apache.lucene.codecs.Codec; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.CodecFactory; +import org.apache.solr.core.SchemaCodecFactory; +import org.apache.solr.core.SolrCore; +import org.apache.solr.util.plugin.SolrCoreAware; + +public class FaissCodecFactory extends CodecFactory implements SolrCoreAware { + + private final SchemaCodecFactory fallback; + private SolrCore core; + private NamedList args; + private Codec fallbackCodec; + private FaissCodec codec; + + public FaissCodecFactory() { + this.fallback = new SchemaCodecFactory(); + } + + @Override + public Codec getCodec() { + if (codec == null) { + fallbackCodec = fallback.getCodec(); + codec = new FaissCodec(core, fallbackCodec, args); + } + return codec; + } + + @Override + public void inform(SolrCore solrCore) { + fallback.inform(solrCore); + this.core = solrCore; + } + + @Override + public void init(NamedList args) { + fallback.init(args); + this.args = args; + } +} + diff --git a/faiss/src/main/java/org/apache/solr/faiss/package-info.java b/faiss/src/main/java/org/apache/solr/faiss/package-info.java new file mode 100644 index 0000000..44e08ad --- /dev/null +++ b/faiss/src/main/java/org/apache/solr/faiss/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains the {@link org.apache.solr.faiss.FaissCodec} to enable FAISS-based vector search + */ +package org.apache.solr.faiss; \ No newline at end of file diff --git a/faiss/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/faiss/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat new file mode 100644 index 0000000..2ab1096 --- /dev/null +++ b/faiss/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -0,0 +1 @@ +org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat diff --git a/faiss/src/test-files/solr/collection1/conf/schema.xml b/faiss/src/test-files/solr/collection1/conf/schema.xml new file mode 100644 index 0000000..6b40dd1 --- /dev/null +++ b/faiss/src/test-files/solr/collection1/conf/schema.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + id + + diff --git a/faiss/src/test-files/solr/collection1/conf/solrconfig.xml b/faiss/src/test-files/solr/collection1/conf/solrconfig.xml new file mode 100644 index 0000000..934f9d6 --- /dev/null +++ b/faiss/src/test-files/solr/collection1/conf/solrconfig.xml @@ -0,0 +1,33 @@ + + + + + + ${tests.luceneMatchVersion:LATEST} + ${solr.data.dir:} + + + + + IDMap,HNSW16 + efConstruction=100 + + + + + diff --git a/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java b/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java new file mode 100644 index 0000000..0f822cf --- /dev/null +++ b/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.faiss; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Random; +import org.apache.commons.io.file.PathUtils; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.tests.mockfile.FilterPath; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.core.SolrConfig; +import org.apache.solr.core.SolrCore; +import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.util.RefCounted; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assume.assumeTrue; + +/** + * Integration test for FAISS module using Solr's test framework. + */ +public class TestFaissIntegration extends SolrTestCaseJ4 { + + private static Random random; + private static final int DATASET_SIZE = 100; + private static final int DATASET_DIMENSION = 8; + private static final int TOPK = 5; + private static final String ID_FIELD = "id"; + private static final String VECTOR_FIELD1 = "vector_field1"; + private static final String VECTOR_FIELD2 = "vector_field2"; + private static final String SOLRCONFIG_XML = "solrconfig.xml"; + private static final String SCHEMA_XML = "schema.xml"; + private static final String COLLECTION = "collection1"; + private static final String CONF_DIR = COLLECTION + "/conf"; + + @BeforeClass + public static void beforeClass() throws Exception { + boolean faissAvailable = false; + try { + System.loadLibrary("faiss_c"); + faissAvailable = true; + } catch (UnsatisfiedLinkError e) { + // FAISS library not available + } + assumeTrue("FAISS native library not available", faissAvailable); + + Path tmpSolrHome = createTempDir(); + Path tmpConfDir = FilterPath.unwrap(tmpSolrHome.resolve(CONF_DIR)); + Path testHomeConfDir = TEST_HOME().resolve(CONF_DIR); + Files.createDirectories(tmpConfDir); + PathUtils.copyFileToDirectory(testHomeConfDir.resolve(SOLRCONFIG_XML), tmpConfDir); + PathUtils.copyFileToDirectory(testHomeConfDir.resolve(SCHEMA_XML), tmpConfDir); + + initCore(SOLRCONFIG_XML, SCHEMA_XML, tmpSolrHome); + random = new Random(222); + } + + @Test + public void testFaissCodecIsLoaded() { + SolrCore solrCore = h.getCore(); + SolrConfig config = solrCore.getSolrConfig(); + String codecFactory = config.get("codecFactory").attr("class"); + assertEquals( + "Unexpected solrconfig codec factory", + "org.apache.solr.faiss.FaissCodecFactory", + codecFactory); + assertEquals("Unexpected core codec", "Lucene103", solrCore.getCodec().getName()); + assertTrue("Codec should be FaissCodec", solrCore.getCodec() instanceof FaissCodec); + } + + @Test + public void testIndexAndSearch() throws IOException { + SolrCore solrCore = h.getCore(); + for (int i = 0; i < DATASET_SIZE; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.addField(ID_FIELD, String.valueOf(i)); + List vector1 = generateRandomVector(random, DATASET_DIMENSION); + List vector2 = generateRandomVector(random, DATASET_DIMENSION); + doc.addField(VECTOR_FIELD1, vector1); + doc.addField(VECTOR_FIELD2, vector2); + assertU(adoc(doc)); + } + assertU(commit()); + + final RefCounted refCountedSearcher = solrCore.getSearcher(); + IndexSearcher searcher = refCountedSearcher.get(); + + float[] queryVector = generateRandomVectorArray(random, DATASET_DIMENSION); + KnnFloatVectorQuery q1 = + new KnnFloatVectorQuery(VECTOR_FIELD1, queryVector, TOPK); + TopDocs results1 = searcher.search(q1, TOPK); + + assertTrue("Should return at least some results", results1.scoreDocs.length > 0); + assertTrue("Should return at most TOPK results", results1.scoreDocs.length <= TOPK); + + refCountedSearcher.decref(); + } + + private static List generateRandomVector(Random random, int dimensions) { + List vector = new java.util.ArrayList<>(); + for (int j = 0; j < dimensions; j++) { + vector.add(random.nextFloat() * 100); + } + return vector; + } + + private static float[] generateRandomVectorArray(Random random, int dimension) { + List vector = generateRandomVector(random, dimension); + float[] query = new float[dimension]; + for (int i = 0; i < dimension; i++) { + query[i] = vector.get(i); + } + return query; + } +} + diff --git a/faiss/src/test/resources/configs/collection1/conf/schema.xml b/faiss/src/test/resources/configs/collection1/conf/schema.xml new file mode 100644 index 0000000..6b40dd1 --- /dev/null +++ b/faiss/src/test/resources/configs/collection1/conf/schema.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + id + + diff --git a/faiss/src/test/resources/configs/collection1/conf/solrconfig.xml b/faiss/src/test/resources/configs/collection1/conf/solrconfig.xml new file mode 100644 index 0000000..934f9d6 --- /dev/null +++ b/faiss/src/test/resources/configs/collection1/conf/solrconfig.xml @@ -0,0 +1,33 @@ + + + + + + ${tests.luceneMatchVersion:LATEST} + ${solr.data.dir:} + + + + + IDMap,HNSW16 + efConstruction=100 + + + + + diff --git a/settings.gradle b/settings.gradle index 39abf37..4f7e3f2 100644 --- a/settings.gradle +++ b/settings.gradle @@ -17,3 +17,5 @@ include 'encryption' include 'gatling-data-prep' include 'gatling-simulations' + +include 'faiss' From 296520a4a81b3ba1b4957c88191dfd2fee454127 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Thu, 20 Nov 2025 12:39:41 +0530 Subject: [PATCH 2/3] Renamed a file and refined build.gradle --- faiss/build.gradle | 9 +++------ .../faiss/{Java21Compatibility.java => FFMUtils.java} | 6 +++--- .../sandbox/codecs/faiss/FaissLibraryNativeImpl.java | 8 ++++---- .../lucene/sandbox/codecs/faiss/FaissNativeWrapper.java | 4 ++-- 4 files changed, 12 insertions(+), 15 deletions(-) rename faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/{Java21Compatibility.java => FFMUtils.java} (97%) diff --git a/faiss/build.gradle b/faiss/build.gradle index 111731f..77d2754 100644 --- a/faiss/build.gradle +++ b/faiss/build.gradle @@ -42,14 +42,13 @@ dependencies { provided "org.apache.solr:solr-core:${solrVersion}" implementation "org.slf4j:slf4j-api:2.0.5" - testImplementation 'org.slf4j:slf4j-api:2.0.5' - testImplementation 'org.hamcrest:hamcrest:2.2' - testImplementation 'junit:junit:4.13.2' - testImplementation 'org.mockito:mockito-inline:5.2.0' testImplementation "org.apache.lucene:lucene-test-framework:10.3.1" testImplementation "org.apache.lucene:lucene-backward-codecs:10.3.1" testImplementation "org.apache.lucene:lucene-sandbox:10.3.1" testImplementation "commons-io:commons-io:2.20.0" + testImplementation "junit:junit:4.13.2" + testImplementation "org.mockito:mockito-inline:5.2.0" + testImplementation "org.hamcrest:hamcrest:3.0" if (solrSubmoduleExists) { testImplementation fileTree(dir: "${solrSubmodulePath}/solr/test-framework/build/libs", include: 'solr-test-framework-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) @@ -86,7 +85,6 @@ dependencies { testImplementation "org.eclipse.jetty:jetty-rewrite:12.0.27" testImplementation "org.eclipse.jetty.http2:jetty-http2-server:12.0.27" testImplementation "org.eclipse.jetty.http2:jetty-http2-common:12.0.27" - testImplementation "org.slf4j:slf4j-api:2.0.17" testImplementation "org.apache.logging.log4j:log4j-api:2.21.0" testImplementation "org.apache.logging.log4j:log4j-core:2.21.0" testImplementation "io.dropwizard.metrics:metrics-core:4.2.26" @@ -108,7 +106,6 @@ dependencies { exclude group: "io.prometheus", module: "prometheus-metrics-config" } testImplementation "com.carrotsearch.randomizedtesting:randomizedtesting-runner:2.8.3" - testImplementation "org.hamcrest:hamcrest:3.0" } else { testImplementation group: 'org.apache.solr', name: 'solr-core', version: "${solrVersion}", { exclude group: "org.eclipse.jetty", module: "jetty-http" diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FFMUtils.java similarity index 97% rename from faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java rename to faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FFMUtils.java index f45e729..a31069d 100644 --- a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/Java21Compatibility.java +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FFMUtils.java @@ -30,7 +30,7 @@ * * @lucene.experimental */ -final class Java21Compatibility { +final class FFMUtils { private static final int JAVA_VERSION = Runtime.version().feature(); private static final boolean IS_JAVA_21 = JAVA_VERSION == 21; @@ -88,11 +88,11 @@ final class Java21Compatibility { MethodType.methodType(MemorySegment.class, ValueLayout.OfLong.class, long[].class)); } } catch (NoSuchMethodException | IllegalAccessException e) { - throw new RuntimeException("Failed to initialize Java compatibility layer", e); + throw new RuntimeException("Failed to initialize FFM utils", e); } } - private Java21Compatibility() {} + private FFMUtils() {} /** * Find a symbol in SymbolLookup. diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java index aa6ae96..35e995f 100644 --- a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissLibraryNativeImpl.java @@ -181,7 +181,7 @@ public FaissLibrary.Index createIndex( // Create an index MemorySegment pointer = temp.allocate(ADDRESS); handleException( - wrapper.faiss_index_factory(pointer, dimension, Java21Compatibility.allocateFrom(temp, description), metric)); + wrapper.faiss_index_factory(pointer, dimension, FFMUtils.allocateFrom(temp, description), metric)); MemorySegment indexPointer = pointer.get(ADDRESS, 0); @@ -195,7 +195,7 @@ public FaissLibrary.Index createIndex( handleException( wrapper.faiss_ParameterSpace_set_index_parameters( - parameterSpacePointer, indexPointer, Java21Compatibility.allocateFrom(temp, indexParams))); + parameterSpacePointer, indexPointer, FFMUtils.allocateFrom(temp, indexParams))); // TODO: Improve memory usage (with a tradeoff in performance) by batched indexing, see: // - https://github.com/opensearch-project/k-NN/issues/1506 @@ -334,7 +334,7 @@ public void search(float[] query, KnnCollector knnCollector, AcceptDocs acceptDo }; // Allocate queries in native memory - MemorySegment queries = Java21Compatibility.allocateFrom(temp, JAVA_FLOAT, query); + MemorySegment queries = FFMUtils.allocateFrom(temp, JAVA_FLOAT, query); // Faiss knn search int k = knnCollector.k(); @@ -352,7 +352,7 @@ public void search(float[] query, KnnCollector knnCollector, AcceptDocs acceptDo long[] bits = fixedBitSet.getBits(); MemorySegment nativeBits = // Use LITTLE_ENDIAN to convert long[] -> uint8_t* - Java21Compatibility.allocateFrom(temp, JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits); + FFMUtils.allocateFrom(temp, JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits); handleException( wrapper.faiss_IDSelectorBitmap_new(pointer, fixedBitSet.length(), nativeBits)); diff --git a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java index 4620467..674ff66 100644 --- a/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java +++ b/faiss/src/main/java/org/apache/lucene/sandbox/codecs/faiss/FaissNativeWrapper.java @@ -63,14 +63,14 @@ private static MethodHandle getHandle(String functionName, FunctionDescriptor de } } // Use compatibility layer for findOrThrow() vs find() - MemorySegment symbol = Java21Compatibility.findSymbol(lookup, functionName); + MemorySegment symbol = FFMUtils.findSymbol(lookup, functionName); return Linker.nativeLinker().downcallHandle(symbol, descriptor); } FaissNativeWrapper() { // Check Faiss version String expectedVersion = FaissLibrary.VERSION; - String actualVersion = Java21Compatibility.getString(faiss_get_version().reinterpret(Long.MAX_VALUE), 0); + String actualVersion = FFMUtils.getString(faiss_get_version().reinterpret(Long.MAX_VALUE), 0); if (expectedVersion.equals(actualVersion) == false) { throw new LinkageError( From ba586e1be043d279091e4aae2020dd0d5baae2b8 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Thu, 20 Nov 2025 13:49:28 +0530 Subject: [PATCH 3/3] Whitespace fixes --- faiss/build.gradle | 6 +++--- .../java/org/apache/solr/faiss/TestFaissIntegration.java | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/faiss/build.gradle b/faiss/build.gradle index 77d2754..c8ede8a 100644 --- a/faiss/build.gradle +++ b/faiss/build.gradle @@ -41,7 +41,7 @@ dependencies { implementation "org.apache.lucene:lucene-codecs:10.3.1" provided "org.apache.solr:solr-core:${solrVersion}" implementation "org.slf4j:slf4j-api:2.0.5" - + testImplementation "org.apache.lucene:lucene-test-framework:10.3.1" testImplementation "org.apache.lucene:lucene-backward-codecs:10.3.1" testImplementation "org.apache.lucene:lucene-sandbox:10.3.1" @@ -49,14 +49,14 @@ dependencies { testImplementation "junit:junit:4.13.2" testImplementation "org.mockito:mockito-inline:5.2.0" testImplementation "org.hamcrest:hamcrest:3.0" - + if (solrSubmoduleExists) { testImplementation fileTree(dir: "${solrSubmodulePath}/solr/test-framework/build/libs", include: 'solr-test-framework-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) testImplementation fileTree(dir: "${solrSubmodulePath}/solr/core/build/libs", include: 'solr-core-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) testImplementation fileTree(dir: "${solrSubmodulePath}/solr/solrj/build/libs", include: 'solr-solrj-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) testImplementation fileTree(dir: "${solrSubmodulePath}/solr/api/build/libs", include: 'solr-api-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) testImplementation fileTree(dir: "${solrSubmodulePath}/solr/solrj-zookeeper/build/libs", include: 'solr-solrj-zookeeper-*.jar', excludes: ['*sources*.jar', '*javadoc*.jar']) - + testImplementation("org.apache.zookeeper:zookeeper:3.9.4") { exclude group: "org.apache.yetus", module: "audience-annotations" } diff --git a/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java b/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java index 0f822cf..0415cd9 100644 --- a/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java +++ b/faiss/src/test/java/org/apache/solr/faiss/TestFaissIntegration.java @@ -63,7 +63,7 @@ public static void beforeClass() throws Exception { // FAISS library not available } assumeTrue("FAISS native library not available", faissAvailable); - + Path tmpSolrHome = createTempDir(); Path tmpConfDir = FilterPath.unwrap(tmpSolrHome.resolve(CONF_DIR)); Path testHomeConfDir = TEST_HOME().resolve(CONF_DIR); @@ -109,7 +109,7 @@ public void testIndexAndSearch() throws IOException { KnnFloatVectorQuery q1 = new KnnFloatVectorQuery(VECTOR_FIELD1, queryVector, TOPK); TopDocs results1 = searcher.search(q1, TOPK); - + assertTrue("Should return at least some results", results1.scoreDocs.length > 0); assertTrue("Should return at most TOPK results", results1.scoreDocs.length <= TOPK);