Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/classifai/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
VectorStoreSearchOutput,
)
from .main import VectorStore
from .types import MetricSettings

__all__ = [
"MetricSettings",
"VectorStore",
"VectorStoreEmbedInput",
"VectorStoreEmbedOutput",
Expand Down
177 changes: 135 additions & 42 deletions src/classifai/indexers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,22 @@
vector databases from your own text data.
"""

from __future__ import annotations

import json
import logging
import os
import shutil
import time
import uuid
from typing import TYPE_CHECKING

import numpy as np
import polars as pl
from tqdm.autonotebook import tqdm

if TYPE_CHECKING:
from ..vectorisers import VectoriserBase
from .dataclasses import (
VectorStoreEmbedInput,
VectorStoreEmbedOutput,
Expand All @@ -45,6 +50,7 @@
VectorStoreSearchInput,
VectorStoreSearchOutput,
)
from .types import MetricSettings

# Configure logging for your application
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
Expand All @@ -53,13 +59,28 @@
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)


def metricvalid(metric: MetricSettings):
"""Test that the given metric is a valid option.

Args:
metric (str): The selected metric for the VectorStore

Raises:
ValueError: If value is not in MetricSettings

"""
if metric not in MetricSettings:
raise ValueError(f"The scoring metric input '{metric}' is not in the valid metrics {list(MetricSettings)}")


class VectorStore:
"""A class to model and create 'VectorStore' objects for building and searching vector databases from CSV text files.

Attributes:
file_name (str): the original file with the knowledgebase to build the vector store
data_type (str): the data type of the original file (curently only csv supported)
vectoriser (object): A Vectoriser object from the corresponding ClassifAI Pacakge module
vectoriser (VectoriserBase): A Vectoriser object from the corresponding ClassifAI Pacakge module
scoring_metric(MetricSettings): The metric to use for scoring
batch_size (int): the batch size to pass to the vectoriser when embedding
meta_data (dict[str:type]): key-value pairs of metadata to extract from the input file and their correpsonding types
output_dir (str): the path to the output directory where the VectorStore will be saved
Expand All @@ -68,27 +89,31 @@ class VectorStore:
num_vectors (int): how many vectors are in the vector store
vectoriser_class (str): the type of vectoriser used to create embeddings
hooks (dict): A dictionary of user-defined hooks for preprocessing and postprocessing.
normalize(bool): Flag to choose if to normalize vectors.
"""

def __init__( # noqa: PLR0913
self,
file_name,
data_type,
vectoriser,
vectoriser: VectoriserBase,
scoring_metric: MetricSettings | str = MetricSettings.DOT_PRODUCT,
batch_size=8,
meta_data=None,
output_dir=None,
overwrite=False,
hooks=None,
normalize=False,
):
"""Initializes the VectorStore object by processing the input CSV file and generating
vector embeddings.

Args:
file_name (str): The name of the input CSV file.
data_type (str): The type of input data (currently supports only "csv").
vectoriser (object): The vectoriser object used to transform text into
vectoriser (VectoriserBase): The vectoriser object used to transform text into
vector embeddings.
scoring_metric(MetricSettings): The metric to use for scoring
batch_size (int, optional): The batch size for processing the input file and batching to
vectoriser. Defaults to 8.
meta_data (dict, optional): key,value pair metadata column names to extract from the input file and their types.
Expand All @@ -97,7 +122,7 @@ def __init__( # noqa: PLR0913
Defaults to None, where input file name will be used.
overwrite (bool, optional): If True, allows overwriting existing folders with the same name. Defaults to false to prevent accidental overwrites.
hooks (dict, optional): A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None.

normalize(bool, optional): A flag to make vectorstore normalize its vdb

Raises:
ValueError: If the data type is not supported or if the folder name conflicts with an existing folder.
Expand All @@ -107,6 +132,7 @@ def __init__( # noqa: PLR0913
self.file_name = file_name
self.data_type = data_type
self.vectoriser = vectoriser
self.scoring_metric = scoring_metric
self.batch_size = batch_size
self.meta_data = meta_data if meta_data is not None else {}
self.output_dir = output_dir
Expand All @@ -115,10 +141,14 @@ def __init__( # noqa: PLR0913
self.num_vectors = None
self.vectoriser_class = vectoriser.__class__.__name__
self.hooks = {} if hooks is None else hooks
self.normalize = normalize

if self.data_type not in ["csv"]:
raise ValueError(f"Data type '{self.data_type}' not supported. Choose from ['csv'].")

## validate scoring metric
metricvalid(self.scoring_metric)

if self.output_dir is None:
logging.info("No output directory specified, attempting to use input file name as output folder name.")

Expand Down Expand Up @@ -147,6 +177,13 @@ def __init__( # noqa: PLR0913

self._create_vector_store_index()

## init normalization
if normalize:
embeddings = self.vectors["embeddings"].to_numpy()
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

self.vectors.with_columns(pl.Series("embeddings", embeddings))

logging.info("Gathering metadata and saving vector store / metadata...")

self.vector_shape = self.vectors["embeddings"].to_numpy().shape[1]
Expand All @@ -157,6 +194,8 @@ def __init__( # noqa: PLR0913
self._save_metadata(os.path.join(self.output_dir, "metadata.json"))

logging.info("Vector Store created - files saved to %s", self.output_dir)
## will norm in memory if using cosine metrics
self._check_norm_vdb()

def _save_metadata(self, path):
"""Saves metadata about the vector store to a JSON file.
Expand All @@ -180,6 +219,7 @@ def _save_metadata(self, path):
"num_vectors": self.num_vectors,
"created_at": time.time(),
"meta_data": serializable_column_meta_data,
"normalized": self.normalize,
}

with open(path, "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -347,10 +387,81 @@ def reverse_search(self, query: VectorStoreReverseSearchInput, n_results=100) ->

return result_df

def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput:
def _check_norm_vdb(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this functionality a lot, but I think it should be the vectoriser's job to output embeddings in the desired form, not the vector store changing them after the fact.
My preference would be to update the Vectorisers' .transform() methods to take an optional (default False) normalise argument, which applies this normalisation if set to True.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree slightly here.

An informed user, who knows their embedding model already outputs normalized embeddings, should then be able to just use the dotproduct metric, which would give them the effects of cosine similarity without having to do the extra norm checks and steps they would need if they set to a cosine metric.

also i think its a good idea to keep the vectorisers pure and not overcomplicate the logic argument logic - whereas the vectorstore responsible for housing, reloading and metric calculations of the vectors probably should be keeping a note on whether the vectors are normalised or not

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An informed user, who knows their embedding model already outputs normalized embeddings, should then be able to just use the dotproduct metric, which would give them the effects of cosine similarity without having to do the extra norm checks and steps they would need if they set to a cosine metric.

I'm not sure I follow what you mean; if a user knows their embedding model already outputs normalised embeddings, they could just not set the normalise flag when creating the Vectoriser.

also i think its a good idea to keep the vectorisers pure and not overcomplicate the logic argument logic

This is an operation that happens directly on the vectors, a step before any use in a vector store or scoring. I think it fits in well with the task of the Vectoriser, and avoids the other issues you discussed - such as any need to duplicate vectors in the vector store and set/read metadata flags about whether the vector store is normalised.

Lets talk about it in our call later 👍

"""Normalise Vdb if using cosine similarity."""
if "cosine" in self.scoring_metric and not self.normalize:
logging.warning(
"Note: you are using metrics that require norms with un-normed vdb data, this will be normed for search but vdb file will not be changed"
)
embeddings = self.vectors["embeddings"].to_numpy()
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

self.vectors.with_columns(pl.Series("embeddings", embeddings))

def score(
self, query: np.ndarray, n_results: int, query_ids_batch: list[str], query_text_batch: list[str]
) -> tuple[pl.DataFrame, np.ndarray]:
"""Perform Scoring and return Top Values.

Args:
query(np.ndarray): query for search
n_results(int): number of results to return
query_ids_batch(list[str]): ids of query batch
query_text_batch(list[str]): source text of query batch

Returns:
pl.DataFrame: The Polars DataFrame containing the top n most similar results to the query
"""
docs = self.vectors["embeddings"].to_numpy()
if self.scoring_metric == MetricSettings.DOT_PRODUCT:
result = query @ docs.T
elif self.scoring_metric == MetricSettings.L2_DISTANCE:
# Dot products (n_queries, n_docs)
dots = query @ docs.T

# Squared norms
q_sq = np.sum(query * query, axis=1, keepdims=True) # (n_queries, 1)
d_sq = np.sum(docs * docs, axis=1, keepdims=True).T # (1, n_docs)

# Squared distances
dist_sq = q_sq + d_sq - 2.0 * dots

# Numerical safety: tiny negatives -> 0
np.maximum(dist_sq, 0.0, out=dist_sq)

# True L2 distances
result = np.sqrt(dist_sq) # (n_queries, n_docs)

# Get the top n_results indices for each query in the batch
idx = np.argpartition(result, -n_results, axis=1)[:, -n_results:]

# Sort top n_results indices by their scores in descending order
idx_sorted = np.zeros_like(idx)
scores = np.zeros_like(idx, dtype=float)

for j in range(idx.shape[0]):
row_scores = result[j, idx[j]]
sorted_indices = np.argsort(row_scores)[::-1]
idx_sorted[j] = idx[j, sorted_indices]
scores[j] = row_scores[sorted_indices]

# Build a DataFrame for the current batch results
result_df = pl.DataFrame(
{
"query_id": np.repeat(query_ids_batch, n_results),
"query_text": np.repeat(query_text_batch, n_results),
"rank": np.tile(np.arange(n_results), len(query_text_batch)),
"score": scores.flatten(),
}
)
return result_df, idx_sorted

def search(
self, query: VectorStoreSearchInput, n_results: int = 10, batch_size: int = 8
) -> VectorStoreSearchOutput:
"""Searches the vector store using queries from a VectorStoreSearchInput object and returns
ranked results in VectorStoreSearchOutput object. In batches, converts users text queries into vector embeddings,
computes cosine similarity with stored document vectors, and retrieves the top results.
computes similarity scoring with stored document vectors, and retrieves the top results.

Args:
query (VectorStoreSearchInput): A VectoreStoreSearchInput object containing the text query or list of queries to search for with ids.
Expand Down Expand Up @@ -386,35 +497,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
# Get the current batch of queries
query_text_batch = query.query.to_list()[i : i + batch_size]
query_ids_batch = query.id.to_list()[i : i + batch_size]

# Convert the current batch of queries to vectors
query_vectors = self.vectoriser.transform(query_text_batch)

# Compute cosine similarity between the query batch and document vectors
cosine = query_vectors @ self.vectors["embeddings"].to_numpy().T

# Get the top n_results indices for each query in the batch
idx = np.argpartition(cosine, -n_results, axis=1)[:, -n_results:]

# Sort top n_results indices by their scores in descending order
idx_sorted = np.zeros_like(idx)
scores = np.zeros_like(idx, dtype=float)

for j in range(idx.shape[0]):
row_scores = cosine[j, idx[j]]
sorted_indices = np.argsort(row_scores)[::-1]
idx_sorted[j] = idx[j, sorted_indices]
scores[j] = row_scores[sorted_indices]

# Build a DataFrame for the current batch results
result_df = pl.DataFrame(
{
"query_id": np.repeat(query_ids_batch, n_results),
"query_text": np.repeat(query_text_batch, n_results),
"rank": np.tile(np.arange(n_results), len(query_text_batch)),
"score": scores.flatten(),
}
)
# perform scoring and return frame and ids
result_df, idx_sorted = self.score(query_vectors, n_results, query_ids_batch, query_text_batch)

# Get the vector store results for the current batch
ranked_docs = self.vectors[idx_sorted.flatten().tolist()].select(["id", "text", *self.meta_data.keys()])
Expand Down Expand Up @@ -461,7 +548,13 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
return result_df

@classmethod
def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None):
def from_filespace(
cls,
folder_path,
vectoriser: VectoriserBase,
scoring_metric: MetricSettings | str = MetricSettings.DOT_PRODUCT,
hooks: dict | None = None,
):
"""Creates a `VectorStore` instance from stored metadata and Parquet files.
This method reads the metadata and vectors from the specified folder,
validates the contents, and initializes a `VectorStore` object with the
Expand All @@ -474,7 +567,8 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None):

Args:
folder_path (str): The folder path containing the metadata and Parquet files.
vectoriser (object): The vectoriser object used to transform text into vector embeddings.
vectoriser (VectoriserBase): The vectoriser object used to transform text into vector embeddings.
scoring_metric(MetricSettings): The metric to use for scoring
hooks (dict, optional): A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None.

Returns:
Expand All @@ -492,14 +586,11 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None):
with open(metadata_path, encoding="utf-8") as f:
metadata = json.load(f)

## validate scoring metric
metricvalid(scoring_metric)

# check that the correct keys exist in metadata
required_keys = [
"vectoriser_class",
"vector_shape",
"num_vectors",
"created_at",
"meta_data",
]
required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data", "normalized"]
for key in required_keys:
if key not in metadata:
raise ValueError(f"Metadata file is missing required key: {key}")
Expand Down Expand Up @@ -545,12 +636,14 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None):
vector_store.file_name = None
vector_store.data_type = None
vector_store.vectoriser = vectoriser
vector_store.scoring_metric = scoring_metric
vector_store.batch_size = None
vector_store.meta_data = deserialized_column_meta_data
vector_store.vectors = df
vector_store.vector_shape = metadata["vector_shape"]
vector_store.num_vectors = metadata["num_vectors"]
vector_store.vectoriser_class = metadata["vectoriser_class"]
vector_store.normalize = metadata["normalized"]
vector_store.hooks = {} if hooks is None else hooks

vector_store._check_norm_vdb()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd favour a 'normalise once' approach -

  1. when the VDB is being constructed by _create_vector_store_index(), it checks if the user specified a metric that requires normalised vectors and normalises the created collection and then saves them to the polars df/parquet file.
  2. Then we'd record the 'metric' used in the metadata file
  3. when the parquet is loaded back in with from_filespace() we know to use the appropriate metric already as its stored in the metadata file and theres no need to redo the normalisation

so i'd also take the 'metric_setting' parameter out of the class method from_filespace() and rely just on the metadata file.

this would mean less operations every time we load the vectorstore in, after initial creation - potentially at the cost of losing the magnitude information and not being able to get it back without running the build step again with a different metric

Copy link
Collaborator Author

@rileyok-ons rileyok-ons Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by adding normalize meta field, if choosing cosine with un-normed will norm but will warn user

return vector_store
6 changes: 6 additions & 0 deletions src/classifai/indexers/types.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we scrapped all 6 of these and just had ['IP', 'L2'].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think L2 squared and IP squared should be a downstream postprocessing hook as its just a common scoring operation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seen this suggested previously, if we want this can sort

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with that plan 👍

I'd like if we added an example to one of the notebooks showing a way of wrapping one of the Vectorisers to add normalisation though, to tide users over until we properly offer normalisation as an option.

I can add that to this PR tomorrow, if nobody objects.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we scrapped all 6 of these and just had ['IP', 'L2'].

Would you be okay with renaming 'IP'->'dot' for this? I think 'dot' would be more easily understood by users via docstrings without needing to explore documentation etc. to find out / confirm IP = Inner Product

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we just leave it - if users really really want it they can make their own custom vectoriser that wraps the hugging face vectoriser - but if you really wanted to you could update the custom_vectoriser demo notebook to have a section on this and show how to do it to the hugging face class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do already have one user group requesting this functionality (and currently using a custom wrapped HF Vectoriser to achieve it), so I think it is worth adding to the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... they did use a custom vectoriser? 😀

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the correct answer is, we definitely don't want to be adding a variant of every Vectoriser called VectoriserX_normalised, or a wrapper for each class. Maybe 1 utility wrapper that wraps round all our Vectoriser class imps.... but what is the benefit/tradeoffs of that new class, which we'd have to add more docs and ensure it's compatible forever, versus guiding users in how to do it with our existing custom vectoriser / base class architecture.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that I made for them as a one-off solution as the package doesn't yet offer that - I'm saying it would be useful to have that knowledge made accessible in the documentation for other users

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class MetricSettings(str, Enum):
DOT_PRODUCT = "dot_product"
L2_DISTANCE = "L2_distance"