diff --git a/README.md b/README.md index 52984f1..1c6674f 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Rust](https://img.shields.io/badge/rust-1.80%2B-orange.svg)](https://www.rust-lang.org) [![Tests](https://github.com/mmailhos/vectorlite/actions/workflows/rust.yml/badge.svg?branch=main)](https://github.com/mmailhos/vectorlite/actions) +[![OpenAPI](https://img.shields.io/badge/OpenAPI-3.0.3-green.svg)](docs/openapi.yaml) **A tiny, in-process Rust vector store with built-in embeddings for sub-millisecond semantic search.** @@ -61,21 +62,25 @@ docker build \ ``` ## HTTP API Overview + | Operation | Method & Endpoint | Body | | --------------------- | ----------------------------------------- | ------------------------------------------------------------------ | | **Health** | `GET /health` | – | | **List collections** | `GET /collections` | – | -| **Create collection** | `POST /collections` | `{"name": "docs", "index_type": "hnsw"}` | +| **Create collection** | `POST /collections` | `{"name": "docs", "index_type": "hnsw", "metric": "cosine"}`| | **Delete collection** | `DELETE /collections/{name}` | – | | **Add text** | `POST /collections/{name}/text` | `{"text": "Hello world", "metadata": {...}}`| -| **Search (text)** | `POST /collections/{name}/search/text` | `{"query": "hello", "k": 5}` | +| **Search (text)** | `POST /collections/{name}/search/text` | `{"query": "hello", "k": 5}` | | **Get vector** | `GET /collections/{name}/vectors/{id}` | – | | **Delete vector** | `DELETE /collections/{name}/vectors/{id}` | – | | **Save collection** | `POST /collections/{name}/save` | `{"file_path": "./collection.vlc"}` | | **Load collection** | `POST /collections/load` | `{"file_path": "./collection.vlc", "collection_name": "restored"}` | + ## Index Types +VectorLite supports 2 indexes: **Flat** and **HNSW**. + | Index | Search Complexity | Insert | Use Case | | -------- | ----------------- | -------- | ------------------------------------- | | **Flat** | O(n) | O(1) | Small datasets (<10K) or exact search | @@ -97,6 +102,9 @@ cargo build --features memory-optimized ### Similarity Metrics + +A flat index is the most flexible as it allows for all search metric operations. On the other hand, the HNSW index is specifically optimised for a specific distance metric, which will be used for all search operations. When creating a HNSW index, provide a `metric` value with one of: `cosine`, `euclidean`, `manhattan` or `dotproduct`. + - **Cosine**: Default for normalized embeddings, scale-invariant - **Euclidean**: Geometric distance, sensitive to vector magnitude - **Manhattan**: L1 norm, robust to outliers @@ -109,9 +117,9 @@ use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType, SimilarityMetr use serde_json::json; fn main() -> Result<(), Box> { - let client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); + let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); - client.create_collection("quotes", IndexType::HNSW)?; + client.create_collection("quotes", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; let id = client.add_text_to_collection( "quotes", @@ -123,11 +131,12 @@ fn main() -> Result<(), Box> { })) )?; + // Metric optional - auto-detected from HNSW index let results = client.search_text_in_collection( "quotes", "beach games", 3, - SimilarityMetric::Cosine, + None, )?; for result in &results { diff --git a/docs/openapi.yaml b/docs/openapi.yaml new file mode 100644 index 0000000..3736909 --- /dev/null +++ b/docs/openapi.yaml @@ -0,0 +1,839 @@ +openapi: 3.0.3 +info: + title: VectorLite API + description: | + A high-performance, in-memory vector database optimized for AI agent and edge workloads. + VectorLite provides sub-millisecond semantic search with built-in embedding generation. + + ## Features + + - **Sub-millisecond search**: In-memory HNSW or flat search + - **Built-in embeddings**: Runs all-MiniLM-L6-v2 locally using Candle + - **Thread-safe**: Concurrent read access with atomic ID generation + - **Persistence**: Save and restore collections to/from disk + - **Flexible metrics**: Cosine, Euclidean, Manhattan, and Dot Product similarity + + ## Index Types + + | Index | Search Complexity | Use Case | + |-------|------------------|----------| + | **Flat** | O(n) | Small datasets (<10K) or exact search | + | **HNSW** | O(log n) | Larger datasets or approximate search | + + ## Similarity Metrics + + | Metric | Description | Range | + |--------|-------------|-------| + | **Cosine** | Scale-invariant, good for normalized embeddings | [-1, 1] | + | **Euclidean** | Geometric distance, sensitive to magnitude | [0, 1] | + | **Manhattan** | L1 norm, robust to outliers | [0, 1] | + | **Dot Product** | Raw similarity, requires consistent scaling | unbounded | + version: 0.1.5 + contact: + name: VectorLite Support + url: https://github.com/mmailhos/vectorlite + +servers: + - url: http://localhost:3001 + description: Local development server + - url: http://localhost:3002 + description: Alternative port + +tags: + - name: Health + description: Health check endpoints + - name: Collections + description: Collection management operations + - name: Vectors + description: Vector operations (add, search, get, delete) + - name: Persistence + description: Save and load collection operations + +paths: + /health: + get: + tags: + - Health + summary: Health check + description: Returns the health status of the server + operationId: healthCheck + responses: + '200': + description: Server is healthy + content: + application/json: + schema: + type: object + properties: + status: + type: string + example: healthy + service: + type: string + example: vectorlite + + /collections: + get: + tags: + - Collections + summary: List all collections + description: Returns a list of all collection names + operationId: listCollections + responses: + '200': + description: List of collections + content: + application/json: + schema: + $ref: '#/components/schemas/ListCollectionsResponse' + + post: + tags: + - Collections + summary: Create a new collection + description: Creates a new collection with the specified index type and similarity metric + operationId: createCollection + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CreateCollectionRequest' + examples: + flat_index: + summary: Create a flat index + value: + name: small_docs + index_type: flat + metric: "" + hnsw_cosine: + summary: Create HNSW index with cosine metric + value: + name: large_docs + index_type: hnsw + metric: cosine + hnsw_euclidean: + summary: Create HNSW index with euclidean metric + value: + name: geo_docs + index_type: hnsw + metric: euclidean + responses: + '200': + description: Collection created successfully + content: + application/json: + schema: + $ref: '#/components/schemas/CreateCollectionResponse' + '400': + description: Bad request (invalid index type, metric, or missing required metric) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + examples: + invalid_index: + value: + message: "Invalid index type: invalid. Must be 'flat' or 'hnsw'" + invalid_metric: + value: + message: "Invalid similarity metric: invalid. Must be 'cosine', 'euclidean', 'manhattan', or 'dotproduct'" + metric_required: + value: + message: "HNSW index requires an explicit similarity metric. Add field 'metric' with one of the following: ['cosine', 'euclidean', 'manhattan', 'dotproduct']" + '409': + description: Collection already exists + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + example: + message: "Collection 'docs' already exists" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /collections/{name}: + get: + tags: + - Collections + summary: Get collection information + description: Returns detailed information about a specific collection + operationId: getCollectionInfo + parameters: + - name: name + in: path + required: true + description: The name of the collection + schema: + type: string + example: my_docs + responses: + '200': + description: Collection information + content: + application/json: + schema: + $ref: '#/components/schemas/CollectionInfoResponse' + '404': + description: Collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + example: + message: "Collection 'my_docs' not found" + + delete: + tags: + - Collections + summary: Delete a collection + description: Deletes a collection and all its vectors + operationId: deleteCollection + parameters: + - name: name + in: path + required: true + description: The name of the collection + schema: + type: string + example: my_docs + responses: + '200': + description: Collection deleted successfully + content: + application/json: + schema: + $ref: '#/components/schemas/CreateCollectionResponse' + '404': + description: Collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /collections/{name}/text: + post: + tags: + - Vectors + summary: Add text to collection + description: | + Adds text to a collection. The text is automatically converted to an embedding, + and the vector is added to the collection. Returns the ID of the newly created vector. + + **Note**: The embedding is generated using the all-MiniLM-L6-v2 model (384 dimensions by default). + operationId: addText + parameters: + - name: name + in: path + required: true + description: The name of the collection + schema: + type: string + example: my_docs + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/AddTextRequest' + examples: + simple: + summary: Add text without metadata + value: + text: "Hello, world! This is a sample document." + with_metadata: + summary: Add text with metadata + value: + text: "AI agents are revolutionizing software development" + metadata: + author: "John Doe" + tags: + - ai + - agents + - ml + published: "2024-01-15" + views: 1234 + responses: + '200': + description: Text added successfully + content: + application/json: + schema: + $ref: '#/components/schemas/AddTextResponse' + '404': + description: Collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '400': + description: Bad request (dimension mismatch or embedding generation error) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + examples: + dimension_mismatch: + value: + message: "Vector dimension mismatch: expected 384, got 256" + '409': + description: Duplicate vector ID + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error (e.g., embedding generation failed) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /collections/{name}/search/text: + post: + tags: + - Vectors + summary: Search by text query + description: | + Searches the collection for vectors similar to the query text. + The query is automatically converted to an embedding, then searched using + the collection's index and similarity metric. + operationId: searchText + parameters: + - name: name + in: path + required: true + description: The name of the collection + schema: + type: string + example: my_docs + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SearchTextRequest' + examples: + default: + summary: Search with default settings + value: + query: "machine learning" + k: 5 + custom_metric: + summary: Search with custom similarity metric + value: + query: "artificial intelligence" + k: 10 + similarity_metric: euclidean + max_results: + summary: Get top 20 results + value: + query: "neural networks" + k: 20 + responses: + '200': + description: Search completed successfully + content: + application/json: + schema: + $ref: '#/components/schemas/SearchResponse' + '400': + description: Bad request (invalid similarity metric or metric mismatch) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + examples: + invalid_metric: + value: + message: "Invalid similarity metric: invalid. Must be 'cosine', 'euclidean', 'manhattan', or 'dotproduct'" + metric_mismatch: + value: + message: "Metric mismatch: search requested Euclidean but index was built for Cosine" + '404': + description: Collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /collections/{name}/vectors/{id}: + get: + tags: + - Vectors + summary: Get vector by ID + description: Retrieves a specific vector from a collection by its ID + operationId: getVector + parameters: + - name: name + in: path + required: true + description: The name of the collection + schema: + type: string + example: my_docs + - name: id + in: path + required: true + description: The ID of the vector + schema: + type: integer + format: int64 + example: 123 + responses: + '200': + description: Vector retrieved successfully + content: + application/json: + schema: + type: object + properties: + vector: + $ref: '#/components/schemas/Vector' + '404': + description: Vector or collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + delete: + tags: + - Vectors + summary: Delete a vector + description: Deletes a specific vector from a collection by its ID + operationId: deleteVector + parameters: + - name: name + in: path + required: true + description: The name of the collection + schema: + type: string + example: my_docs + - name: id + in: path + required: true + description: The ID of the vector + schema: + type: integer + format: int64 + example: 123 + responses: + '200': + description: Vector deleted successfully + content: + application/json: + schema: + type: object + '404': + description: Vector or collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /collections/{name}/save: + post: + tags: + - Persistence + summary: Save collection to file + description: | + Saves the entire collection to disk in a binary format (.vlc file). + This includes all vectors, the index structure, and metadata. + + **Note**: The file will be created if it doesn't exist, and overwritten if it does. + operationId: saveCollection + parameters: + - name: name + in: path + required: true + description: The name of the collection to save + schema: + type: string + example: my_docs + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/SaveCollectionRequest' + example: + file_path: "./backups/my_docs.vlc" + responses: + '200': + description: Collection saved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/SaveCollectionResponse' + '404': + description: Collection not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error (e.g., disk I/O failure) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /collections/load: + post: + tags: + - Persistence + summary: Load collection from file + description: | + Loads a collection from a previously saved file (.vlc format). + The collection name can be specified, or it will use the name from the saved file. + + **Note**: This operation will fail if a collection with the same name already exists. + You must delete the existing collection first if you want to replace it. + operationId: loadCollection + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/LoadCollectionRequest' + examples: + load_with_name: + summary: Load with custom collection name + value: + file_path: "./backups/my_docs.vlc" + collection_name: restored_docs + load_without_name: + summary: Load with original collection name + value: + file_path: "./backups/my_docs.vlc" + responses: + '200': + description: Collection loaded successfully + content: + application/json: + schema: + $ref: '#/components/schemas/LoadCollectionResponse' + '404': + description: File not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + example: + message: "File not found: ./backups/my_docs.vlc" + '409': + description: Collection already exists + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + example: + message: "Collection 'restored_docs' already exists" + '500': + description: Internal server error (e.g., invalid file format) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + +components: + schemas: + # Request schemas + CreateCollectionRequest: + type: object + required: + - name + - index_type + properties: + name: + type: string + description: Unique name for the collection + example: my_docs + index_type: + type: string + enum: [flat, hnsw] + description: Type of index to use (flat for exact search, hnsw for approximate) + example: hnsw + metric: + type: string + default: "" + description: Similarity metric (cosine, euclidean, manhattan, dotproduct). Empty string for default. + example: cosine + + AddTextRequest: + type: object + required: + - text + properties: + text: + type: string + description: The text to add to the collection + example: "This is a sample document about machine learning." + metadata: + type: object + description: Optional metadata associated with this text (any JSON object) + additionalProperties: true + example: + author: "John Doe" + category: "machine-learning" + tags: + - ai + - ml + - research + + SearchTextRequest: + type: object + required: + - query + properties: + query: + type: string + description: The search query text + example: "neural networks" + k: + type: integer + format: int32 + default: 10 + description: Number of results to return + minimum: 1 + maximum: 1000 + example: 5 + similarity_metric: + type: string + enum: [cosine, euclidean, manhattan, dotproduct] + description: Override the collection's default similarity metric for this search + example: cosine + + SaveCollectionRequest: + type: object + required: + - file_path + properties: + file_path: + type: string + description: Path where the collection should be saved (including .vlc extension) + example: "./backups/my_docs.vlc" + + LoadCollectionRequest: + type: object + required: + - file_path + properties: + file_path: + type: string + description: Path to the collection file to load + example: "./backups/my_docs.vlc" + collection_name: + type: string + description: Optional name for the loaded collection. If not provided, uses the name from the file. + example: restored_docs + + # Response schemas + CreateCollectionResponse: + type: object + properties: + name: + type: string + description: Name of the created collection + example: my_docs + + AddTextResponse: + type: object + properties: + id: + type: integer + format: int64 + nullable: true + description: ID of the newly added vector + example: 42 + + SearchResponse: + type: object + properties: + results: + type: array + nullable: true + description: Array of search results sorted by similarity (highest first) + items: + $ref: '#/components/schemas/SearchResult' + + SearchResult: + type: object + properties: + id: + type: integer + format: int64 + description: ID of the matching vector + example: 42 + score: + type: number + format: float + description: Similarity score (higher is more similar) + example: 0.9234 + text: + type: string + description: Original text that was embedded + example: "This document is about machine learning algorithms" + metadata: + type: object + nullable: true + description: Metadata associated with this vector (if any) + additionalProperties: true + example: + author: "John Doe" + category: "machine-learning" + + Vector: + type: object + properties: + id: + type: integer + format: int64 + description: Unique identifier for the vector + example: 123 + values: + type: array + items: + type: number + format: float + description: The embedding vector values (typically 384 dimensions for all-MiniLM-L6-v2) + example: [0.1, 0.2, -0.3, 0.4, 0.5, -0.6] + text: + type: string + description: The original text that was embedded to create this vector + example: "Sample document text" + metadata: + type: object + nullable: true + description: Optional metadata associated with this vector (any JSON object) + additionalProperties: true + example: + author: "John Doe" + tags: + - tutorial + - example + + ListCollectionsResponse: + type: object + properties: + collections: + type: array + items: + type: string + description: List of collection names + example: [my_docs, other_collection] + + CollectionInfoResponse: + type: object + properties: + info: + nullable: true + $ref: '#/components/schemas/CollectionInfo' + + CollectionInfo: + type: object + properties: + name: + type: string + description: Name of the collection + example: my_docs + count: + type: integer + format: int32 + description: Number of vectors in the collection + example: 1234 + is_empty: + type: boolean + description: Whether the collection is empty + example: false + dimension: + type: integer + format: int32 + description: Dimension of vectors in this collection (typically 384 for all-MiniLM-L6-v2) + example: 384 + + SaveCollectionResponse: + type: object + properties: + file_path: + type: string + nullable: true + description: Path where the collection was saved + example: "./backups/my_docs.vlc" + + LoadCollectionResponse: + type: object + properties: + collection_name: + type: string + nullable: true + description: Name of the loaded collection + example: restored_docs + + ErrorResponse: + type: object + properties: + message: + type: string + description: Error message + example: "Collection 'my_docs' not found" + + responses: + NotFound: + description: Resource not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + BadRequest: + description: Bad request (validation error) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + Conflict: + description: Resource conflict (e.g., already exists) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + InternalServerError: + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + diff --git a/src/client.rs b/src/client.rs index 19679fb..8a98b14 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,7 +17,7 @@ //! let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); //! //! // Create a collection -//! client.create_collection("documents", IndexType::HNSW)?; +//! client.create_collection("documents", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; //! //! // Add text (auto-generates embedding) //! let id = client.add_text_to_collection("documents", "Hello world", None)?; @@ -27,7 +27,7 @@ //! "documents", //! "hello", //! 5, -//! SimilarityMetric::Cosine +//! None // Auto-detects from HNSW index metric //! )?; //! # Ok(()) //! # } @@ -54,11 +54,11 @@ use crate::errors::{VectorLiteError, VectorLiteResult}; /// # Examples /// /// ```rust -/// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType}; +/// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType, SimilarityMetric}; /// /// # fn example() -> Result<(), Box> { /// let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); -/// client.create_collection("docs", IndexType::HNSW)?; +/// client.create_collection("docs", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; /// # Ok(()) /// # } /// ``` @@ -80,15 +80,22 @@ impl VectorLiteClient { } } - pub fn create_collection(&mut self, name: &str, index_type: IndexType) -> VectorLiteResult<()> { + pub fn create_collection(&mut self, name: &str, index_type: IndexType, metric: Option) -> VectorLiteResult<()> { if self.collections.contains_key(name) { return Err(VectorLiteError::CollectionAlreadyExists { name: name.to_string() }); } let dimension = self.embedding_function.dimension(); let index = match index_type { - IndexType::Flat => VectorIndexWrapper::Flat(crate::FlatIndex::new(dimension, Vec::new())), - IndexType::HNSW => VectorIndexWrapper::HNSW(Box::new(crate::HNSWIndex::new(dimension))), + IndexType::Flat => { + VectorIndexWrapper::Flat(crate::FlatIndex::new(dimension, Vec::new())) + }, + IndexType::HNSW => { + // HNSW requires a metric to build the graph structure + // No default is provided to force explicit specification + let used_metric = metric.ok_or(VectorLiteError::MetricRequired)?; + VectorIndexWrapper::HNSW(Box::new(crate::HNSWIndex::new(dimension, used_metric))) + }, }; let collection = Collection { @@ -129,11 +136,25 @@ impl VectorLiteClient { } - pub fn search_text_in_collection(&self, collection_name: &str, query_text: &str, k: usize, similarity_metric: SimilarityMetric) -> VectorLiteResult> { + pub fn search_text_in_collection(&self, collection_name: &str, query_text: &str, k: usize, similarity_metric: Option) -> VectorLiteResult> { let collection = self.collections.get(collection_name) .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?; - collection.search_text(query_text, k, similarity_metric, self.embedding_function.as_ref()) + let metric = match similarity_metric { + Some(m) => m, + None => { + let index_guard = collection.index.read().map_err(|_| { + VectorLiteError::LockError("Failed to acquire read lock for metric detection".to_string()) + })?; + + match index_guard.metric() { + Some(m) => m, + None => SimilarityMetric::Cosine, + } + } + }; + + collection.search_text(query_text, k, metric, self.embedding_function.as_ref()) } @@ -178,16 +199,17 @@ impl VectorLiteClient { /// # Examples /// /// ```rust -/// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType}; +/// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType, SimilarityMetric}; /// /// # fn example() -> Result<(), Box> { /// let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); /// /// // For small datasets with exact search requirements -/// client.create_collection("small_data", IndexType::Flat)?; +/// // Flat index - metric is optional +/// client.create_collection("small_data", IndexType::Flat, None)?; /// /// // For large datasets with approximate search tolerance -/// client.create_collection("large_data", IndexType::HNSW)?; +/// client.create_collection("large_data", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; /// # Ok(()) /// # } /// ``` @@ -235,11 +257,11 @@ type CollectionRef = Arc; /// # Examples /// /// ```rust -/// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType}; +/// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType, SimilarityMetric}; /// /// # fn example() -> Result<(), Box> { /// let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); -/// client.create_collection("docs", IndexType::HNSW)?; +/// client.create_collection("docs", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; /// /// let info = client.get_collection_info("docs")?; /// println!("Collection '{}' has {} vectors of dimension {}", @@ -374,13 +396,14 @@ impl Collection { // Acquire read lock for search let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for search_text".to_string()))?; - Ok(index.search(&query_embedding, k, similarity_metric)) + + index.search(&query_embedding, k, similarity_metric) } pub fn get_vector(&self, id: u64) -> VectorLiteResult> { let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_vector".to_string()))?; - Ok(index.get_vector(id).cloned()) + Ok(index.get_vector(id)) } pub fn get_info(&self) -> VectorLiteResult { @@ -424,12 +447,12 @@ impl Collection { /// # Examples /// /// ```rust - /// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType}; + /// use vectorlite::{VectorLiteClient, EmbeddingGenerator, IndexType, SimilarityMetric}; /// use std::path::Path; /// /// # fn example() -> Result<(), Box> { /// let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); - /// client.create_collection("docs", IndexType::HNSW)?; + /// client.create_collection("docs", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; /// client.add_text_to_collection("docs", "Hello world", None)?; /// /// let collection = client.get_collection("docs").unwrap(); @@ -514,7 +537,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create collection - let result = client.create_collection("test_collection", IndexType::Flat); + let result = client.create_collection("test_collection", IndexType::Flat, None); assert!(result.is_ok()); // Check collection exists @@ -528,10 +551,10 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create first collection - client.create_collection("test_collection", IndexType::Flat).unwrap(); + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); // Try to create duplicate - let result = client.create_collection("test_collection", IndexType::Flat); + let result = client.create_collection("test_collection", IndexType::Flat, None); assert!(result.is_err()); assert!(matches!(result.unwrap_err(), VectorLiteError::CollectionAlreadyExists { .. })); } @@ -542,7 +565,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create collection - client.create_collection("test_collection", IndexType::Flat).unwrap(); + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); // Get collection let collection = client.get_collection("test_collection"); @@ -560,7 +583,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create collection - client.create_collection("test_collection", IndexType::Flat).unwrap(); + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); assert!(client.has_collection("test_collection")); // Delete collection @@ -579,7 +602,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create collection - client.create_collection("test_collection", IndexType::Flat).unwrap(); + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); // Add text let result = client.add_text_to_collection("test_collection", "Hello world", None); @@ -615,7 +638,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create collection - client.create_collection("test_collection", IndexType::Flat).unwrap(); + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); // Test initial state let info = client.get_collection_info("test_collection").unwrap(); @@ -639,7 +662,7 @@ mod tests { assert_eq!(info.count, 2); // Test search - let results = client.search_text_in_collection("test_collection", "Hello", 1, SimilarityMetric::Cosine).unwrap(); + let results = client.search_text_in_collection("test_collection", "Hello", 1, Some(SimilarityMetric::Cosine)).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].id, 0); @@ -665,7 +688,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create HNSW collection - client.create_collection("hnsw_collection", IndexType::HNSW).unwrap(); + client.create_collection("hnsw_collection", IndexType::HNSW, Some(SimilarityMetric::Euclidean)).unwrap(); // Add some text let id1 = client.add_text_to_collection("hnsw_collection", "First document", None).unwrap(); @@ -677,12 +700,26 @@ mod tests { let info = client.get_collection_info("hnsw_collection").unwrap(); assert_eq!(info.count, 2); - // Test search - let results = client.search_text_in_collection("hnsw_collection", "First", 1, SimilarityMetric::Cosine).unwrap(); + // Test search with Euclidean (must match the index metric) + let results = client.search_text_in_collection("hnsw_collection", "First", 1, Some(SimilarityMetric::Euclidean)).unwrap(); assert_eq!(results.len(), 1); } + #[test] + fn test_hnsw_requires_metric() { + let embedding_fn = MockEmbeddingFunction::new(3); + let mut client = VectorLiteClient::new(Box::new(embedding_fn)); + // Creating HNSW without metric should fail + let result = client.create_collection("hnsw_collection", IndexType::HNSW, None); + assert!(result.is_err()); + match result { + Err(VectorLiteError::MetricRequired) => { + // Expected error + } + _ => panic!("Expected MetricRequired error when creating HNSW without metric"), + } + } #[test] fn test_collection_save_and_load() { @@ -690,7 +727,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create collection and add some data - client.create_collection("test_collection", IndexType::Flat).unwrap(); + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); client.add_text_to_collection("test_collection", "Hello world", None).unwrap(); client.add_text_to_collection("test_collection", "Another text", None).unwrap(); @@ -727,7 +764,7 @@ mod tests { let mut client = VectorLiteClient::new(Box::new(embedding_fn)); // Create HNSW collection and add some data - client.create_collection("test_hnsw_collection", IndexType::HNSW).unwrap(); + client.create_collection("test_hnsw_collection", IndexType::HNSW, Some(SimilarityMetric::Euclidean)).unwrap(); client.add_text_to_collection("test_hnsw_collection", "First document", None).unwrap(); client.add_text_to_collection("test_hnsw_collection", "Second document", None).unwrap(); @@ -741,8 +778,8 @@ mod tests { // Create a separate embedding function for testing let test_embedding_fn = MockEmbeddingFunction::new(3); - // Test search on original collection using text search (like the working test) - let results = collection.search_text("First", 1, SimilarityMetric::Cosine, &test_embedding_fn).unwrap(); + // Test search on original collection using text search + let results = collection.search_text("First", 1, SimilarityMetric::Euclidean, &test_embedding_fn).unwrap(); assert_eq!(results.len(), 1); // Save to temporary file @@ -765,7 +802,7 @@ mod tests { assert!(!info.is_empty); // Test search functionality using text search - let results = loaded_collection.search_text("First", 1, SimilarityMetric::Cosine, &test_embedding_fn).unwrap(); + let results = loaded_collection.search_text("First", 1, SimilarityMetric::Euclidean, &test_embedding_fn).unwrap(); assert_eq!(results.len(), 1); } @@ -774,7 +811,8 @@ mod tests { let embedding_fn = MockEmbeddingFunction::new(3); let mut client = VectorLiteClient::new(Box::new(embedding_fn)); - client.create_collection("test_collection", IndexType::Flat).unwrap(); + // Flat indexes don't need a metric parameter + client.create_collection("test_collection", IndexType::Flat, None).unwrap(); client.add_text_to_collection("test_collection", "Hello world", None).unwrap(); let collection = client.get_collection("test_collection").unwrap(); @@ -792,10 +830,10 @@ mod tests { fn test_collection_load_nonexistent_file() { let temp_dir = tempfile::TempDir::new().unwrap(); let file_path = temp_dir.path().join("nonexistent.vlc"); - + let result = Collection::load_from_file(&file_path); assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), PersistenceError::Io(_))); + assert!(matches!(result.unwrap_err(), PersistenceError::FileNotFound(_))); } #[test] diff --git a/src/embeddings.rs b/src/embeddings.rs index ecae792..a59f162 100644 --- a/src/embeddings.rs +++ b/src/embeddings.rs @@ -400,7 +400,7 @@ mod tests { #[test] fn test_batch_embedding_generation() { let generator = create_test_generator(); - let texts = vec![ + let texts = [ "first text".to_string(), "second text".to_string(), "third text".to_string(), diff --git a/src/errors.rs b/src/errors.rs index 30aac6b..db51c89 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -37,6 +37,14 @@ pub enum VectorLiteError { #[error("Invalid similarity metric: {metric}. Must be 'cosine', 'euclidean', 'manhattan', or 'dotproduct'")] InvalidSimilarityMetric { metric: String }, + /// Metric mismatch between search request and index configuration + #[error("Metric mismatch: search requested {requested:?} but index was built for {index:?}")] + MetricMismatch { requested: crate::SimilarityMetric, index: crate::SimilarityMetric }, + + /// Metric required for HNSW index but not provided + #[error("HNSW index requires an explicit similarity metric. Add field 'metric' with one of the following: ['cosine', 'euclidean', 'manhattan', 'dotproduct'] ")] + MetricRequired, + /// Embedding generation error #[error("Embedding generation failed: {0}")] EmbeddingError(#[from] crate::embeddings::EmbeddingError), @@ -70,8 +78,13 @@ impl VectorLiteError { VectorLiteError::CollectionAlreadyExists { .. } => StatusCode::CONFLICT, VectorLiteError::InvalidIndexType { .. } => StatusCode::BAD_REQUEST, VectorLiteError::InvalidSimilarityMetric { .. } => StatusCode::BAD_REQUEST, + VectorLiteError::MetricMismatch { .. } => StatusCode::BAD_REQUEST, + VectorLiteError::MetricRequired => StatusCode::BAD_REQUEST, VectorLiteError::EmbeddingError(_) => StatusCode::INTERNAL_SERVER_ERROR, - VectorLiteError::PersistenceError(_) => StatusCode::INTERNAL_SERVER_ERROR, + VectorLiteError::PersistenceError(e) => match e { + crate::persistence::PersistenceError::FileNotFound(_) => StatusCode::NOT_FOUND, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, VectorLiteError::LockError(_) => StatusCode::INTERNAL_SERVER_ERROR, VectorLiteError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, } diff --git a/src/index/flat.rs b/src/index/flat.rs index 46e0be8..5bb1200 100644 --- a/src/index/flat.rs +++ b/src/index/flat.rs @@ -95,7 +95,14 @@ impl VectorIndex for FlatIndex { Ok(()) } - fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> Vec { + fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> Result, crate::errors::VectorLiteError> { + if !self.data.is_empty() && query.len() != self.dim { + return Err(crate::errors::VectorLiteError::DimensionMismatch { + expected: self.dim, + actual: query.len() + }); + } + let mut similarities: Vec<_> = self.data .iter() .map(|e| SearchResult { @@ -108,7 +115,7 @@ impl VectorIndex for FlatIndex { similarities.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); similarities.truncate(k); - similarities + Ok(similarities) } fn len(&self) -> usize { @@ -119,8 +126,8 @@ impl VectorIndex for FlatIndex { self.data.is_empty() } - fn get_vector(&self, id: u64) -> Option<&Vector> { - self.data.iter().find(|e| e.id == id) + fn get_vector(&self, id: u64) -> Option { + self.data.iter().find(|e| e.id == id).cloned() } fn dimension(&self) -> usize { @@ -162,7 +169,7 @@ mod tests { // Verify search works on the deserialized index let query = vec![1.1, 0.1, 0.1]; - let results = deserialized.search(&query, 2, SimilarityMetric::Cosine); + let results = deserialized.search(&query, 2, SimilarityMetric::Cosine).unwrap(); assert_eq!(results.len(), 2); // Results should be sorted by score (highest first) @@ -186,7 +193,7 @@ mod tests { let index = FlatIndex::new(3, vectors); let query = vec![1.0, 0.0, 0.0]; - let results = index.search(&query, 2, SimilarityMetric::Cosine); + let results = index.search(&query, 2, SimilarityMetric::Cosine).unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].id, 1); // Most similar (identical) @@ -203,7 +210,7 @@ mod tests { let index = FlatIndex::new(2, vectors); let query = vec![0.0, 0.0]; - let results = index.search(&query, 2, SimilarityMetric::Euclidean); + let results = index.search(&query, 2, SimilarityMetric::Euclidean).unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].id, 1); // Most similar (identical) @@ -220,7 +227,7 @@ mod tests { let index = FlatIndex::new(2, vectors); let query = vec![0.0, 0.0]; - let results = index.search(&query, 2, SimilarityMetric::Manhattan); + let results = index.search(&query, 2, SimilarityMetric::Manhattan).unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].id, 1); // Most similar (identical) @@ -237,7 +244,7 @@ mod tests { let index = FlatIndex::new(2, vectors); let query = vec![1.0, 2.0]; - let results = index.search(&query, 2, SimilarityMetric::DotProduct); + let results = index.search(&query, 2, SimilarityMetric::DotProduct).unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].id, 1); // Most similar (identical) @@ -255,11 +262,11 @@ mod tests { let query = vec![1.0, 2.0]; // Test with cosine similarity - let results_cosine = index.search(&query, 1, SimilarityMetric::Cosine); + let results_cosine = index.search(&query, 1, SimilarityMetric::Cosine).unwrap(); assert_eq!(results_cosine[0].id, 1); // Test with dot product - let results_dot = index.search(&query, 1, SimilarityMetric::DotProduct); + let results_dot = index.search(&query, 1, SimilarityMetric::DotProduct).unwrap(); assert_eq!(results_dot[0].id, 1); // Scores should be different diff --git a/src/index/hnsw.rs b/src/index/hnsw.rs index c7c4d34..3c90b56 100644 --- a/src/index/hnsw.rs +++ b/src/index/hnsw.rs @@ -30,11 +30,11 @@ //! use vectorlite::{HNSWIndex, Vector, SimilarityMetric, VectorIndex}; //! //! # fn example() -> Result<(), Box> { -//! let mut index = HNSWIndex::new(384); +//! let mut index = HNSWIndex::new(384, SimilarityMetric::Euclidean); //! let vector = Vector { id: 1, values: vec![0.1; 384], text: "test".to_string(), metadata: None }; //! //! index.add(vector)?; -//! let results = index.search(&[0.1; 384], 5, SimilarityMetric::Cosine); +//! let results = index.search(&[0.1; 384], 5, SimilarityMetric::Euclidean); //! # Ok(()) //! # } //! ``` @@ -46,9 +46,52 @@ use serde::{Deserialize, Serialize, Deserializer}; use space::{Metric, Neighbor}; use hnsw::{Hnsw, Searcher}; use crate::{Vector, VectorIndex, SearchResult, SimilarityMetric}; + +/// Convert distance to similarity score for the given metric +fn convert_distance_to_similarity(distance: f64, metric: SimilarityMetric) -> f64 { + match metric { + SimilarityMetric::Euclidean => { + // For Euclidean: similarity = 1 / (1 + distance) + 1.0 / (1.0 + distance) + }, + SimilarityMetric::Cosine => { + // For Cosine: distance = 1 - similarity, so similarity = 1 - distance + // But distance is [0, 2000] scaled, so we divide by 1000 + let cos_distance = distance / 1000.0; + 1.0 - cos_distance + }, + SimilarityMetric::Manhattan => { + // For Manhattan: similarity = 1 / (1 + distance) + 1.0 / (1.0 + distance) + }, + SimilarityMetric::DotProduct => { + // For DotProduct: distance = 1000 - dot_product (clamped) + // So: dot_product = 1000 - distance + // We want similarity to range [0, 1] where higher dot product = higher similarity + // Convert: similarity = (1000 - distance) / 1000, normalized to [0, 1] + ((1000.0 - distance) / 1000.0).clamp(0.0, 1.0) + }, + } +} + +// VectorMetadata contains the metadata for a vector without the embedding values +#[derive(Debug, Clone, Serialize, Deserialize)] +struct VectorMetadata { + text: String, + metadata: Option, +} #[derive(Default, Clone)] struct Euclidean; +#[derive(Default, Clone)] +struct Cosine; + +#[derive(Default, Clone)] +struct Manhattan; + +#[derive(Default, Clone)] +struct DotProduct; + const MAXIMUM_NUMBER_CONNECTIONS: usize = if cfg!(feature = "memory-optimized") { 8 } else if cfg!(feature = "high-accuracy") { @@ -79,42 +122,150 @@ impl Metric> for Euclidean { } } +impl Metric> for Cosine { + type Unit = u64; + + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + // Cosine distance = 1 - cosine_similarity + let (dot, norm_a_sq, norm_b_sq) = a.iter() + .zip(b.iter()) + .fold((0.0, 0.0, 0.0), |(dot, a_sq, b_sq), (&x, &y)| { + (dot + x * y, a_sq + x * x, b_sq + y * y) + }); + + let norm_a = norm_a_sq.sqrt(); + let norm_b = norm_b_sq.sqrt(); + + if norm_a == 0.0 || norm_b == 0.0 { + return 1000; // Maximum distance for zero vectors + } + + let cosine_sim = dot / (norm_a * norm_b); + // Convert to distance: (1 - similarity) * 1000 + let distance = (1.0 - cosine_sim) * 1000.0; + distance as u64 + } +} + +impl Metric> for Manhattan { + type Unit = u64; + + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + let dist = a.iter() + .zip(b.iter()) + .map(|(&x, &y)| (x - y).abs()) + .sum::(); + (dist * 1000.0) as u64 + } +} + +impl Metric> for DotProduct { + type Unit = u64; + + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + // Dot product as distance (negative because higher dot product = smaller distance) + let dot = a.iter() + .zip(b.iter()) + .map(|(&x, &y)| x * y) + .sum::(); + // Convert to positive distance: 1000 - dot (clamped) + (1000.0 - dot.clamp(-1000.0, 1000.0)) as u64 + } +} + +/// Enum to hold different HNSW index types for different metrics +#[derive(Clone)] +enum HNSWIndexInternal { + Euclidean { + hnsw: Hnsw, StdRng, MAXIMUM_NUMBER_CONNECTIONS, MAXIMUM_NUMBER_CONNECTIONS_0>, + searcher: Searcher, + }, + Cosine { + hnsw: Hnsw, StdRng, MAXIMUM_NUMBER_CONNECTIONS, MAXIMUM_NUMBER_CONNECTIONS_0>, + searcher: Searcher, + }, + Manhattan { + hnsw: Hnsw, StdRng, MAXIMUM_NUMBER_CONNECTIONS, MAXIMUM_NUMBER_CONNECTIONS_0>, + searcher: Searcher, + }, + DotProduct { + hnsw: Hnsw, StdRng, MAXIMUM_NUMBER_CONNECTIONS, MAXIMUM_NUMBER_CONNECTIONS_0>, + searcher: Searcher, + }, +} + #[derive(Clone, Serialize)] pub struct HNSWIndex { #[serde(skip)] - hnsw: Hnsw, StdRng, MAXIMUM_NUMBER_CONNECTIONS, MAXIMUM_NUMBER_CONNECTIONS_0>, - #[serde(skip)] - searcher: Searcher, + index_internal: HNSWIndexInternal, dim: usize, + // The similarity metric this index was optimized for + metric: SimilarityMetric, // Mapping from custom ID to internal HNSW index id_to_index: HashMap, // Mapping from internal HNSW index to custom ID index_to_id: HashMap, - // Store vectors for retrieval by ID. - vectors: HashMap, + // Store only metadata (text + JSON), not the full Vector + metadata: HashMap, + // Store vector values separately + vector_values: HashMap>, } impl HNSWIndex { - pub fn new(dim: usize) -> Self { + pub fn new(dim: usize, metric: SimilarityMetric) -> Self { if dim == 0 { panic!("HNSW index dimension cannot be 0"); } - let hnsw: Hnsw, StdRng, MAXIMUM_NUMBER_CONNECTIONS, MAXIMUM_NUMBER_CONNECTIONS_0> = Hnsw::new(Euclidean); - let searcher = Searcher::new(); + + // Create HNSW with the specific metric for its graph structure. + // This ensures the HNSW graph is optimized for the intended similarity metric. + let index_internal = match metric { + SimilarityMetric::Euclidean => { + HNSWIndexInternal::Euclidean { + hnsw: Hnsw::new(Euclidean), + searcher: Searcher::new(), + } + }, + SimilarityMetric::Cosine => { + HNSWIndexInternal::Cosine { + hnsw: Hnsw::new(Cosine), + searcher: Searcher::new(), + } + }, + SimilarityMetric::Manhattan => { + HNSWIndexInternal::Manhattan { + hnsw: Hnsw::new(Manhattan), + searcher: Searcher::new(), + } + }, + SimilarityMetric::DotProduct => { + HNSWIndexInternal::DotProduct { + hnsw: Hnsw::new(DotProduct), + searcher: Searcher::new(), + } + }, + }; + Self { - hnsw, - searcher, + index_internal, dim, + metric, id_to_index: HashMap::new(), index_to_id: HashMap::new(), - vectors: HashMap::new(), + metadata: HashMap::new(), + vector_values: HashMap::new(), } } + + /// Get the metric this index was built for + pub fn metric(&self) -> SimilarityMetric { + self.metric + } /// Get the maximum ID from the stored vectors pub fn max_id(&self) -> Option { - self.vectors.keys().max().copied() + self.metadata.keys().max().copied() } } @@ -126,41 +277,84 @@ impl<'de> Deserialize<'de> for HNSWIndex { #[derive(Deserialize)] struct Temp { dim: usize, - vectors: HashMap, + metric: SimilarityMetric, + metadata: HashMap, + vector_values: HashMap>, } let data = Temp::deserialize(deserializer)?; - let mut hnsw = Hnsw::new(Euclidean); - let mut searcher = Searcher::new(); + if data.dim == 0 { + return Err(serde::de::Error::custom("Invalid dimension: cannot be 0")); + } + + // Create the appropriate HNSW index based on the metric + let mut index_internal = match data.metric { + SimilarityMetric::Euclidean => { + HNSWIndexInternal::Euclidean { + hnsw: Hnsw::new(Euclidean), + searcher: Searcher::new(), + } + }, + SimilarityMetric::Cosine => { + HNSWIndexInternal::Cosine { + hnsw: Hnsw::new(Cosine), + searcher: Searcher::new(), + } + }, + SimilarityMetric::Manhattan => { + HNSWIndexInternal::Manhattan { + hnsw: Hnsw::new(Manhattan), + searcher: Searcher::new(), + } + }, + SimilarityMetric::DotProduct => { + HNSWIndexInternal::DotProduct { + hnsw: Hnsw::new(DotProduct), + searcher: Searcher::new(), + } + }, + }; let mut new_id_to_index = HashMap::new(); let mut new_index_to_id = HashMap::new(); - for (id, vector) in &data.vectors { - if vector.values.len() != data.dim { + // Insert all vectors into the appropriate HNSW index + for (id, values) in &data.vector_values { + if values.len() != data.dim { return Err(serde::de::Error::custom(format!( "Vector dimension mismatch: expected {}, got {}", - data.dim, vector.values.len() + data.dim, values.len() ))); } - let internal_index = hnsw.insert(vector.values.clone(), &mut searcher); + + let internal_index = match &mut index_internal { + HNSWIndexInternal::Euclidean { hnsw, searcher } => { + hnsw.insert(values.clone(), searcher) + }, + HNSWIndexInternal::Cosine { hnsw, searcher } => { + hnsw.insert(values.clone(), searcher) + }, + HNSWIndexInternal::Manhattan { hnsw, searcher } => { + hnsw.insert(values.clone(), searcher) + }, + HNSWIndexInternal::DotProduct { hnsw, searcher } => { + hnsw.insert(values.clone(), searcher) + }, + }; + new_id_to_index.insert(*id, internal_index); new_index_to_id.insert(internal_index, *id); } - // Verify that the HNSW index was created with the correct dimension - if data.dim == 0 { - return Err(serde::de::Error::custom("Invalid dimension: cannot be 0")); - } - Ok(HNSWIndex { - hnsw, - searcher, + index_internal, dim: data.dim, + metric: data.metric, id_to_index: new_id_to_index, index_to_id: new_index_to_id, - vectors: data.vectors, + metadata: data.metadata, + vector_values: data.vector_values, }) } } @@ -175,11 +369,31 @@ impl VectorIndex for HNSWIndex { return Err(format!("Vector ID {} already exists", vector.id)); } - let internal_index = self.hnsw.insert(vector.values.clone(), &mut self.searcher); + let internal_index = match &mut self.index_internal { + HNSWIndexInternal::Euclidean { hnsw, searcher } => { + hnsw.insert(vector.values.clone(), searcher) + }, + HNSWIndexInternal::Cosine { hnsw, searcher } => { + hnsw.insert(vector.values.clone(), searcher) + }, + HNSWIndexInternal::Manhattan { hnsw, searcher } => { + hnsw.insert(vector.values.clone(), searcher) + }, + HNSWIndexInternal::DotProduct { hnsw, searcher } => { + hnsw.insert(vector.values.clone(), searcher) + }, + }; + + // Store metadata and values separately + let vector_metadata = VectorMetadata { + text: vector.text, + metadata: vector.metadata, + }; self.id_to_index.insert(vector.id, internal_index); self.index_to_id.insert(internal_index, vector.id); - self.vectors.insert(vector.id, vector); + self.metadata.insert(vector.id, vector_metadata); + self.vector_values.insert(vector.id, vector.values); Ok(()) } @@ -193,29 +407,36 @@ impl VectorIndex for HNSWIndex { // Since HNSW doesn't support deletion, we just remove the reference to the node in the mapping self.id_to_index.remove(&id); self.index_to_id.remove(&internal_index); - self.vectors.remove(&id); + self.metadata.remove(&id); + self.vector_values.remove(&id); Ok(()) } - fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> Vec { + fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> Result, crate::errors::VectorLiteError> { if query.len() != self.dim { - eprintln!("Warning: Query dimension mismatch. Expected {}, got {}. Returning empty results.", self.dim, query.len()); - return Vec::new(); + return Err(crate::errors::VectorLiteError::DimensionMismatch { + expected: self.dim, + actual: query.len() + }); } - if self.vectors.is_empty() { - return Vec::new(); + // Reject searches that don't match the metric the index was built for + // HNSW's graph structure is optimized for a specific distance metric + if similarity_metric != self.metric { + return Err(crate::errors::VectorLiteError::MetricMismatch { + requested: similarity_metric, + index: self.metric + }); } - let query_vec = query.to_vec(); + if self.metadata.is_empty() { + return Ok(Vec::new()); + } - let mut searcher: Searcher = Searcher::new(); - // HNSW searches for k*2 candidates to improve accuracy in approximate search. - // This compensates for the graph structure limitations and ensures we find - // the best k results after recalculating with the requested similarity metric. - let max_candidates = std::cmp::min(k * 2, self.vectors.len()); + let query_vec = query.to_vec(); + let max_candidates = std::cmp::min(k, self.metadata.len()); if max_candidates == 0 { - return Vec::new(); + return Ok(Vec::new()); } let mut neighbors = vec![ @@ -226,20 +447,42 @@ impl VectorIndex for HNSWIndex { max_candidates ]; - let results = self.hnsw.nearest(&query_vec, max_candidates, &mut searcher, &mut neighbors); + // Use the appropriate HNSW index based on the metric + let results = match &self.index_internal { + HNSWIndexInternal::Euclidean { hnsw, .. } => { + let mut searcher: Searcher = Searcher::new(); + hnsw.nearest(&query_vec, max_candidates, &mut searcher, &mut neighbors) + }, + HNSWIndexInternal::Cosine { hnsw, .. } => { + let mut searcher: Searcher = Searcher::new(); + hnsw.nearest(&query_vec, max_candidates, &mut searcher, &mut neighbors) + }, + HNSWIndexInternal::Manhattan { hnsw, .. } => { + let mut searcher: Searcher = Searcher::new(); + hnsw.nearest(&query_vec, max_candidates, &mut searcher, &mut neighbors) + }, + HNSWIndexInternal::DotProduct { hnsw, .. } => { + let mut searcher: Searcher = Searcher::new(); + hnsw.nearest(&query_vec, max_candidates, &mut searcher, &mut neighbors) + }, + }; - // Get candidate vectors and recalculate with the requested similarity metric + // Convert HNSW distances to similarity scores + // The HNSW returns distances in its native Unit (u64 scaled by 1000) let mut search_results: Vec = results.iter() .filter(|n| n.index != !0) // Filter out invalid results .filter_map(|n| { self.index_to_id.get(&n.index).and_then(|&custom_id| { - self.vectors.get(&custom_id).map(|vector| { - let score = similarity_metric.calculate(&vector.values, query); + self.metadata.get(&custom_id).map(|meta| { + // Convert u64 distance back to f64, then to similarity + let distance = n.distance as f64 / 1000.0; + let score = convert_distance_to_similarity(distance, similarity_metric); + SearchResult { id: custom_id, score, - text: vector.text.clone(), - metadata: vector.metadata.clone() + text: meta.text.clone(), + metadata: meta.metadata.clone() } }) }) @@ -249,16 +492,25 @@ impl VectorIndex for HNSWIndex { // Sort by similarity score and take top k search_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); search_results.truncate(k); - search_results + Ok(search_results) } fn len(&self) -> usize { - self.vectors.len() + self.metadata.len() } fn is_empty(&self) -> bool { - self.vectors.is_empty() + self.metadata.is_empty() } - fn get_vector(&self, id: u64) -> Option<&Vector> { - self.vectors.get(&id) + fn get_vector(&self, id: u64) -> Option { + self.metadata.get(&id).and_then(|meta| { + self.vector_values.get(&id).map(|values| { + Vector { + id, + values: values.clone(), + text: meta.text.clone(), + metadata: meta.metadata.clone(), + } + }) + }) } fn dimension(&self) -> usize { self.dim @@ -268,22 +520,22 @@ impl VectorIndex for HNSWIndex { impl Debug for HNSWIndex { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("HNSWIndex") - .field("len", &self.vectors.len()) - .field("is_empty", &self.vectors.is_empty()) + .field("len", &self.metadata.len()) + .field("is_empty", &self.metadata.is_empty()) .field("dimension", &self.dim) .finish() } } #[test] fn test_create_hnswindex() { - let hnsw = HNSWIndex::new(3); + let hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); assert!(hnsw.is_empty()); assert_eq!(hnsw.dimension(), 3); } #[test] fn test_add_vector() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); let vector = Vector { id: 1, values: vec![1.0, 2.0, 3.0], @@ -298,7 +550,7 @@ fn test_add_vector() { #[test] fn test_add_vector_dimension_mismatch() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); let vector = Vector { id: 1, values: vec![1.0, 2.0], // Wrong dimension @@ -312,7 +564,7 @@ fn test_add_vector_dimension_mismatch() { #[test] fn test_search_basic() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); let vectors = vec![ Vector { id: 1, values: vec![1.0, 0.0, 0.0], text: "test".to_string(), metadata: None }, @@ -329,7 +581,7 @@ fn test_search_basic() { // Search for vector similar to [1.0, 0.0, 0.0] let query = vec![1.1, 0.1, 0.1]; - let results = hnsw.search(&query, 2, SimilarityMetric::Euclidean); + let results = hnsw.search(&query, 2, SimilarityMetric::Euclidean).unwrap(); assert!(!results.is_empty()); assert!(results.len() <= 2); @@ -340,18 +592,18 @@ fn test_search_basic() { } } -#[test] -fn test_search_empty_index() { - let hnsw = HNSWIndex::new(3); - let query = vec![1.0, 2.0, 3.0]; - let results = hnsw.search(&query, 5, SimilarityMetric::Euclidean); - - assert!(results.is_empty()); -} + #[test] + fn test_search_empty_index() { + let hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); + let query = vec![1.0, 2.0, 3.0]; + let results = hnsw.search(&query, 5, SimilarityMetric::Euclidean).unwrap(); + + assert!(results.is_empty()); + } #[test] fn test_id_mapping() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); // Add vectors with custom IDs let vectors = vec![ @@ -374,7 +626,7 @@ fn test_id_mapping() { // Test that search returns the correct custom IDs let query = vec![1.1, 0.1, 0.1]; - let results = hnsw.search(&query, 2, SimilarityMetric::Euclidean); + let results = hnsw.search(&query, 2, SimilarityMetric::Euclidean).unwrap(); assert!(!results.is_empty()); // The first result should be the vector with ID 100 (most similar to [1.0, 0.0, 0.0]) @@ -383,7 +635,7 @@ fn test_id_mapping() { #[test] fn test_duplicate_id_error() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); let vector1 = Vector { id: 1, values: vec![1.0, 2.0, 3.0], text: "test".to_string(), metadata: None }; let vector2 = Vector { id: 1, values: vec![4.0, 5.0, 6.0], text: "test".to_string(), metadata: None }; // Same ID @@ -394,7 +646,7 @@ fn test_duplicate_id_error() { #[test] fn test_delete_vector() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); let vector = Vector { id: 42, values: vec![1.0, 2.0, 3.0], text: "test".to_string(), metadata: None }; assert!(hnsw.add(vector).is_ok()); @@ -413,7 +665,7 @@ fn test_delete_vector() { fn test_feature_flags() { // Test that the constants are properly set based on features // This test will only pass if the correct feature is enabled - let hnsw = HNSWIndex::new(3); + let hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); // Verify the HNSW was created successfully assert!(hnsw.is_empty()); @@ -428,7 +680,7 @@ fn test_serialization_deserialization() { use serde_json; // Create an HNSW index with some data - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); let vectors = vec![ Vector { id: 1, values: vec![1.0, 0.0, 0.0], text: "test".to_string(), metadata: None }, Vector { id: 2, values: vec![0.0, 1.0, 0.0], text: "test".to_string(), metadata: None }, @@ -469,7 +721,7 @@ fn test_serialization_deserialization() { let query = vec![1.1, 0.1, 0.1]; - let results = deserialized.search(&query, 2, SimilarityMetric::Euclidean); + let results = deserialized.search(&query, 2, SimilarityMetric::Euclidean).unwrap(); assert!(!results.is_empty(), "Search should return some results"); assert!(results.len() <= 2, "Should return at most 2 results as requested"); @@ -501,7 +753,7 @@ fn test_empty_hnsw_serialization_deserialization() { use serde_json; // Create an empty HNSW index - let empty_hnsw = HNSWIndex::new(3); + let empty_hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); assert!(empty_hnsw.is_empty()); assert_eq!(empty_hnsw.dimension(), 3); @@ -522,7 +774,7 @@ fn test_empty_hnsw_serialization_deserialization() { #[test] fn test_search_with_limited_vectors() { - let mut hnsw = HNSWIndex::new(3); + let mut hnsw = HNSWIndex::new(3, SimilarityMetric::Euclidean); // Add only 3 vectors let vectors = vec![ @@ -540,7 +792,7 @@ fn test_search_with_limited_vectors() { // Test searching for k=4 (more than we have vectors) // This should not panic and should return at most 3 results let query = vec![1.1, 0.1, 0.1]; - let results = hnsw.search(&query, 4, SimilarityMetric::Euclidean); + let results = hnsw.search(&query, 4, SimilarityMetric::Euclidean).unwrap(); // Should return at most 3 results (all available vectors) assert!(results.len() <= 3); @@ -552,3 +804,230 @@ fn test_search_with_limited_vectors() { } } +/// Tests for distance to similarity conversion functions +#[cfg(test)] +mod conversion_tests { + use super::{convert_distance_to_similarity, SimilarityMetric}; + + #[test] + fn test_euclidean_distance_conversion() { + // Test zero distance (identical vectors) + let distance = 0.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Euclidean); + assert_eq!(similarity, 1.0, "Zero distance should give similarity of 1.0"); + + // Test small distance + let distance = 0.5; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Euclidean); + let expected = 1.0 / (1.0 + 0.5); + assert!((similarity - expected).abs() < 1e-10, "Small distance conversion should be correct"); + + // Test medium distance + let distance = 1.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Euclidean); + let expected = 1.0 / (1.0 + 1.0); + assert!((similarity - expected).abs() < 1e-10, "Medium distance conversion should be correct"); + + // Test large distance + let distance = 10.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Euclidean); + let expected = 1.0 / (1.0 + 10.0); + assert!((similarity - expected).abs() < 1e-10, "Large distance conversion should be correct"); + + // Test very large distance + let distance = 100.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Euclidean); + assert!(similarity > 0.0 && similarity < 0.01, "Very large distance should give very low similarity"); + } + + #[test] + fn test_cosine_distance_conversion() { + // Test zero distance (identical vectors) + let distance = 0.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Cosine); + assert_eq!(similarity, 1.0, "Zero cosine distance should give similarity of 1.0"); + + // Test small distance (similar vectors) + let distance = 100.0; // cosine distance of 0.1 in scaled units + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Cosine); + let expected = 1.0 - (100.0 / 1000.0); + assert!((similarity - expected).abs() < 1e-10, "Small cosine distance conversion should be correct"); + + // Test medium distance + let distance = 500.0; // cosine distance of 0.5 in scaled units + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Cosine); + let expected = 1.0 - (500.0 / 1000.0); + assert!((similarity - expected).abs() < 1e-10, "Medium cosine distance conversion should be correct"); + + // Test maximum distance (opposite vectors) + let distance = 2000.0; // cosine distance of 2.0 in scaled units + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Cosine); + let expected = 1.0 - (2000.0 / 1000.0); + assert!((similarity - expected).abs() < 1e-10, "Maximum cosine distance conversion should be correct"); + + // Verify similarity is bounded + assert!((-1.0..=1.0).contains(&similarity), "Cosine similarity should be bounded"); + } + + #[test] + fn test_manhattan_distance_conversion() { + // Test zero distance (identical vectors) + let distance = 0.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Manhattan); + assert_eq!(similarity, 1.0, "Zero Manhattan distance should give similarity of 1.0"); + + // Test small distance + let distance = 1.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Manhattan); + let expected = 1.0 / (1.0 + 1.0); + assert!((similarity - expected).abs() < 1e-10, "Small Manhattan distance conversion should be correct"); + + // Test medium distance + let distance = 5.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Manhattan); + let expected = 1.0 / (1.0 + 5.0); + assert!((similarity - expected).abs() < 1e-10, "Medium Manhattan distance conversion should be correct"); + + // Test large distance + let distance = 20.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::Manhattan); + let expected = 1.0 / (1.0 + 20.0); + assert!((similarity - expected).abs() < 1e-10, "Large Manhattan distance conversion should be correct"); + } + + #[test] + fn test_dotproduct_distance_conversion() { + // Test zero distance (maximum dot product) + let distance = 0.0; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::DotProduct); + assert_eq!(similarity, 1.0, "Zero dot product distance should give similarity of 1.0"); + + // Test small distance + let distance = 100.0_f64; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::DotProduct); + let ratio: f64 = (1000.0 - 100.0) / 1000.0; + let expected: f64 = ratio.clamp(0.0, 1.0); + assert!((similarity - expected).abs() < 1e-10, "Small dot product distance conversion should be correct"); + + // Test medium distance + let distance = 500.0_f64; + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::DotProduct); + let ratio: f64 = (1000.0 - 500.0) / 1000.0; + let expected: f64 = ratio.clamp(0.0, 1.0); + assert!((similarity - expected).abs() < 1e-10, "Medium dot product distance conversion should be correct"); + + // Test maximum distance + let distance = 2000.0_f64; // negative dot product clamped + let similarity = convert_distance_to_similarity(distance, SimilarityMetric::DotProduct); + assert_eq!(similarity, 0.0, "Maximum dot product distance should give similarity of 0.0"); + + // Test that similarity is bounded [0, 1] + let distances = vec![0.0_f64, 100.0, 500.0, 1000.0, 1500.0, 2000.0]; + for dist in distances { + let sim = convert_distance_to_similarity(dist, SimilarityMetric::DotProduct); + assert!((0.0..=1.0).contains(&sim), "DotProduct similarity should be in [0, 1]"); + } + } + + #[test] + fn test_conversion_monotonicity() { + // Test that similarity decreases monotonically with increasing distance + // for all metrics + + let distances = vec![0.0, 0.5, 1.0, 2.0, 5.0, 10.0]; + + for metric in &[ + SimilarityMetric::Euclidean, + SimilarityMetric::Cosine, + SimilarityMetric::Manhattan, + ] { + let mut prev_sim = 1.0; + for &dist in &distances { + let sim = convert_distance_to_similarity(dist, *metric); + assert!(sim <= prev_sim, + "Similarity should decrease as distance increases for {:?}", + metric); + prev_sim = sim; + } + } + } + + #[test] + fn test_conversion_edge_cases() { + // Test with extremely small distances + for metric in &[ + SimilarityMetric::Euclidean, + SimilarityMetric::Cosine, + SimilarityMetric::Manhattan, + SimilarityMetric::DotProduct, + ] { + let sim = convert_distance_to_similarity(0.0001, *metric); + assert!(sim > 0.9, "Extremely small distance should give high similarity"); + assert!(sim <= 1.0, "Similarity should not exceed 1.0"); + } + + // Test with extremely large distances + for metric in &[ + SimilarityMetric::Euclidean, + SimilarityMetric::Manhattan, + ] { + let sim = convert_distance_to_similarity(100000.0, *metric); + assert!(sim > 0.0, "Even very large distance should give non-zero similarity"); + assert!(sim < 0.01, "Very large distance should give very low similarity"); + } + } + + #[test] + fn test_conversion_known_vectors() { + // Test with actual vector calculations + + // Two identical vectors should have high similarity + let identical_distance_euclidean = 0.0; + let identical_distance_cosine = 0.0; + + let euclidean_sim = convert_distance_to_similarity(identical_distance_euclidean, SimilarityMetric::Euclidean); + let cosine_sim = convert_distance_to_similarity(identical_distance_cosine, SimilarityMetric::Cosine); + + assert_eq!(euclidean_sim, 1.0); + assert_eq!(cosine_sim, 1.0); + + // Opposite vectors (for cosine): [1,0,0] and [-1,0,0] + // Cosine distance = 2 (in raw form) = 2000 (scaled by 1000) + let opposite_distance = 2000.0; + let cosine_sim = convert_distance_to_similarity(opposite_distance, SimilarityMetric::Cosine); + assert!((cosine_sim - (-1.0)).abs() < 0.01, "Opposite vectors should have negative cosine similarity"); + + // Perpendicular vectors: [1,0] and [0,1] + // Cosine distance ≈ 1 = 1000 (scaled) + let perpendicular_distance = 1000.0; + let cosine_sim = convert_distance_to_similarity(perpendicular_distance, SimilarityMetric::Cosine); + assert!((cosine_sim - 0.0).abs() < 0.01, "Perpendicular vectors should have cosine similarity ≈ 0"); + } + + #[test] + fn test_scaling_factor_documentation() { + // Verify the scaling factors used in the conversions + // This helps document the behavior for future reference + + // Cosine: scaled by 1000.0 + // Maximum cosine distance is 2.0 (raw) = 2000 (scaled) + let max_cosine_distance = 2000.0; + let min_cosine_sim = convert_distance_to_similarity(max_cosine_distance, SimilarityMetric::Cosine); + assert_eq!(min_cosine_sim, -1.0, "Maximum cosine distance should yield similarity of -1.0"); + + // DotProduct: scaled by 1000.0 + // Maximum distance is when dot product is at minimum (negative) + let max_dotproduct_distance = 2000.0; + let min_dotproduct_sim = convert_distance_to_similarity(max_dotproduct_distance, SimilarityMetric::DotProduct); + assert_eq!(min_dotproduct_sim, 0.0, "Maximum dot product distance should yield similarity of 0.0"); + + // Euclidean and Manhattan use 1/(1+distance) formula + // They don't have a hard upper bound but should approach 0 + let large_distance = 1000.0; + let large_euclidean_sim = convert_distance_to_similarity(large_distance, SimilarityMetric::Euclidean); + let large_manhattan_sim = convert_distance_to_similarity(large_distance, SimilarityMetric::Manhattan); + assert!(large_euclidean_sim < 0.01 && large_euclidean_sim > 0.0); + assert!(large_manhattan_sim < 0.01 && large_manhattan_sim > 0.0); + } +} + diff --git a/src/lib.rs b/src/lib.rs index 55b773b..e9a9b3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,8 @@ //! fn main() -> Result<(), Box> { //! let mut client = VectorLiteClient::new(Box::new(EmbeddingGenerator::new()?)); //! -//! client.create_collection("quotes", IndexType::HNSW)?; +//! // Create HNSW collection with specific metric +//! client.create_collection("quotes", IndexType::HNSW, Some(SimilarityMetric::Cosine))?; //! //! let id = client.add_text_to_collection( //! "quotes", @@ -57,11 +58,12 @@ //! })) //! )?; //! +//! // Search without specifying metric - automatically uses the index's metric //! let results = client.search_text_in_collection( //! "quotes", //! "beach games", //! 3, -//! SimilarityMetric::Cosine, +//! None, // Auto-detects from HNSW index //! )?; //! //! for result in &results { @@ -78,11 +80,13 @@ //! - **Complexity**: O(n) search, O(1) insert //! - **Memory**: Linear with dataset size //! - **Use Case**: Small datasets (< 10K vectors) or exact search requirements +//! - **Metric Flexibility**: Supports all similarity metrics dynamically //! //! ### HNSWIndex //! - **Complexity**: O(log n) search, O(log n) insert //! - **Memory**: ~2-3x vector size due to graph structure //! - **Use Case**: Large datasets with approximate search tolerance +//! - **Metric Flexibility**: Built for a specific metric; searches automatically use the index's metric //! //! ## Similarity Metrics //! @@ -130,6 +134,7 @@ pub use embeddings::{EmbeddingGenerator, EmbeddingFunction}; pub use client::{VectorLiteClient, Collection, Settings, IndexType}; pub use server::{create_app, start_server}; pub use persistence::{PersistenceError, save_collection_to_file, load_collection_from_file}; +pub use errors::{VectorLiteError, VectorLiteResult}; use serde::{Serialize, Deserialize}; @@ -224,7 +229,7 @@ pub trait VectorIndex { fn delete(&mut self, id: u64) -> Result<(), String>; /// Search for the k most similar vectors - fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> Vec; + fn search(&self, query: &[f64], k: usize, similarity_metric: SimilarityMetric) -> VectorLiteResult>; /// Get the number of vectors in the index fn len(&self) -> usize; @@ -233,7 +238,7 @@ pub trait VectorIndex { fn is_empty(&self) -> bool; /// Get a vector by its ID - fn get_vector(&self, id: u64) -> Option<&Vector>; + fn get_vector(&self, id: u64) -> Option; /// Get the dimension of vectors in this index fn dimension(&self) -> usize; @@ -285,7 +290,7 @@ impl VectorIndex for VectorIndexWrapper { } } - fn search(&self, query: &[f64], k: usize, s: SimilarityMetric) -> Vec { + fn search(&self, query: &[f64], k: usize, s: SimilarityMetric) -> VectorLiteResult> { match self { VectorIndexWrapper::Flat(index) => index.search(query, k, s), VectorIndexWrapper::HNSW(index) => index.search(query, k, s), @@ -306,7 +311,7 @@ impl VectorIndex for VectorIndexWrapper { } } - fn get_vector(&self, id: u64) -> Option<&Vector> { + fn get_vector(&self, id: u64) -> Option { match self { VectorIndexWrapper::Flat(index) => index.get_vector(id), VectorIndexWrapper::HNSW(index) => index.get_vector(id), @@ -321,6 +326,25 @@ impl VectorIndex for VectorIndexWrapper { } } +impl VectorIndexWrapper { + /// Get the similarity metric this index was built for (HNSW only) + /// Returns None for Flat indexes (which support all metrics) + pub fn metric(&self) -> Option { + match self { + VectorIndexWrapper::Flat(_) => None, + VectorIndexWrapper::HNSW(index) => Some(index.metric()), + } + } + + /// Get the index type + pub fn index_type(&self) -> IndexType { + match self { + VectorIndexWrapper::Flat(_) => IndexType::Flat, + VectorIndexWrapper::HNSW(_) => IndexType::HNSW, + } + } +} + /// Similarity metrics for vector comparison /// /// Different metrics are suitable for different use cases and vector characteristics. @@ -336,7 +360,7 @@ impl VectorIndex for VectorIndexWrapper { /// let cosine_score = SimilarityMetric::Cosine.calculate(&a, &b); /// let euclidean_score = SimilarityMetric::Euclidean.calculate(&a, &b); /// ``` -#[derive(Debug, Clone, Copy, PartialEq, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize)] pub enum SimilarityMetric { /// Cosine similarity - scale-invariant, good for normalized embeddings /// Range: [-1, 1], where 1 is identical @@ -662,7 +686,7 @@ mod tests { ]; let store = FlatIndex::new(3, vectors); let query = vec![1.0, 0.0, 0.0]; - let results = store.search(&query, 2, SimilarityMetric::Cosine); + let results = store.search(&query, 2, SimilarityMetric::Cosine).unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0].id, 0); @@ -694,7 +718,7 @@ mod tests { // Test search through the wrapper let query = vec![1.1, 0.1, 0.1]; - let results = deserialized.search(&query, 1, SimilarityMetric::Cosine); + let results = deserialized.search(&query, 1, SimilarityMetric::Cosine).unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0].id, 1); } diff --git a/src/persistence.rs b/src/persistence.rs index 29a708c..4a52623 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -35,21 +35,30 @@ use crate::{VectorIndexWrapper, Collection, VectorIndex}; #[derive(Error, Debug)] pub enum PersistenceError { #[error("IO error: {0}")] - Io(#[from] std::io::Error), - + Io(std::io::Error), + + #[error("File not found: {0}")] + FileNotFound(String), + #[error("Serialization error: {0}")] Serialization(#[from] serde_json::Error), - + #[error("Invalid file format: {0}")] InvalidFormat(String), - + #[error("Version mismatch: expected {expected}, got {actual}")] VersionMismatch { expected: String, actual: String }, - + #[error("Collection error: {0}")] Collection(String), } +impl From for PersistenceError { + fn from(error: std::io::Error) -> Self { + PersistenceError::Io(error) + } +} + /// File header containing version and format information #[derive(Debug, Serialize, Deserialize)] pub struct FileHeader { @@ -138,9 +147,15 @@ pub fn save_collection_to_file(collection: &Collection, path: &Path) -> Result<( /// Load a collection from a file pub fn load_collection_from_file(path: &Path) -> Result { - let json_data = fs::read_to_string(path)?; + let json_data = fs::read_to_string(path).map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + PersistenceError::FileNotFound(path.display().to_string()) + } else { + PersistenceError::Io(e) + } + })?; let collection_data: CollectionData = serde_json::from_str(&json_data)?; - + // Validate version compatibility if collection_data.header.version != "1.0.0" { return Err(PersistenceError::VersionMismatch { @@ -148,7 +163,7 @@ pub fn load_collection_from_file(path: &Path) -> Result Result VectorLiteResult { } } +// Implement IntoResponse for VectorLiteError to enable automatic error responses +impl IntoResponse for VectorLiteError { + fn into_response(self) -> Response { + let status = self.status_code(); + let error_message = self.to_string(); + + let body = Json(ErrorResponse { + message: error_message, + }); + + (status, body).into_response() + } +} + // Handlers async fn health_check() -> Json { Json(serde_json::json!({ @@ -173,8 +188,8 @@ async fn health_check() -> Json { async fn list_collections( State(state): State, -) -> Result, StatusCode> { - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; +) -> Result, VectorLiteError> { + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for list_collections".to_string()))?; let collections = client.list_collections(); Ok(Json(ListCollectionsResponse { collections, @@ -184,75 +199,58 @@ async fn list_collections( async fn create_collection( State(state): State, Json(payload): Json, -) -> Result, StatusCode> { - let index_type = match parse_index_type(&payload.index_type) { - Ok(t) => t, - Err(e) => { - return Err(e.status_code()); - } +) -> Result, VectorLiteError> { + let index_type = parse_index_type(&payload.index_type)?; + + // Parse metric - optional for Flat index, required for HNSW + let metric = if payload.metric.is_empty() { + None // No metric specified + } else { + Some(parse_similarity_metric(&payload.metric)?) }; - let mut client = state.write().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.create_collection(&payload.name, index_type) { - Ok(_) => { - info!("Created collection: {}", payload.name); - Ok(Json(CreateCollectionResponse { - name: payload.name, - })) - } - Err(e) => { - error!("Failed to create collection '{}': {}", payload.name, e); - Err(e.status_code()) - } - } + let mut client = state.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for create_collection".to_string()))?; + client.create_collection(&payload.name, index_type, metric)?; + info!("Created collection: {}", payload.name); + Ok(Json(CreateCollectionResponse { + name: payload.name, + })) } async fn get_collection_info( State(state): State, Path(collection_name): Path, -) -> Result, StatusCode> { - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.get_collection_info(&collection_name) { - Ok(info) => Ok(Json(CollectionInfoResponse { - info: Some(info), - })), - Err(e) => Err(e.status_code()), - } +) -> Result, VectorLiteError> { + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_collection_info".to_string()))?; + let info = client.get_collection_info(&collection_name)?; + Ok(Json(CollectionInfoResponse { + info: Some(info), + })) } async fn delete_collection( State(state): State, Path(collection_name): Path, -) -> Result, StatusCode> { - let mut client = state.write().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.delete_collection(&collection_name) { - Ok(_) => { - info!("Deleted collection: {}", collection_name); - Ok(Json(CreateCollectionResponse { - name: collection_name, - })) - } - Err(e) => Err(e.status_code()), - } +) -> Result, VectorLiteError> { + let mut client = state.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for delete_collection".to_string()))?; + client.delete_collection(&collection_name)?; + info!("Deleted collection: {}", collection_name); + Ok(Json(CreateCollectionResponse { + name: collection_name, + })) } async fn add_text( State(state): State, Path(collection_name): Path, Json(payload): Json, -) -> Result, StatusCode> { - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.add_text_to_collection(&collection_name, &payload.text, payload.metadata) { - Ok(id) => { - info!("Added text to collection '{}' with ID: {}", collection_name, id); - Ok(Json(AddTextResponse { - id: Some(id), - })) - } - Err(e) => { - Err(e.status_code()) - } - } +) -> Result, VectorLiteError> { + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for add_text".to_string()))?; + let id = client.add_text_to_collection(&collection_name, &payload.text, payload.metadata)?; + info!("Added text to collection '{}' with ID: {}", collection_name, id); + Ok(Json(AddTextResponse { + id: Some(id), + })) } @@ -261,144 +259,99 @@ async fn search_text( State(state): State, Path(collection_name): Path, Json(payload): Json, -) -> Result, StatusCode> { +) -> Result, VectorLiteError> { let k = payload.k.unwrap_or(10); let similarity_metric = match payload.similarity_metric { - Some(metric) => match parse_similarity_metric(&metric) { - Ok(m) => m, - Err(e) => { - return Err(e.status_code()); - } - }, - None => SimilarityMetric::Cosine, + Some(metric) => Some(parse_similarity_metric(&metric)?), + None => None, // No metric specified - will auto-detect }; - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.search_text_in_collection(&collection_name, &payload.query, k, similarity_metric) { - Ok(results) => { - info!("Search completed for collection '{}' with {} results", collection_name, results.len()); - Ok(Json(SearchResponse { - results: Some(results), - })) - } - Err(e) => Err(e.status_code()), - } + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for search_text".to_string()))?; + let results = client.search_text_in_collection(&collection_name, &payload.query, k, similarity_metric)?; + info!("Search completed for collection '{}' with {} results", collection_name, results.len()); + Ok(Json(SearchResponse { + results: Some(results), + })) } async fn get_vector( State(state): State, Path((collection_name, vector_id)): Path<(String, u64)>, -) -> Result, StatusCode> { - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.get_vector_from_collection(&collection_name, vector_id) { - Ok(Some(vector)) => { - Ok(Json(serde_json::json!({ - "vector": vector - }))) - } - Ok(None) => { - Err(StatusCode::NOT_FOUND) - } - Err(e) => { - Err(e.status_code()) - } - } +) -> Result, VectorLiteError> { + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_vector".to_string()))?; + let vector = client.get_vector_from_collection(&collection_name, vector_id)? + .ok_or(VectorLiteError::VectorNotFound { id: vector_id })?; + Ok(Json(serde_json::json!({ + "vector": vector + }))) } async fn delete_vector( State(state): State, Path((collection_name, vector_id)): Path<(String, u64)>, -) -> Result, StatusCode> { - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - match client.delete_from_collection(&collection_name, vector_id) { - Ok(_) => { - info!("Deleted vector {} from collection '{}'", vector_id, collection_name); - Ok(Json(serde_json::json!({}))) - } - Err(e) => { - Err(e.status_code()) - } - } +) -> Result, VectorLiteError> { + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for delete_vector".to_string()))?; + client.delete_from_collection(&collection_name, vector_id)?; + info!("Deleted vector {} from collection '{}'", vector_id, collection_name); + Ok(Json(serde_json::json!({}))) } async fn save_collection( State(state): State, Path(collection_name): Path, Json(payload): Json, -) -> Result, StatusCode> { - let client = state.read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - +) -> Result, VectorLiteError> { + let client = state.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for save_collection".to_string()))?; + // Get the collection - let collection = match client.get_collection(&collection_name) { - Some(collection) => collection, - None => { - return Err(StatusCode::NOT_FOUND); - } - }; + let collection = client.get_collection(&collection_name) + .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.clone() })?; // Convert file path to PathBuf let file_path = PathBuf::from(&payload.file_path); - + // Save the collection - match collection.save_to_file(&file_path) { - Ok(_) => { - info!("Saved collection '{}' to file: {}", collection_name, payload.file_path); - Ok(Json(SaveCollectionResponse { - file_path: Some(payload.file_path), - })) - } - Err(_) => { - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } + collection.save_to_file(&file_path)?; + info!("Saved collection '{}' to file: {}", collection_name, payload.file_path); + Ok(Json(SaveCollectionResponse { + file_path: Some(payload.file_path), + })) } async fn load_collection( State(state): State, Json(payload): Json, -) -> Result, StatusCode> { +) -> Result, VectorLiteError> { // Convert file path to PathBuf let file_path = PathBuf::from(&payload.file_path); - + // Load the collection from file - let collection = match crate::Collection::load_from_file(&file_path) { - Ok(collection) => collection, - Err(e) => { - // Check if it's a file not found error - if let crate::persistence::PersistenceError::Io(io_err) = &e - && io_err.kind() == std::io::ErrorKind::NotFound { - return Err(VectorLiteError::FileNotFound(format!("File not found: {}", payload.file_path)).status_code()); - } - return Err(VectorLiteError::from(e).status_code()); - } - }; + let collection = crate::Collection::load_from_file(&file_path)?; // Determine the collection name to use let collection_name = payload.collection_name.unwrap_or_else(|| collection.name().to_string()); - + // Add the collection to the client - let mut client = state.write().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - + let mut client = state.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for load_collection".to_string()))?; + // Check if collection already exists if client.has_collection(&collection_name) { - return Err(StatusCode::CONFLICT); + return Err(VectorLiteError::CollectionAlreadyExists { name: collection_name }); } // Extract the index from the loaded collection let index = { - let index_guard = collection.index_read().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let index_guard = collection.index_read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for collection index".to_string()))?; (*index_guard).clone() }; - + // Create a new collection with the loaded data let new_collection = crate::Collection::new(collection_name.clone(), index); - + // Add the collection to the client - if client.add_collection(new_collection).is_err() { - return Err(StatusCode::INTERNAL_SERVER_ERROR); - } - + client.add_collection(new_collection)?; + info!("Loaded collection '{}' from file: {}", collection_name, payload.file_path); Ok(Json(LoadCollectionResponse { collection_name: Some(collection_name), diff --git a/tests/http_integration_test.rs b/tests/http_integration_test.rs index 21116ec..9e5026c 100644 --- a/tests/http_integration_test.rs +++ b/tests/http_integration_test.rs @@ -133,7 +133,7 @@ async fn test_create_duplicate_collection() { #[tokio::test] async fn test_get_collection_info() { let mut client = create_test_client(); - client.create_collection("test_collection", vectorlite::IndexType::Flat).unwrap(); + client.create_collection("test_collection", vectorlite::IndexType::Flat, None).unwrap(); let app = create_app(std::sync::Arc::new(std::sync::RwLock::new(client))); let request = Request::builder() @@ -155,7 +155,7 @@ async fn test_get_collection_info() { #[tokio::test] async fn test_add_text_to_collection() { let mut client = create_test_client(); - client.create_collection("test_collection", vectorlite::IndexType::Flat).unwrap(); + client.create_collection("test_collection", vectorlite::IndexType::Flat, None).unwrap(); let app = create_app(std::sync::Arc::new(std::sync::RwLock::new(client))); let payload = json!({ @@ -181,7 +181,7 @@ async fn test_add_text_to_collection() { #[tokio::test] async fn test_search_text() { let mut client = create_test_client(); - client.create_collection("test_collection", vectorlite::IndexType::Flat).unwrap(); + client.create_collection("test_collection", vectorlite::IndexType::Flat, None).unwrap(); client.add_text_to_collection("test_collection", "Hello world", None).unwrap(); let app = create_app(std::sync::Arc::new(std::sync::RwLock::new(client))); @@ -212,7 +212,7 @@ async fn test_search_text() { #[tokio::test] async fn test_get_vector() { let mut client = create_test_client(); - client.create_collection("test_collection", vectorlite::IndexType::Flat).unwrap(); + client.create_collection("test_collection", vectorlite::IndexType::Flat, None).unwrap(); client.add_text_to_collection("test_collection", "Hello world", None).unwrap(); let app = create_app(std::sync::Arc::new(std::sync::RwLock::new(client))); @@ -234,7 +234,7 @@ async fn test_get_vector() { #[tokio::test] async fn test_delete_vector() { let mut client = create_test_client(); - client.create_collection("test_collection", vectorlite::IndexType::Flat).unwrap(); + client.create_collection("test_collection", vectorlite::IndexType::Flat, None).unwrap(); client.add_text_to_collection("test_collection", "Hello world", None).unwrap(); let app = create_app(std::sync::Arc::new(std::sync::RwLock::new(client))); @@ -255,7 +255,7 @@ async fn test_delete_vector() { #[tokio::test] async fn test_delete_collection() { let mut client = create_test_client(); - client.create_collection("test_collection", vectorlite::IndexType::Flat).unwrap(); + client.create_collection("test_collection", vectorlite::IndexType::Flat, None).unwrap(); let app = create_app(std::sync::Arc::new(std::sync::RwLock::new(client))); let request = Request::builder()