diff --git a/src/classifai/exceptions.py b/src/classifai/exceptions.py new file mode 100644 index 0000000..962412b --- /dev/null +++ b/src/classifai/exceptions.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(eq=False) +class ClassifaiError(Exception): + """Base error for the package. + + - message: what happened (human readable) + - code: stable identifier (machine readable; optional but useful) + - context: small debug hints (counts, ids, model name; avoid secrets / full text) + """ + + message: str + code: str = "classifai_error" + context: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + super().__init__(self.message) + + def to_dict(self) -> dict[str, Any]: + data = {"error": self.code, "detail": self.message} + if self.context: + data["context"] = self.context + return data + + +# ---- Subclasses ---- + + +@dataclass(eq=False) +class ConfigurationError(ClassifaiError): + code: str = "configuration_error" + + +@dataclass(eq=False) +class DependencyError(ClassifaiError): + code: str = "dependency_error" + + +@dataclass(eq=False) +class DataValidationError(ClassifaiError): + code: str = "validation_error" + + +@dataclass(eq=False) +class ExternalServiceError(ClassifaiError): + code: str = "external_service_error" + + +@dataclass(eq=False) +class VectorisationError(ClassifaiError): + code: str = "vectorisation_error" + + +@dataclass(eq=False) +class IndexBuildError(ClassifaiError): + code: str = "index_build_error" + + +@dataclass(eq=False) +class HookValidationError(ClassifaiError): + code: str = "hook_validation_error" diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index d35dc2c..d49a267 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -37,6 +37,16 @@ import polars as pl from tqdm.autonotebook import tqdm +from classifai.exceptions import ( + ClassifaiError, + ConfigurationError, + DataValidationError, + HookError, + IndexBuildError, + VectorisationError, +) + +from ..vectorisers.base import VectoriserBase from .dataclasses import ( VectorStoreEmbedInput, VectorStoreEmbedOutput, @@ -70,7 +80,7 @@ class VectorStore: hooks (dict): A dictionary of user-defined hooks for preprocessing and postprocessing. """ - def __init__( # noqa: PLR0913 + def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self, file_name, data_type, @@ -100,10 +110,41 @@ def __init__( # noqa: PLR0913 Raises: - ValueError: If the data type is not supported or if the folder name conflicts with an existing folder. + DataValidationError: If input arguments are invalid or if there are issues with the input file. + ConfigurationError: If there are configuration issues, such as output directory problems. + IndexBuildError: If there are failures during index building or saving outputs. """ - # Run the Pydantic validator first which will raise errors if the inputs are invalid + # ---- Input validation (caller mistakes) -> DataValidationError / ConfigurationError + if not isinstance(file_name, str) or not file_name.strip(): + raise DataValidationError("file_name must be a non-empty string.", context={"file_name": file_name}) + + if not os.path.exists(file_name): + raise DataValidationError("Input file does not exist.", context={"file_name": file_name}) + + if data_type not in ["csv"]: + raise DataValidationError( + "Unsupported data_type. Choose from ['csv'].", + context={"data_type": data_type}, + ) + + if not isinstance(vectoriser, VectoriserBase): + raise ConfigurationError( + "vectoriser must be an instance of Vectoriser(Base) with a .transform(texts) method.", + context={"vectoriser_type": type(vectoriser).__name__}, + ) + + if not isinstance(batch_size, int) or batch_size < 1: + raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size}) + if meta_data is not None and not isinstance(meta_data, dict): + raise DataValidationError( + "meta_data must be a dict or None.", context={"meta_data_type": type(meta_data).__name__} + ) + + if hooks is not None and not isinstance(hooks, dict): + raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) + + # ---- Assign fields self.file_name = file_name self.data_type = data_type self.vectoriser = vectoriser @@ -116,47 +157,60 @@ def __init__( # noqa: PLR0913 self.vectoriser_class = vectoriser.__class__.__name__ self.hooks = {} if hooks is None else hooks - if self.data_type not in ["csv"]: - raise ValueError(f"Data type '{self.data_type}' not supported. Choose from ['csv'].") - - if self.output_dir is None: - logging.info("No output directory specified, attempting to use input file name as output folder name.") - - # Normalize the file name to ensure it doesn't include relative paths or extensions - normalized_file_name = os.path.basename(os.path.splitext(self.file_name)[0]) - # Check if the folder exists in the specified subdirectory - self.output_dir = os.path.join(normalized_file_name) - if os.path.isdir(self.output_dir): - if overwrite: - shutil.rmtree(self.output_dir) - else: - raise ValueError( - f"The name '{self.output_dir}' is already used as a folder in the subdirectory. Pass overwrite=True to overwrite the folder." - ) - os.makedirs(self.output_dir, exist_ok=True) + # ---- Output directory handling (filesystem problems) -> ConfigurationError + try: + if self.output_dir is None: + logging.info("No output directory specified, attempting to use input file name as output folder name.") + normalized_file_name = os.path.basename(os.path.splitext(self.file_name)[0]) + self.output_dir = os.path.join(normalized_file_name) - else: if os.path.isdir(self.output_dir): if overwrite: shutil.rmtree(self.output_dir) else: - raise ValueError( - f"The name '{self.output_dir}' is already used as a folder in the subdirectory. Pass overwrite=True to overwrite the folder." + raise ConfigurationError( + "Output directory already exists. Pass overwrite=True to overwrite the folder.", + context={"output_dir": self.output_dir}, ) os.makedirs(self.output_dir, exist_ok=True) + except ClassifaiError: + raise + except Exception as e: + raise ConfigurationError( + "Failed to prepare output directory.", + context={"output_dir": self.output_dir}, + ) from e - self._create_vector_store_index() + # ---- Build index (wrap every unexpected failure) -> IndexBuildError + try: + self._create_vector_store_index() + except ClassifaiError: + # preserve already-classified errors (e.g. vectoriser raised DataValidationError) + raise + except Exception as e: + raise IndexBuildError( + "Failed to create vector store index.", + context={"file_name": self.file_name, "data_type": self.data_type, "batch_size": self.batch_size}, + ) from e - logging.info("Gathering metadata and saving vector store / metadata...") + # ---- Save + derived metadata (IO/format problems) -> IndexBuildError + try: + logging.info("Gathering metadata and saving vector store / metadata...") - self.vector_shape = self.vectors["embeddings"].to_numpy().shape[1] - self.num_vectors = len(self.vectors) + self.vector_shape = self.vectors["embeddings"].to_numpy().shape[1] + self.num_vectors = len(self.vectors) - ## save everything to the folder etc: metadata, parquet and vectoriser - self.vectors.write_parquet(os.path.join(self.output_dir, "vectors.parquet")) - self._save_metadata(os.path.join(self.output_dir, "metadata.json")) + self.vectors.write_parquet(os.path.join(self.output_dir, "vectors.parquet")) + self._save_metadata(os.path.join(self.output_dir, "metadata.json")) - logging.info("Vector Store created - files saved to %s", self.output_dir) + logging.info("Vector Store created - files saved to %s", self.output_dir) + except ClassifaiError: + raise + except Exception as e: + raise IndexBuildError( + "Vector store was created but saving outputs failed.", + context={"output_dir": self.output_dir}, + ) from e def _save_metadata(self, path): """Saves metadata about the vector store to a JSON file. @@ -165,8 +219,12 @@ def _save_metadata(self, path): path (str): The file path where the metadata JSON file will be saved. Raises: - Exception: If an error occurs while saving the metadata file. + DataValidationError: If the path argument is invalid. + IndexBuildError: If there are failures during serialization or file writing. """ + if not isinstance(path, str) or not path.strip(): + raise DataValidationError("path must be a non-empty string.", context={"path": path}) + try: # Convert meta_data types to strings for JSON serialization serializable_column_meta_data = { @@ -184,11 +242,29 @@ def _save_metadata(self, path): with open(path, "w", encoding="utf-8") as f: json.dump(metadata, f, indent=4) - except Exception: - logging.error("Something went wrong trying to save the metadata file") + + except ClassifaiError: + # Preserve package-specific exceptions unchanged raise + except (TypeError, ValueError) as e: + # Usually means something in `meta_data` isn't JSON-serializable + raise IndexBuildError( + "Failed to serialize vector store metadata to JSON.", + context={"path": path}, + ) from e + except OSError as e: + # Permission denied, invalid path, disk full, etc. + raise IndexBuildError( + "Failed to write metadata file.", + context={"path": path}, + ) from e + except Exception as e: + raise IndexBuildError( + "Unexpected error while saving metadata file.", + context={"path": path}, + ) from e - def _create_vector_store_index(self): + def _create_vector_store_index(self): # noqa: C901 """Processes text strings in batches, generates vector embeddings, and creates the vector store. Called from the constructor once other metadata has been set. @@ -197,48 +273,86 @@ def _create_vector_store_index(self): a Parquet file in the output_dir attribute, and stores in the vectors attribute. Raises: - Exception: If an error occurs during file processing or vector generation. + DataValidationError: If there are issues reading or validating the input file. + IndexBuildError: If there are failures during embedding or building the vectors table. """ - # NOTE: read_excel schema_overrides only allows polars datatypes, not python built-in types - # Excel support disabled until we decide how to handle this. - # - # if self.data_type == "excel": - # self.vectors = pl.read_excel( - # self.file_name, - # has_header=True, - # columns=["id", "text", *self.meta_data.keys()], - # schema_overrides={"id": pl.String, "text": pl.String} | self.meta_data, - # ) - if self.data_type == "csv": - self.vectors = pl.read_csv( - self.file_name, - columns=["id", "text", *self.meta_data.keys()], - dtypes=self.meta_data | {"id": str, "text": str}, - ) - self.vectors = self.vectors.with_columns( - pl.Series("uuid", [str(uuid.uuid4()) for _ in range(self.vectors.height)]) - ) - else: - raise ValueError("File type not supported: {self.data_type}. Choose from ['csv'].") + # ---- Reading source data (validation/format issues) -> DataValidationError / IndexBuildError + try: + if self.data_type == "csv": + self.vectors = pl.read_csv( + self.file_name, + columns=["id", "text", *self.meta_data.keys()], + dtypes=self.meta_data | {"id": str, "text": str}, + ) + self.vectors = self.vectors.with_columns( + pl.Series("uuid", [str(uuid.uuid4()) for _ in range(self.vectors.height)]) + ) + else: + raise DataValidationError( + "File type not supported. Choose from ['csv'].", + context={"data_type": self.data_type}, + ) + except ClassifaiError: + raise + except Exception as e: + raise IndexBuildError( + "Failed to read input file into a table.", + context={"file_name": self.file_name, "data_type": self.data_type}, + ) from e logging.info("Processing file: %s...\n", self.file_name) + + # ---- Embedding / dataframe build (vectoriser failures and mismatches) -> IndexBuildError try: documents = self.vectors["text"].to_list() - embeddings = [] + if not documents: + raise DataValidationError( + "Input file contains no documents in column 'text'.", + context={"file_name": self.file_name}, + ) + + embeddings: list[np.ndarray] = [] for batch_id in tqdm(range(0, len(documents), self.batch_size)): batch = documents[batch_id : (batch_id + self.batch_size)] - embeddings.extend(self.vectoriser.transform(batch)) + try: + batch_embeddings = self.vectoriser.transform(batch) + except ClassifaiError: + # preserve vectoriser classification, but add context by re-wrapping + raise + except Exception as e: + raise IndexBuildError( + "Vectoriser.transform failed during index build.", + context={ + "file_name": self.file_name, + "vectoriser": self.vectoriser_class, + "batch_id": batch_id, + "batch_size": len(batch), + }, + ) from e + + # Basic sanity check: batch should return same number of vectors as texts + if len(batch_embeddings) != len(batch): + raise IndexBuildError( + "Vectoriser returned wrong number of embeddings for batch.", + context={ + "file_name": self.file_name, + "vectoriser": self.vectoriser_class, + "batch_id": batch_id, + "expected": len(batch), + "got": len(batch_embeddings), + }, + ) + + embeddings.extend(batch_embeddings) + self.vectors = self.vectors.with_columns(pl.Series(embeddings).alias("embeddings")) + except ClassifaiError: + raise except Exception as e: - logging.error("Error creating Polars DataFrame") - raise e - - def validate(self): - """Validates the vector store by checking if the loaded vectoriser matches the one used to create the vectors - and testing the search functionality. - """ - # This method is a placeholder for future validation logic. - # Currently, it does not perform any validation. + raise IndexBuildError( + "Failed while creating embeddings and building vectors table.", + context={"file_name": self.file_name, "vectoriser": self.vectoriser_class}, + ) from e def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: """Converts text into vector embeddings using the vectoriser and returns a VectorStoreEmbedOutput dataframe with columns 'id', 'text', and 'embedding'. @@ -248,41 +362,62 @@ def embed(self, query: VectorStoreEmbedInput) -> VectorStoreEmbedOutput: Returns: VectorStoreEmbedOutput: The output object containing the embeddings along with their corresponding ids and texts. + + Raises: + DataValidationError: Raised if invalid arguments are passed. + HookError: Raised if user-defined hooks fail. + ClassifaiError: Raised if embedding operation fails. """ - # Validate the input object + # ---- Validate arguments (caller mistakes) -> DataValidationError if not isinstance(query, VectorStoreEmbedInput): - raise ValueError("Query must be a VectorStoreEmbedInput object.") + raise DataValidationError( + "query must be a VectorStoreEmbedInput object.", + context={"got_type": type(query).__name__}, + ) - # Check if there is a user defined preprocess hook for the VectorStore embed method + # ---- Preprocess hook -> HookError if "embed_preprocess" in self.hooks: - modified_query = self.hooks["embed_preprocess"](query) try: + modified_query = self.hooks["embed_preprocess"](query) query = VectorStoreEmbedInput.validate(modified_query) except Exception as e: - raise ValueError( - f"Preprocessing hook returned an invalid VectorStoreEmbedInput object. Error: {e}" + raise HookError( + "embed_preprocess hook raised an exception.", + context={"hook": "embed_preprocess"}, ) from e - # Generate embeddings using the vectoriser - embeddings = self.vectoriser.transform(query.text.to_list()) + # ---- Main embed operation + try: + # Generate embeddings using the vectoriser + embeddings = self.vectoriser.transform(query.text.to_list()) - # Create a DataFrame with id, text, and embedding fields - results_df = VectorStoreEmbedOutput.from_data( - { - "id": query.id, - "text": query.text, - "embedding": [embeddings[i] for i in range(len(embeddings))], - } - ) + # Create a DataFrame with id, text, and embedding fields + results_df = VectorStoreEmbedOutput.from_data( + { + "id": query.id, + "text": query.text, + "embedding": [embeddings[i] for i in range(len(embeddings))], + } + ) + + except ClassifaiError: + raise + except Exception as e: + raise ClassifaiError( + "Embedding failed.", + code="embed_failed", + context={"n_texts": len(query), "vectoriser": self.vectoriser_class}, + ) from e - # Check if there is a user defined postprocess hook for the VectorStore embed method + # ---- Postprocess hook -> HookError if "embed_postprocess" in self.hooks: - modified_results_df = self.hooks["embed_postprocess"](results_df) try: + modified_results_df = self.hooks["embed_postprocess"](results_df) results_df = VectorStoreEmbedOutput.validate(modified_results_df) except Exception as e: - raise ValueError( - f"Postprocessing hook returned an invalid VectorStoreEmbedOutput object. Error: {e}" + raise HookError( + "embed_postprocess hook raised an exception.", + context={"hook": "embed_postprocess"}, ) from e return results_df @@ -301,53 +436,77 @@ def reverse_search(self, query: VectorStoreReverseSearchInput, n_results=100) -> document ID, document text and any associated metadata columns. Raises: - ValueError: Raised if invalid arguments are passed. + DataValidationError: Raised if invalid arguments are passed. + HookError: Raised if user-defined hooks fail. + ClassifaiError: Raised if reverse search operation fails. """ - # Validate the input object + # ---- Validate arguments (caller mistakes) -> DataValidationError if not isinstance(query, VectorStoreReverseSearchInput): - raise ValueError("Query must be a VectorStoreReverseSearchInput object.") + raise DataValidationError( + "query must be a VectorStoreReverseSearchInput object.", + context={"got_type": type(query).__name__}, + ) + + if not isinstance(n_results, int) or n_results < 1: + raise DataValidationError("n_results must be an integer >= 1.", context={"n_results": n_results}) + + if len(query) == 0: + raise DataValidationError("query is empty.", context={"n_queries": 0}) - # Check if there is a user defined preprocess hook for the VectorStore reverse search method + # ---- Preprocess hook -> HookError if "reverse_search_preprocess" in self.hooks: - modified_query = self.hooks["reverse_search_preprocess"](query) try: + modified_query = self.hooks["reverse_search_preprocess"](query) query = VectorStoreReverseSearchInput.validate(modified_query) except Exception as e: - raise ValueError( - f"Preprocessing hook returned an invalid VectorStoreReverseSearchInput object. Error: {e}" + raise HookError( + "reverse_search_preprocess hook raised an exception.", + context={"hook": "reverse_search_preprocess"}, ) from e - # polars conversion - paired_query = pl.DataFrame({"id": query.id, "doc_id": query.doc_id}) + # ---- Main reverse-search operation (wrap unexpected failures) -> ClassifaiError + try: + # polars conversion + paired_query = pl.DataFrame({"id": query.id, "doc_id": query.doc_id}) + + # join query with vdb to get matches + joined_table = paired_query.join(self.vectors.rename({"id": "doc_id"}), on="doc_id", how="inner") - # join query with vdb to get matches - joined_table = paired_query.join(self.vectors.rename({"id": "doc_id"}), on="doc_id", how="inner") + # get formatted table + final_table = joined_table.select( + [ + pl.col("id").cast(str), + pl.col("doc_id").cast(str), + pl.col("text").cast(str).alias("doc_text"), + *[pl.col(key) for key in self.meta_data], + ] + ) - # get formatted table - final_table = joined_table.select( - [ - pl.col("id").cast(str), - pl.col("doc_id").cast(str), - pl.col("text").cast(str).alias("doc_text"), - *[pl.col(key) for key in self.meta_data], - ] - ) + result_df = VectorStoreReverseSearchOutput.from_data(final_table.to_pandas()) - result_df = VectorStoreReverseSearchOutput.from_data(final_table.to_pandas()) + except ClassifaiError: + raise + except Exception as e: + raise ClassifaiError( + "Reverse search failed.", + code="reverse_search_failed", + context={"n_queries": len(query), "n_results": n_results}, + ) from e - # Check if there is a user defined postprocess hook for the VectorStore reverse search method + # ---- Postprocess hook -> HookError if "reverse_search_postprocess" in self.hooks: - modified_result_df = self.hooks["reverse_search_postprocess"](result_df) try: + modified_result_df = self.hooks["reverse_search_postprocess"](result_df) result_df = VectorStoreReverseSearchOutput.validate(modified_result_df) except Exception as e: - raise ValueError( - f"Preprocessing hook returned an invalid VectorStoreReverseSearchOutput object. Error: {e}" + raise HookError( + "reverse_search_postprocess hook raised an exception.", + context={"hook": "reverse_search_postprocess"}, ) from e return result_df - def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput: + def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 """Searches the vector store using queries from a VectorStoreSearchInput object and returns ranked results in VectorStoreSearchOutput object. In batches, converts users text queries into vector embeddings, computes cosine similarity with stored document vectors, and retrieves the top results. @@ -362,106 +521,155 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V document ID, document text, rank, score, and any associated metadata columns. Raises: - ValueError: Raised if invalid arguments are passed. + DataValidationError: Raised if invalid arguments are passed. + ConfigurationError: Raised if the vector store is not initialized. + HookError: Raised if user-defined hooks fail. + VectorisationError: Raised if embedding queries fails. """ - # Validate the input object + # ---- Validate arguments (caller mistakes) -> DataValidationError if not isinstance(query, VectorStoreSearchInput): - raise ValueError("Query must be a VectorStoreSearchInput object.") + raise DataValidationError( + "query must be a VectorStoreSearchInput object.", + context={"got_type": type(query).__name__}, + ) + + if not isinstance(n_results, int) or n_results < 1: + raise DataValidationError("n_results must be an integer >= 1.", context={"n_results": n_results}) + + if not isinstance(batch_size, int) or batch_size < 1: + raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size}) + + if self.vectors is None: + raise ConfigurationError("Vector store is not initialized (vectors is None).") - # Check if there is a user defined preprocess hook for the VectorStore search method + if len(query) == 0: + raise DataValidationError("query is empty.", context={"n_queries": 0}) + + # ---- Preprocess hook -> DataValidationError if it returns invalid shape/type if "search_preprocess" in self.hooks: - modified_query = self.hooks["search_preprocess"](query) try: + modified_query = self.hooks["search_preprocess"](query) query = VectorStoreSearchInput.validate(modified_query) except Exception as e: - raise ValueError( - f"Preprocessing hook returned an invalid VectorStoreSearchInput object. Error: {e}" + raise HookError( + "search_preprocess hook raised an exception.", + context={"hook": "search_preprocess"}, ) from e - # Initialize an empty list to store results from each batch - all_results = [] - - # Process the queries in batches - for i in tqdm(range(0, len(query), batch_size), desc="Processing query batches"): - # Get the current batch of queries - query_text_batch = query.query.to_list()[i : i + batch_size] - query_ids_batch = query.id.to_list()[i : i + batch_size] - - # Convert the current batch of queries to vectors - query_vectors = self.vectoriser.transform(query_text_batch) - - # Compute cosine similarity between the query batch and document vectors - cosine = query_vectors @ self.vectors["embeddings"].to_numpy().T - - # Get the top n_results indices for each query in the batch - idx = np.argpartition(cosine, -n_results, axis=1)[:, -n_results:] - - # Sort top n_results indices by their scores in descending order - idx_sorted = np.zeros_like(idx) - scores = np.zeros_like(idx, dtype=float) + # ---- Main search (wrap operational failures) -> SearchError / VectorisationError + try: + doc_embeddings = self.vectors["embeddings"].to_numpy() + + all_results: list[pl.DataFrame] = [] + + for i in tqdm(range(0, len(query), batch_size), desc="Processing query batches"): + query_text_batch = query.query.to_list()[i : i + batch_size] + query_ids_batch = query.id.to_list()[i : i + batch_size] + + if len(query_text_batch) == 0: + continue + + # Embed query batch + try: + query_vectors = self.vectoriser.transform(query_text_batch) + except ClassifaiError: + raise + except Exception as e: + raise VectorisationError( + "Failed to embed query batch.", + context={ + "vectoriser": self.vectoriser_class, + "batch_start": i, + "batch_size": len(query_text_batch), + "n_results": n_results, + }, + ) from e + + # Similarity + top-k + cosine = query_vectors @ doc_embeddings.T + + idx = np.argpartition(cosine, -n_results, axis=1)[:, -n_results:] + + idx_sorted = np.zeros_like(idx) + scores = np.zeros_like(idx, dtype=float) + + for j in range(idx.shape[0]): + row_scores = cosine[j, idx[j]] + sorted_indices = np.argsort(row_scores)[::-1] + idx_sorted[j] = idx[j, sorted_indices] + scores[j] = row_scores[sorted_indices] + + # Build batch result table + result_df = pl.DataFrame( + { + "query_id": np.repeat(query_ids_batch, n_results), + "query_text": np.repeat(query_text_batch, n_results), + "rank": np.tile(np.arange(n_results), len(query_text_batch)), + "score": scores.flatten(), + } + ) + + ranked_docs = self.vectors[idx_sorted.flatten().tolist()].select(["id", "text", *self.meta_data.keys()]) + merged_df = result_df.hstack(ranked_docs).rename({"id": "doc_id", "text": "doc_text"}) + + merged_df = merged_df.with_columns( + [ + pl.col("doc_id").cast(str), + pl.col("doc_text").cast(str), + pl.col("rank").cast(int), + pl.col("score").cast(float), + pl.col("query_id").cast(str), + pl.col("query_text").cast(str), + ] + ) + + all_results.append(merged_df) + + if not all_results: + # Shouldn't happen if len(query)>0, but keep it safe. + empty = pl.DataFrame( + schema={ + "query_id": pl.Utf8, + "query_text": pl.Utf8, + "doc_id": pl.Utf8, + "doc_text": pl.Utf8, + "rank": pl.Int64, + "score": pl.Float64, + **dict.fromkeys(self.meta_data.keys(), pl.Utf8), + } + ) + return VectorStoreSearchOutput.from_data(empty.to_pandas()) + + reordered_df = pl.concat(all_results).select( + ["query_id", "query_text", "doc_id", "doc_text", "rank", "score", *self.meta_data.keys()] + ) - for j in range(idx.shape[0]): - row_scores = cosine[j, idx[j]] - sorted_indices = np.argsort(row_scores)[::-1] - idx_sorted[j] = idx[j, sorted_indices] - scores[j] = row_scores[sorted_indices] + result_df = VectorStoreSearchOutput.from_data(reordered_df.to_pandas()) - # Build a DataFrame for the current batch results - result_df = pl.DataFrame( - { - "query_id": np.repeat(query_ids_batch, n_results), - "query_text": np.repeat(query_text_batch, n_results), - "rank": np.tile(np.arange(n_results), len(query_text_batch)), - "score": scores.flatten(), - } - ) + except ClassifaiError: + raise + except Exception as e: + raise ClassifaiError( + "Search failed.", + code="search_failed", + context={"n_queries": len(query), "batch_size": batch_size, "n_results": n_results}, + ) from e - # Get the vector store results for the current batch - ranked_docs = self.vectors[idx_sorted.flatten().tolist()].select(["id", "text", *self.meta_data.keys()]) - merged_df = result_df.hstack(ranked_docs).rename({"id": "doc_id", "text": "doc_text"}) - merged_df = merged_df.with_columns( - [ - pl.col("doc_id").cast(str), - pl.col("doc_text").cast(str), - pl.col("rank").cast(int), - pl.col("score").cast(float), - pl.col("query_id").cast(str), - pl.col("query_text").cast(str), - ] - ) - # Append the current batch results to the list - all_results.append(merged_df) - - # Concatenate all batch results into a single DataFrame - reordered_df = pl.concat(all_results).select( - [ - "query_id", - "query_text", - "doc_id", - "doc_text", - "rank", - "score", - *self.meta_data.keys(), - ] - ) - - # Now that polars has been used for processing convert back to pandas for user familiarity - result_df = VectorStoreSearchOutput.from_data(reordered_df.to_pandas()) - - # Check if there is a user defined postprocess hook for the VectorStore search method + # ---- Postprocess hook -> DataValidationError if it returns invalid shape/type if "search_postprocess" in self.hooks: - modified_result_df = self.hooks["search_postprocess"](result_df) try: + modified_result_df = self.hooks["search_postprocess"](result_df) result_df = VectorStoreSearchOutput.validate(modified_result_df) except Exception as e: - raise ValueError( - f"Preprocessing hook returned an invalid VectorStoreSearchOutput object. Error: {e}" + raise HookError( + "search_postprocessing hook raised an exception.", + context={"hook": "search_postprocess"}, ) from e return result_df @classmethod - def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): + def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915 """Creates a `VectorStore` instance from stored metadata and Parquet files. This method reads the metadata and vectors from the specified folder, validates the contents, and initializes a `VectorStore` object with the @@ -481,66 +689,121 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): VectorStore: An instance of the `VectorStore` class. Raises: - ValueError: If required files or metadata keys are missing, or if the vectoriser class does not match. + DataValidationError: If input arguments are invalid or if there are issues with the metadata or Parquet files. + ConfigurationError: If there are configuration issues, such as vectoriser mismatches. + IndexBuildError: If there are failures during loading or parsing the files. """ - # check that the metadata, vectoiser info and parquet exist + # ---- Validate arguments (caller mistakes) -> DataValidationError / ConfigurationError + if not isinstance(folder_path, str) or not folder_path.strip(): + raise DataValidationError("folder_path must be a non-empty string.", context={"folder_path": folder_path}) - # load the metadata file + if not os.path.isdir(folder_path): + raise DataValidationError( + "folder_path must be an existing directory.", context={"folder_path": folder_path} + ) + + if not hasattr(vectoriser, "transform") or not callable(getattr(vectoriser, "transform", None)): + raise ConfigurationError( + "vectoriser must provide a callable .transform(texts) method.", + context={"vectoriser_type": type(vectoriser).__name__}, + ) + + if hooks is not None and not isinstance(hooks, dict): + raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) + + # ---- Load metadata -> IndexBuildError metadata_path = os.path.join(folder_path, "metadata.json") if not os.path.exists(metadata_path): - raise ValueError(f"Metadata file not found in {folder_path}") - with open(metadata_path, encoding="utf-8") as f: - metadata = json.load(f) - - # check that the correct keys exist in metadata - required_keys = [ - "vectoriser_class", - "vector_shape", - "num_vectors", - "created_at", - "meta_data", - ] - for key in required_keys: - if key not in metadata: - raise ValueError(f"Metadata file is missing required key: {key}") - - # get the column metadata and convert types to built-in types - deserialized_column_meta_data = { - key: getattr(__builtins__, value, value) # Use built-in types or keep as-is - for key, value in metadata["meta_data"].items() - } - - # check that the vector shape and num vectors are correct - # load the parquet file + raise DataValidationError( + "Metadata file not found in folder_path.", + context={"folder_path": folder_path, "metadata_path": metadata_path}, + ) + + try: + with open(metadata_path, encoding="utf-8") as f: + metadata = json.load(f) + except Exception as e: + raise IndexBuildError( + "Failed to read metadata.json.", + context={"metadata_path": metadata_path}, + ) from e + + # ---- Validate metadata content -> DataValidationError + if not isinstance(metadata, dict): + raise DataValidationError( + "metadata.json did not contain a JSON object.", + context={"metadata_path": metadata_path, "metadata_type": type(metadata).__name__}, + ) + + required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data"] + missing = [k for k in required_keys if k not in metadata] + if missing: + raise DataValidationError( + "Metadata file is missing required keys.", + context={"metadata_path": metadata_path, "missing_keys": missing}, + ) + + if not isinstance(metadata["meta_data"], dict): + raise DataValidationError( + "metadata.meta_data must be an object/dict.", + context={"metadata_path": metadata_path, "meta_data_type": type(metadata["meta_data"]).__name__}, + ) + + # ---- Deserialize meta_data types safely -> DataValidationError + try: + # get the column metadata and convert types to built-in types + deserialized_column_meta_data = { + key: getattr(__builtins__, value, value) # Use built-in types or keep as-is + for key, value in metadata["meta_data"].items() + } + except Exception as e: + raise DataValidationError( + "Unable to deserialize metadata column types from metadata in metadata file.", + context={"metadata_path": metadata_path, "meta_data": metadata["meta_data"]}, + ) from e + + # ---- Load parquet -> IndexBuildError / DataValidationError vectors_path = os.path.join(folder_path, "vectors.parquet") if not os.path.exists(vectors_path): - raise ValueError(f"Vectors Parquet file not found in {folder_path}") + raise DataValidationError( + "Vectors Parquet file not found in folder_path.", + context={"folder_path": folder_path, "vectors_path": vectors_path}, + ) + + required_columns = ["id", "text", "embeddings", "uuid", *deserialized_column_meta_data.keys()] + + try: + df = pl.read_parquet(vectors_path, columns=required_columns) + except Exception as e: + raise IndexBuildError( + "Failed to read vectors.parquet.", + context={"vectors_path": vectors_path}, + ) from e - df = pl.read_parquet( - vectors_path, - columns=["id", "text", "embeddings", "uuid", *deserialized_column_meta_data.keys()], - ) if df.is_empty(): - raise ValueError(f"Vectors Parquet file is empty in {folder_path}") - # check parquet file has the correct columns - required_columns = [ - "id", - "text", - "embeddings", - "uuid", - *deserialized_column_meta_data.keys(), - ] - for col in required_columns: - if col not in df.columns: - raise ValueError(f"Vectors Parquet file is missing required column: {col}") - - # check that the vectoriser class matches the one provided + raise DataValidationError( + "Vectors Parquet file is empty.", + context={"vectors_path": vectors_path}, + ) + + missing_cols = [c for c in required_columns if c not in df.columns] + if missing_cols: + raise DataValidationError( + "Vectors Parquet file is missing required columns.", + context={"vectors_path": vectors_path, "missing_columns": missing_cols}, + ) + + # ---- Validate vectoriser class match -> ConfigurationError if metadata["vectoriser_class"] != vectoriser.__class__.__name__: - raise ValueError( - f"Vectoriser class in metadata ({metadata['vectoriser_class']}) does not match provided vectoriser ({vectoriser.__class__.__name__})" + raise ConfigurationError( + "Vectoriser class in metadata does not match provided vectoriser.", + context={ + "metadata_vectoriser_class": metadata["vectoriser_class"], + "provided_vectoriser_class": vectoriser.__class__.__name__, + }, ) - # create the VectorStore instance and add the new data to the fields + # ---- Construct instance without __init__ and assign fields vector_store = object.__new__(cls) vector_store.file_name = None vector_store.data_type = None diff --git a/src/classifai/servers/main.py b/src/classifai/servers/main.py index 6ff6596..977367d 100644 --- a/src/classifai/servers/main.py +++ b/src/classifai/servers/main.py @@ -12,12 +12,16 @@ from typing import Annotated import uvicorn - -# New imports -from fastapi import FastAPI, Query -from fastapi.responses import RedirectResponse - -from ..indexers.dataclasses import VectorStoreEmbedInput, VectorStoreReverseSearchInput, VectorStoreSearchInput +from fastapi import FastAPI, Query, Request +from fastapi.responses import JSONResponse, RedirectResponse + +from ..exceptions import ClassifaiError, ConfigurationError, DataValidationError, IndexBuildError +from ..indexers.dataclasses import ( + VectorStore, + VectorStoreEmbedInput, + VectorStoreReverseSearchInput, + VectorStoreSearchInput, +) from .pydantic_models import ( ClassifaiData, EmbeddingsList, @@ -30,7 +34,7 @@ ) -def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901 +def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901, PLR0915 """Initialize and start the FastAPI application with dynamically created endpoints. This function dynamically registers embedding and search endpoints for each provided vector store and endpoint name. It also sets up a default route to redirect users to @@ -42,10 +46,52 @@ def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901 endpoint_names (list): A list of endpoint names corresponding to the vector stores. port (int, optional): The port on which the API server will run. Defaults to 8000. + Raises: + DataValidationError: If the input parameters are invalid. + ConfigurationError: If a vector store is missing required methods. """ + # ---- Validate startup args -> DataValidationError / ConfigurationError + if not isinstance(vector_stores, list) or not isinstance(endpoint_names, list): + raise DataValidationError( + "vector_stores and endpoint_names must be lists.", + context={ + "vector_stores_type": type(vector_stores).__name__, + "endpoint_names_type": type(endpoint_names).__name__, + }, + ) + if len(vector_stores) != len(endpoint_names): - raise ValueError("The number of vector stores must match the number of endpoint names.") + raise DataValidationError( + "The number of vector stores must match the number of endpoint names.", + context={"n_vector_stores": len(vector_stores), "n_endpoint_names": len(endpoint_names)}, + ) + + if any(not isinstance(x, str) or not x.strip() for x in endpoint_names): + raise DataValidationError( + "All endpoint_names must be non-empty strings.", + context={"endpoint_names": endpoint_names}, + ) + + if len(set(endpoint_names)) != len(endpoint_names): + raise DataValidationError( + "endpoint_names must be unique.", + context={"endpoint_names": endpoint_names}, + ) + + MAX_PORT, MIN_PORT = 65535, 1 + if not isinstance(port, int) or port < MIN_PORT or port > MAX_PORT: + raise DataValidationError( + "port must be an integer between 1 and 65535.", + context={"port": port}, + ) + + for i, vs in enumerate(vector_stores): + if not isinstance(vs, VectorStore): + raise ConfigurationError( + "vector_store must be an instance of the VectorStore class.", + context={"index": i, "vector_store_type": type(vs).__name__}, + ) logging.info("Starting ClassifAI API") @@ -53,6 +99,36 @@ def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901 app = FastAPI() + # ---- Centralized exception mapping (preferred) + @app.exception_handler(ClassifaiError) + async def classifai_error_handler(request: Request, exc: ClassifaiError): + # If your ClassifaiError exposes .code/.context, include them + status = 400 + if isinstance(exc, ConfigurationError): + status = 500 + if isinstance(exc, IndexBuildError): + status = 500 + if isinstance(exc, DataValidationError): + status = 422 + + logging.warning("ClassifAI error at %s: %s", request.url.path, exc, exc_info=False) + + payload = {"error": {"message": str(exc)}} + if getattr(exc, "code", None): + payload["error"]["code"] = exc.code + if getattr(exc, "context", None): + payload["error"]["context"] = exc.context + + return JSONResponse(status_code=status, content=payload) + + @app.exception_handler(Exception) + async def unhandled_error_handler(request: Request, exc: Exception): + logging.exception("Unhandled error at %s", request.url.path) + return JSONResponse( + status_code=500, + content={"error": {"message": "Internal server error."}}, + ) + def create_embedding_endpoint(app, endpoint_name, vector_store): """Create and register an embedding endpoint for a specific vector store. @@ -141,9 +217,7 @@ def reverse_search_endpoint( data: RevClassifaiData, n_results: Annotated[ int, - Query( - description="The max number of results to return.", - ), + Query(description="The max number of results to return.", ge=1), ] = 100, ) -> RevResultsResponseBody: input_ids = [x.id for x in data.entries] diff --git a/src/classifai/vectorisers/gcp.py b/src/classifai/vectorisers/gcp.py index 2b1ae88..49bdbfc 100644 --- a/src/classifai/vectorisers/gcp.py +++ b/src/classifai/vectorisers/gcp.py @@ -7,6 +7,7 @@ import numpy as np from classifai._optional import check_deps +from classifai.exceptions import ConfigurationError, ExternalServiceError, VectorisationError from .base import VectoriserBase @@ -14,6 +15,8 @@ logging.getLogger("google.cloud").setLevel(logging.WARNING) logging.getLogger("google.api_core").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + class GcpVectoriser(VectoriserBase): """A class for embedding text using Google Cloud Platform's GenAI API. @@ -58,8 +61,7 @@ def __init__( **client_kwargs: Additional keyword arguments to pass to the GenAI client. Raises: - RuntimeError: If the GenAI client fails to initialize. - ValueError: If neither project_id&location nor api_key is provided. + ConfigurationError: If the GenAI client fails to initialize. """ check_deps(["google-genai"], extra="gcp") from google import genai # type: ignore @@ -73,8 +75,9 @@ def __init__( elif api_key and not project_id: client_kwargs.setdefault("api_key", api_key) else: - raise ValueError( - "Provide either 'project_id' and 'location' together, or 'api_key' alone for GCP Vectoriser." + raise ConfigurationError( + "Provide either 'project_id' and 'location' together, or 'api_key' alone for GCP Vectoriser.", + context={"vectoriser": "gcp"}, ) try: @@ -82,7 +85,10 @@ def __init__( **client_kwargs, ) except Exception as e: - raise RuntimeError(f"Failed to initialize GCP Vectoriser. {e}") from e + raise ConfigurationError( + "Failed to initialize GCP GenAI client.", + context={"vectoriser": "gcp"}, + ) from e def transform(self, texts): """Transforms input text(s) into embeddings using the GenAI API. @@ -94,18 +100,31 @@ def transform(self, texts): numpy.ndarray: A 2D array of embeddings, where each row corresponds to an input text. Raises: - TypeError: If the input is not a string or a list of strings. + ExternalServiceError: If the GenAI API request fails. + VectorisationError: If the response format from the GenAI API is unexpected. """ # If a single string is passed as arg to texts, convert to list if isinstance(texts, str): texts = [texts] # The Vertex AI call to embed content - embeddings = self.vectoriser.models.embed_content( - model=self.model_name, contents=texts, config=self.model_config - ) + try: + embeddings = self.vectoriser.models.embed_content( + model=self.model_name, contents=texts, config=self.model_config + ) + except Exception as e: + raise ExternalServiceError( + "GCP embedding request failed.", + context={"vectoriser": "gcp", "model": self.model_name, "n_texts": len(texts)}, + ) from e # Extract embeddings from the response object - result = np.array([res.values for res in embeddings.embeddings]) + try: + result = np.array([res.values for res in embeddings.embeddings]) + except Exception as e: + raise VectorisationError( + "Unexpected embedding response format from GCP.", + context={"vectoriser": "gcp", "model": self.model_name}, + ) from e return result diff --git a/src/classifai/vectorisers/huggingface.py b/src/classifai/vectorisers/huggingface.py index 176b183..a23fd52 100644 --- a/src/classifai/vectorisers/huggingface.py +++ b/src/classifai/vectorisers/huggingface.py @@ -1,6 +1,7 @@ """A module that provides a wrapper for Huggingface Transformers models to generate text embeddings.""" from classifai._optional import check_deps +from classifai.exceptions import ConfigurationError, ExternalServiceError, VectorisationError from .base import VectoriserBase @@ -22,23 +23,39 @@ def __init__(self, model_name, device=None, model_revision="main"): model_name (str): The name of the Huggingface model to use. device (torch.device, optional): The device to use for computation. Defaults to GPU if available, otherwise CPU. model_revision (str, optional): The specific model revision to use. Defaults to "main". + + Raises: + ExternalServiceError: If the model or tokenizer cannot be loaded. + ConfigurationError: If the model cannot be initialized on the specified device. """ check_deps(["transformers", "torch"], extra="huggingface") import torch # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore self.model_name = model_name - self.tokenizer = AutoTokenizer.from_pretrained(model_name, revision=model_revision) # nosec: B615 - self.model = AutoModel.from_pretrained(model_name, revision=model_revision) # nosec: B615 - - # Use GPU if available and not overridden - if device: - self.device = device - else: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.model.to(self.device) - self.model.eval() + try: + self.tokenizer = AutoTokenizer.from_pretrained(model_name, revision=model_revision) # nosec: B615 + self.model = AutoModel.from_pretrained(model_name, revision=model_revision) # nosec: B615 + except Exception as e: + raise ExternalServiceError( + "Failed to load HuggingFace model/tokenizer.", + context={"vectoriser": "huggingface", "model": model_name, "revision": model_revision}, + ) from e + + # Device selection / model placement is local configuration/runtime. + try: + if device is not None: + self.device = device + else: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.model.to(self.device) + self.model.eval() + except Exception as e: + raise ConfigurationError( + "Failed to initialize model on device.", + context={"vectoriser": "huggingface", "model": model_name, "device": str(device) if device else "auto"}, + ) from e def transform(self, texts): """Transforms input text(s) into embeddings using the Huggingface model. @@ -50,7 +67,7 @@ def transform(self, texts): numpy.ndarray: A 2D array of embeddings, where each row corresponds to an input text. Raises: - TypeError: If the input is not a string or a list of strings. + VectorisationError: If tokenization, model inference, or embedding extraction fails. """ import torch # type: ignore @@ -58,24 +75,51 @@ def transform(self, texts): if isinstance(texts, str): texts = [texts] - # Tokenise input texts - inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(self.device) - - # Get model outputs - with torch.no_grad(): - outputs = self.model(**inputs) - - # Use mean pooling over the token embeddings - token_embeddings = outputs.last_hidden_state # shape: (batch_size, seq_len, hidden_size) - attention_mask = inputs["attention_mask"] - - # Perform mean pooling manually - mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - summed = torch.sum(token_embeddings * mask, dim=1) - counts = torch.clamp(mask.sum(dim=1), min=1e-9) - mean_pooled = summed / counts # shape: (batch_size, hidden_size) - - # Convert to numpy array - embeddings = mean_pooled.cpu().numpy() + # Tokenization / tensor move can fail (e.g., device issues, weird tokenizer config) + try: + inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(self.device) + except Exception as e: + raise VectorisationError( + "Tokenization failed.", + context={"vectoriser": "huggingface", "model": self.model_name, "n_texts": len(texts)}, + ) from e + + # Forward pass can fail (OOM, dtype/device mismatch, model bug) + try: + with torch.no_grad(): + outputs = self.model(**inputs) + except RuntimeError as e: + # RuntimeError is common for CUDA OOM etc. + raise VectorisationError( + "Model forward pass failed (possible OOM/device issue).", + context={ + "vectoriser": "huggingface", + "model": self.model_name, + "n_texts": len(texts), + "device": str(self.device), + }, + ) from e + except Exception as e: + raise VectorisationError( + "Model forward pass failed.", + context={"vectoriser": "huggingface", "model": self.model_name, "n_texts": len(texts)}, + ) from e + + # Pooling / output parsing + try: + token_embeddings = outputs.last_hidden_state + attention_mask = inputs["attention_mask"] + + mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + summed = torch.sum(token_embeddings * mask, dim=1) + counts = torch.clamp(mask.sum(dim=1), min=1e-9) + mean_pooled = summed / counts + + embeddings = mean_pooled.cpu().numpy() + except Exception as e: + raise VectorisationError( + "Failed to compute embeddings from model outputs.", + context={"vectoriser": "huggingface", "model": self.model_name}, + ) from e return embeddings diff --git a/src/classifai/vectorisers/ollama.py b/src/classifai/vectorisers/ollama.py index 2be88a8..07d75a3 100644 --- a/src/classifai/vectorisers/ollama.py +++ b/src/classifai/vectorisers/ollama.py @@ -3,6 +3,7 @@ import numpy as np from classifai._optional import check_deps +from classifai.errors import ExternalServiceError, VectorisationError from .base import VectoriserBase @@ -37,7 +38,8 @@ def transform(self, texts): numpy.ndarray: A 2D array of embeddings, where each row corresponds to an input text. Raises: - TypeError: If the input is not a string or a list of strings. + ExternalServiceError: If the Ollama service fails to generate embeddings. + VectorisationError: If embedding extraction from the Ollama response fails. """ import ollama # type: ignore @@ -45,6 +47,18 @@ def transform(self, texts): if isinstance(texts, str): texts = [texts] - response = ollama.embed(model=self.model_name, input=texts) - - return np.array(response.embeddings) + try: + response = ollama.embed(model=self.model_name, input=texts) + except Exception as e: + raise ExternalServiceError( + "Failed to generate embeddings using Ollama.", + context={"vectoriser": "ollama", "model": self.model_name}, + ) from e + + try: + return np.array(response.embeddings) + except Exception as e: + raise VectorisationError( + "Failed to extract embeddings from Ollama response.", + context={"vectoriser": "ollama", "model": self.model_name}, + ) from e diff --git a/uv.lock b/uv.lock index 8082439..758d0b8 100644 --- a/uv.lock +++ b/uv.lock @@ -302,7 +302,7 @@ wheels = [ [[package]] name = "classifai" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "fastapi", extra = ["standard"] },