diff --git a/DEMO/genAI_prototype.ipynb b/DEMO/genAI_prototype.ipynb
new file mode 100644
index 0000000..6c19e21
--- /dev/null
+++ b/DEMO/genAI_prototype.ipynb
@@ -0,0 +1,262 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# GEN AI Workflow Notebook ✨✨"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "ClassifAI allows users to create vector databases from text datasets, search those vector datasets in an ad-hoc retrieval manner, and deploy this pipeline as a restAPI service using FastAPI to perform AI assisted classification tasks.\n",
+ "\n",
+ "\n",
+ "An recent emergent AI field is Retrieval Augmented Generation (RAG) where text generation models that traditionally respond to a user 'prompt', first retrieve relevant infomration from a vector database via adhoc retrieval processes, and use those serach results as context to generate an answer for the original user prompt.\n",
+ "\n",
+ "#### This notebook shows how we use RAG agents to perform classification on our VectorStore semantic search results!\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## ClassifAIs existing retrieval setup\n",
+ "\n",
+ "\n",
+ "\n",
+ "#### The other modules of ClassifAI provide 3 core classes that work together to:\n",
+ "\n",
+ "1. Vectorisers, to create embeddings,\n",
+ "2. Vectorstores, to create datgabsees of vectors and the ability to searc/query those database\n",
+ "3. RestAPI, to deploy a rest api service on a server to allow connections that search the created vector databases"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!uv pip install \"classifai[gcp]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "We need to first set up a traditional ClassifAI pipeline that can provide search/classification results for our generative model to use as context...\n",
+ "\n",
+ "\n",
+ "All of the following code is part of our standard ClassifAI setupm, and is a short demo of how you can create a semantic search classification system. Check out our general_workflow_demo.ipynb notebook for a walkthrough of the content of these cells!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!gcloud auth application-default login"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from classifai.indexers import VectorStore\n",
+ "from classifai.vectorisers import HuggingFaceVectoriser\n",
+ "\n",
+ "# Our embedding model is pulled down from HuggingFace, or used straight away if previously downloaded\n",
+ "# This also works with many different huggingface models!\n",
+ "vectoriser = HuggingFaceVectoriser(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n",
+ "\n",
+ "\n",
+ "my_vector_store = VectorStore(\n",
+ " file_name=\"./data/fake_soc_dataset.csv\", # demo csv file from the classifai package repo! (try some of our other DEMO/data datasets)\n",
+ " data_type=\"csv\",\n",
+ " vectoriser=vectoriser,\n",
+ " agent=None,\n",
+ " overwrite=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## We've set up our semantic search classification system\n",
+ "\n",
+ "The following cell runs the 'search' method of the vectorstore, which will take a query for our created vectorstore and return the top k (n_results) most similar samples stored in the vectorstore. This is an example of our exisiting retrieval capabilities with the ClassifAI package. \n",
+ "\n",
+ "This retrieved set can then be used to make a final classification decision, by some method such as:\n",
+ "- automatically choosing the top ranked item, \n",
+ "- a human in the loop physically looking at the retrieved candidates making the decision,\n",
+ "- by using a generative AI agent to assess the candidate lists and making a final decison... more on that later"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from classifai.indexers.dataclasses import VectorStoreSearchInput\n",
+ "\n",
+ "input_object = VectorStoreSearchInput(\n",
+ " {\"id\": [1, 2], \"query\": [\"dairy famer that sells milk\", \"happy construction worker\"]}\n",
+ ")\n",
+ "\n",
+ "bb = my_vector_store.search(input_object, n_results=5)\n",
+ "bb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Creating an AI Agent\n",
+ "\n",
+ "With our semantic search classification pipeline set up, we can send the top K results (seen above) to a Genereative AI model and ask that LLM to make a final decision on which result is the correct result for the user's original input query. Passing semantic search results to a generatieve model for some task is often referred to as 'Retrieval Augmented Generation'.\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "\n",
+ "Our `GcpAgent` class agent has a transform() method which accepts a `VectorStoreSearchOutput` object. Passing this `VectorStoreSearchOutput` object to the transform() method will return another `VectorStoreSearchOutput` object, which will modify the results in some way. The classificaion GcpAgent, reduces the semantic search results down to a single result row for each query in the VectorStoreSearchOutput results.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To instantiate the GcpAgent, the constructor takes:\n",
+ "\n",
+ "- project_id - of the google project associated with the Google Gemini embedding models\n",
+ "- location - corresponding to the Gcloud Project\n",
+ "- model_name - the specific LLM model to use, that is avaiable on Gcloud\n",
+ "- task_type - indicates the kind of work you want the LLM to do on the ClassifAI results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from classifai.agents import GcpAgent\n",
+ "\n",
+ "my_agent = GcpAgent(\n",
+ " project_id=\"xxxxxx\", location=\"europe-west2\", model_name=\"gemini-2.5-flash\", task_type=\"classification\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can pass some VectorStore search results directly to the agents transform() method"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "semantic_search_results = my_vector_store.search(input_object, n_results=5)\n",
+ "\n",
+ "print(type(semantic_search_results))\n",
+ "semantic_search_results"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### To use the If the classification agent can make a decision, it will return a single row for the corresponding query_id. Otherwise it will return the original results\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "my_agent.transform(semantic_search_results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### To use the agent in our ClassifAI pipeline, we can attach it to the running VectorStore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "my_vector_store.agent = my_agent"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### We can then run vectorstore.search and it will automatically call the agent to process the results!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "my_vector_store.search(input_object, n_results=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Thats it!\n",
+ "\n",
+ "You can see in the above cell that my_agent.transform() takes and returns a VectorStoreSearchOutput object and therefore integrates with the VectorStore search method.\n",
+ "\n",
+ "Check out the general_worfklow notebook to see how VectorStores can be deployed, and you'll find that the Agent can be deployed as an integrated part of the VectorStore."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/classifai/agents/__init__.py b/src/classifai/agents/__init__.py
new file mode 100644
index 0000000..dc9f36d
--- /dev/null
+++ b/src/classifai/agents/__init__.py
@@ -0,0 +1,7 @@
+from .base import AgentBase
+from .gcp import GcpAgent
+
+__all__ = [
+ "AgentBase",
+ "GcpAgent",
+]
diff --git a/src/classifai/agents/base.py b/src/classifai/agents/base.py
new file mode 100644
index 0000000..d206c0b
--- /dev/null
+++ b/src/classifai/agents/base.py
@@ -0,0 +1,19 @@
+from abc import ABC, abstractmethod
+
+from classifai.indexers.dataclasses import VectorStoreSearchOutput
+
+##
+# The following is the abstract base class for all RAG generative models.
+##
+
+
+class Agentase(ABC):
+ """Abstract base class for all Generative and RAG models."""
+
+ @abstractmethod
+ def transform(
+ self,
+ results: VectorStoreSearchOutput,
+ ) -> VectorStoreSearchOutput:
+ """Passes VectorStoreSearchOutput object, which the Agent manipulates in some way and returns."""
+ pass
diff --git a/src/classifai/agents/gcp.py b/src/classifai/agents/gcp.py
new file mode 100644
index 0000000..b1754f9
--- /dev/null
+++ b/src/classifai/agents/gcp.py
@@ -0,0 +1,200 @@
+import json
+
+import pandas as pd
+from google import genai
+from google.genai.types import GenerateContentConfig, HttpOptions
+from pydantic import BaseModel, Field
+
+from classifai.indexers.dataclasses import VectorStoreSearchOutput
+
+from .base import AgentBase
+
+########################
+#### SYSTEM PROMPTS FOR DIFFERENT TASK TYPES: currently, Classification or Summarization.
+########################
+
+CLASSIFICATION_SYSTEM_PROMPT = """You are an AI assistant designed to classify a user query based on the provided context. You will be provided with 5 candidate entries retrieved from a knowledge base, each containing an ID and a text description. Your task is to analyze the user query and the text of the context entries to determine which of the entries best matches the user query.
+
+Guidelines:
+1. Always prioritize the provided context when making your classification.
+2. The context will be provided as an XML structure containing multiple entries. Each entry includes an ID and a text description.
+3. The IDs will be integer values from 0 to 4, corresponding to the 5 candidate entries.
+4. Use the text of the entries to determine the most relevant classification for the user query.
+5. Your output must be a JSON object that adheres to the following schema:
+ - The JSON object must contain a single key, `classification`.
+ - The value of `classification` must be an integer between 0 and 4, representing the ID of the best matching entry.
+ - If no classification can be determined due to ambiguity or insufficient information, the value of `classification` must be `-1`.
+
+Example of the required JSON output:
+{
+ "classification": 1
+}
+
+The XML structure for the context and user query will be as follows:
+
+
+ 0
+ [Text from the first entry]
+
+
+ 1
+ [Text from the second entry]
+
+ ...
+
+ 4
+ [Text from the fifth entry]
+
+
+
+
+ [The user query will be inserted here]
+
+
+Your task is to analyze the context and the user query, and return the classification in the required structured format."""
+
+
+########################
+#### GENERAL FUNCTION FOR FORMATTING THE USER QUERY PROMPT WITH RETRIEVED RESULTS FROM VECTORSTORE
+########################
+
+
+def format_prompt_with_retrieval_results(df: pd.DataFrame) -> str:
+ """Generates a formatted XML prompt for the generative model from a structured DataFrame.
+
+ Args:
+ df (pd.DataFrame): A DataFrame containing columns as per `searchOutputSchema`.
+
+ Returns:
+ str: The formatted XML prompt.
+ """
+ # Extract the user query (assuming all rows have the same query_id and query_text)
+ user_query = df["query_text"].iloc[0]
+
+ # Limit to the top 5 entries based on rank
+ top_entries = df.nsmallest(5, "rank")
+
+ # Build the section
+ context_entries = "\n".join(
+ f" \n {idx}\n {row['doc_text']}\n "
+ for idx, row in top_entries.iterrows()
+ )
+
+ # Combine everything into the final prompt
+ formatted_prompt = f"""
+
+{context_entries}
+
+
+
+ {user_query}
+"""
+
+ return formatted_prompt
+
+
+########################
+#### SYSTEM PROMPTS FOR DIFFERENT TASK TYPES: Classification
+########################
+
+
+class ClassificationResponseModel(BaseModel):
+ classification: int = Field(description="Chosen ID of the best matching entry.", ge=-1)
+
+
+########################
+#### FORMATTING FUNCTIONS THAT INTERPRET THE MODEL RAW RESPONSE, FORMATS, and APPLIES TO DF
+########################
+
+
+def format_classification_output(generated_text, result: VectorStoreSearchOutput) -> VectorStoreSearchOutput:
+ # Parse the generated text
+ try:
+ response = json.loads(generated_text)
+ validated_response = ClassificationResponseModel(**response)
+ except (json.JSONDecodeError, ValueError):
+ # If parsing or validation fails, return the original DataFrame
+ return result
+
+ # Extract the classification
+ classification = validated_response.classification
+
+ # Validate the classification value is in the expected range
+ MIN_INDEX = 0
+ MAX_INDEX = 4
+ if int(classification) < MIN_INDEX or int(classification) > MAX_INDEX:
+ return result
+
+ # Otherwise, filter to only keep the row with the classified doc_id
+ result = result.iloc[[classification]].reset_index(drop=True)
+
+ return VectorStoreSearchOutput(result)
+
+
+########################
+#### ACTUAL AGENT CODE
+########################
+
+
+class GcpAgent(AgentBase):
+ def __init__(
+ self,
+ project_id: str,
+ location: str,
+ model_name: str = "gemini-3-flash-preview",
+ task_type: str = "classification",
+ ):
+ self.client = genai.Client(
+ vertexai=True,
+ project=project_id,
+ location=location,
+ http_options=HttpOptions(api_version="v1"),
+ )
+
+ # assign model name and vectorstore isntance
+ self.model_name = model_name
+
+ # decide logic for classification or reranking
+ # if task_type == "reranking":
+ # self.system_prompt = RERANK_SYSTEM_PROMPT
+ # self.response_formatting_function = format_reranking_output
+ # self.response_schema = RERANK_RESPONSE_SCHEMA
+ if task_type == "classification":
+ self.system_prompt = CLASSIFICATION_SYSTEM_PROMPT
+ self.response_formatting_function = format_classification_output
+ self.response_schema = ClassificationResponseModel.model_json_schema()
+
+ else:
+ raise ValueError(
+ f"Unsupported task_type: {task_type}. Current supported types are 'reranking' and 'classification'."
+ )
+
+ def transform(self, results: VectorStoreSearchOutput) -> VectorStoreSearchOutput:
+ # Group rows by query_id and process individually
+ grouped = list(results.groupby("query_id"))
+ all_results = []
+
+ # Iterate over each group (query_id)
+ for _, group in grouped:
+ # Create a prompt for the current query_id
+ prompt = format_prompt_with_retrieval_results(group)
+
+ # Prompt the model with the single prompt
+ response = self.client.models.generate_content(
+ model=self.model_name,
+ contents=prompt,
+ config=GenerateContentConfig(
+ system_instruction=self.system_prompt,
+ response_mime_type="application/json",
+ response_schema=self.response_schema,
+ ),
+ )
+
+ # Process the response from the genai
+ formatted_result = self.response_formatting_function(response.text, group)
+ all_results.append(formatted_result)
+
+ # Combine all results into the final DataFrame
+ final_results = pd.concat(all_results, ignore_index=True)
+
+ return VectorStoreSearchOutput(final_results)
diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py
index 6f87380..9391857 100644
--- a/src/classifai/indexers/main.py
+++ b/src/classifai/indexers/main.py
@@ -12,6 +12,8 @@
- Batch processing of input files to handle large datasets.
- Support for CSV file format (additional formats may be added in future updates).
- Integration with a custom embedder for generating vector embeddings.
+- Support for user-defined hooks for preprocessing and postprocessing.
+- Integreates with an optional AI agent for classifying and transforming search results.
- Logging for tracking progress and handling errors during processing.
Dependencies:
@@ -60,6 +62,7 @@ class VectorStore:
file_name (str): the original file with the knowledgebase to build the vector store
data_type (str): the data type of the original file (curently only csv supported)
vectoriser (object): A Vectoriser object from the corresponding ClassifAI Pacakge module
+ agent (object): An optional generate AI agent from the ClassifAI Agents module to transform candidate classification search results
batch_size (int): the batch size to pass to the vectoriser when embedding
meta_data (dict[str:type]): key-value pairs of metadata to extract from the input file and their correpsonding types
output_dir (str): the path to the output directory where the VectorStore will be saved
@@ -75,6 +78,7 @@ def __init__( # noqa: PLR0913
file_name,
data_type,
vectoriser,
+ agent=None,
batch_size=8,
meta_data=None,
output_dir=None,
@@ -89,6 +93,7 @@ def __init__( # noqa: PLR0913
data_type (str): The type of input data (currently supports only "csv").
vectoriser (object): The vectoriser object used to transform text into
vector embeddings.
+ agent (object): The generate AI agent used to transform candidate classification search results.
batch_size (int, optional): The batch size for processing the input file and batching to
vectoriser. Defaults to 8.
meta_data (dict, optional): key,value pair metadata column names to extract from the input file and their types.
@@ -107,6 +112,7 @@ def __init__( # noqa: PLR0913
self.file_name = file_name
self.data_type = data_type
self.vectoriser = vectoriser
+ self.agent = agent if agent is not None else None
self.batch_size = batch_size
self.meta_data = meta_data if meta_data is not None else {}
self.output_dir = output_dir
@@ -458,6 +464,15 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
f"Preprocessing hook returned an invalid VectorStoreSearchOutput object. Error: {e}"
) from e
+ # If an agent is defined, use it to transform the results, expecxts a VectorStoreSearchOutput object in return
+ if self.agent is not None:
+ agent_result_df = self.agent.transform(result_df)
+
+ if not isinstance(agent_result_df, VectorStoreSearchOutput):
+ raise ValueError("Agent Prep did not succesfully return a VectorStoreSearchOutput object.")
+
+ return agent_result_df
+
return result_df
@classmethod