From 428a559ee91c5cf1bd5c9c1e89bec7229eaee3d3 Mon Sep 17 00:00:00 2001 From: frayle-ons <194791647+frayle-ons@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:51:11 +0000 Subject: [PATCH 1/2] added GcpAgent for classifcation and genai base classes --- DEMO/genAI_prototype.ipynb | 262 +++++++++++++++++++++++++++++++ src/classifai/agents/__init__.py | 7 + src/classifai/agents/base.py | 19 +++ src/classifai/agents/gcp.py | 200 +++++++++++++++++++++++ src/classifai/indexers/main.py | 15 ++ 5 files changed, 503 insertions(+) create mode 100644 DEMO/genAI_prototype.ipynb create mode 100644 src/classifai/agents/__init__.py create mode 100644 src/classifai/agents/base.py create mode 100644 src/classifai/agents/gcp.py diff --git a/DEMO/genAI_prototype.ipynb b/DEMO/genAI_prototype.ipynb new file mode 100644 index 0000000..d0cb569 --- /dev/null +++ b/DEMO/genAI_prototype.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GEN AI Workflow Notebook ✨✨" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "ClassifAI allows users to create vector databases from text datasets, search those vector datasets in an ad-hoc retrieval manner, and deploy this pipeline as a restAPI service using FastAPI to perform AI assisted classification tasks.\n", + "\n", + "\n", + "An recent emergent AI field is Retrieval Augmented Generation (RAG) where text generation models that traditionally respond to a user 'prompt', first retrieve relevant infomration from a vector database via adhoc retrieval processes, and use those serach results as context to generate an answer for the original user prompt.\n", + "\n", + "#### This notebook shows how we use RAG agents to perform classification on our VectorStore semantic search results!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ClassifAIs existing retrieval setup\n", + "\n", + "![Server_Image](files/servers.png)\n", + "\n", + "#### The other modules of ClassifAI provide 3 core classes that work together to:\n", + "\n", + "1. Vectorisers, to create embeddings,\n", + "2. Vectorstores, to create datgabsees of vectors and the ability to searc/query those database\n", + "3. RestAPI, to deploy a rest api service on a server to allow connections that search the created vector databases" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!uv pip install \"classifai[gcp]\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "We need to first set up a traditional ClassifAI pipeline that can provide search/classification results for our generative model to use as context...\n", + "\n", + "\n", + "All of the following code is part of our standard ClassifAI setupm, and is a short demo of how you can create a semantic search classification system. Check out our general_workflow_demo.ipynb notebook for a walkthrough of the content of these cells!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!gcloud auth application-default login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers import VectorStore\n", + "from classifai.vectorisers import HuggingFaceVectoriser\n", + "\n", + "# Our embedding model is pulled down from HuggingFace, or used straight away if previously downloaded\n", + "# This also works with many different huggingface models!\n", + "vectoriser = HuggingFaceVectoriser(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n", + "\n", + "\n", + "my_vector_store = VectorStore(\n", + " file_name=\"./data/fake_soc_dataset.csv\", # demo csv file from the classifai package repo! (try some of our other DEMO/data datasets)\n", + " data_type=\"csv\",\n", + " vectoriser=vectoriser,\n", + " agent=None,\n", + " overwrite=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## We've set up our semantic search classification system\n", + "\n", + "The following cell runs the 'search' method of the vectorstore, which will take a query for our created vectorstore and return the top k (n_results) most similar samples stored in the vectorstore. This is an example of our exisiting retrieval capabilities with the ClassifAI package. \n", + "\n", + "This retrieved set can then be used to make a final classification decision, by some method such as:\n", + "- automatically choosing the top ranked item, \n", + "- a human in the loop physically looking at the retrieved candidates making the decision,\n", + "- by using a generative AI agent to assess the candidate lists and making a final decison... more on that later" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.indexers.dataclasses import VectorStoreSearchInput\n", + "\n", + "input_object = VectorStoreSearchInput(\n", + " {\"id\": [1, 2], \"query\": [\"dairy famer that sells milk\", \"happy construction worker\"]}\n", + ")\n", + "\n", + "bb = my_vector_store.search(input_object, n_results=5)\n", + "bb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating an AI Agent\n", + "\n", + "With our semantic search classification pipeline set up, we can send the top K results (seen above) to a Genereative AI model and ask that LLM to make a final decision on which result is the correct result for the user's original input query. Passing semantic search results to a generatieve model for some task is often referred to as 'Retrieval Augmented Generation'.\n", + "\n", + "\n", + "![Rag_Image](files/agent.png) \n", + "\n", + "\n", + "Our `GcpAgent` class agent has a transform() method which accepts a `VectorStoreSearchOutput` object. Passing this `VectorStoreSearchOutput` object to the transform() method will return another `VectorStoreSearchOutput` object, which will modify the results in some way. The classificaion GcpAgent, reduces the semantic search results down to a single result row for each query in the VectorStoreSearchOutput results.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To instantiate the GcpAgent, the constructor takes:\n", + "\n", + "- project_id - of the google project associated with the Google Gemini embedding models\n", + "- location - corresponding to the Gcloud Project\n", + "- model_name - the specific LLM model to use, that is avaiable on Gcloud\n", + "- task_type - indicates the kind of work you want the LLM to do on the ClassifAI results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from classifai.agents import GcpAgent\n", + "\n", + "my_agent = GcpAgent(\n", + " project_id=\"platforms-sandbox\", location=\"europe-west2\", model_name=\"gemini-2.5-flash\", task_type=\"classification\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can pass some VectorStore search results directly to the agents transform() method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "semantic_search_results = my_vector_store.search(input_object, n_results=5)\n", + "\n", + "print(type(semantic_search_results))\n", + "semantic_search_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### To use the If the classification agent can make a decision, it will return a single row for the corresponding query_id. Otherwise it will return the original results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_agent.transform(semantic_search_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### To use the agent in our ClassifAI pipeline, we can attach it to the running VectorStore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_vector_store.agent = my_agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### We can then run vectorstore.search and it will automatically call the agent to process the results!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_vector_store.search(input_object, n_results=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Thats it!\n", + "\n", + "You can see in the above cell that my_agent.transform() takes and returns a VectorStoreSearchOutput object and therefore integrates with the VectorStore search method.\n", + "\n", + "Check out the general_worfklow notebook to see how VectorStores can be deployed, and you'll find that the Agent can be deployed as an integrated part of the VectorStore." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/classifai/agents/__init__.py b/src/classifai/agents/__init__.py new file mode 100644 index 0000000..f354377 --- /dev/null +++ b/src/classifai/agents/__init__.py @@ -0,0 +1,7 @@ +from .base import GeneratorBase +from .gcp import GcpAgent + +__all__ = [ + "GcpAgent", + "GeneratorBase", +] diff --git a/src/classifai/agents/base.py b/src/classifai/agents/base.py new file mode 100644 index 0000000..3c9e1c4 --- /dev/null +++ b/src/classifai/agents/base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + +from classifai.indexers.dataclasses import VectorStoreSearchOutput + +## +# The following is the abstract base class for all RAG generative models. +## + + +class GeneratorBase(ABC): + """Abstract base class for all Generative RAG models.""" + + @abstractmethod + def transform( + self, + results: VectorStoreSearchOutput, + ) -> VectorStoreSearchOutput: + """Passes prompt(s) to the generator and returns the generated text(s) and RAG ranking.""" + pass diff --git a/src/classifai/agents/gcp.py b/src/classifai/agents/gcp.py new file mode 100644 index 0000000..f005d64 --- /dev/null +++ b/src/classifai/agents/gcp.py @@ -0,0 +1,200 @@ +import json + +import pandas as pd +from google import genai +from google.genai.types import GenerateContentConfig, HttpOptions +from pydantic import BaseModel, Field + +from classifai.indexers.dataclasses import VectorStoreSearchOutput + +from .base import GeneratorBase + +######################## +#### SYSTEM PROMPTS FOR DIFFERENT TASK TYPES: currently, Classification or Summarization. +######################## + +CLASSIFICATION_SYSTEM_PROMPT = """You are an AI assistant designed to classify a user query based on the provided context. You will be provided with 5 candidate entries retrieved from a knowledge base, each containing an ID and a text description. Your task is to analyze the user query and the text of the context entries to determine which of the entries best matches the user query. + +Guidelines: +1. Always prioritize the provided context when making your classification. +2. The context will be provided as an XML structure containing multiple entries. Each entry includes an ID and a text description. +3. The IDs will be integer values from 0 to 4, corresponding to the 5 candidate entries. +4. Use the text of the entries to determine the most relevant classification for the user query. +5. Your output must be a JSON object that adheres to the following schema: + - The JSON object must contain a single key, `classification`. + - The value of `classification` must be an integer between 0 and 4, representing the ID of the best matching entry. + - If no classification can be determined due to ambiguity or insufficient information, the value of `classification` must be `-1`. + +Example of the required JSON output: +{ + "classification": 1 +} + +The XML structure for the context and user query will be as follows: + + + 0 + [Text from the first entry] + + + 1 + [Text from the second entry] + + ... + + 4 + [Text from the fifth entry] + + + + + [The user query will be inserted here] + + +Your task is to analyze the context and the user query, and return the classification in the required structured format.""" + + +######################## +#### GENERAL FUNCTION FOR FORMATTING THE USER QUERY PROMPT WITH RETRIEVED RESULTS FROM VECTORSTORE +######################## + + +def format_prompt_with_retrieval_results(df: pd.DataFrame) -> str: + """Generates a formatted XML prompt for the generative model from a structured DataFrame. + + Args: + df (pd.DataFrame): A DataFrame containing columns as per `searchOutputSchema`. + + Returns: + str: The formatted XML prompt. + """ + # Extract the user query (assuming all rows have the same query_id and query_text) + user_query = df["query_text"].iloc[0] + + # Limit to the top 5 entries based on rank + top_entries = df.nsmallest(5, "rank") + + # Build the section + context_entries = "\n".join( + f" \n {idx}\n {row['doc_text']}\n " + for idx, row in top_entries.iterrows() + ) + + # Combine everything into the final prompt + formatted_prompt = f""" + +{context_entries} + + + + {user_query} +""" + + return formatted_prompt + + +######################## +#### SYSTEM PROMPTS FOR DIFFERENT TASK TYPES: Classification +######################## + + +class ClassificationResponseModel(BaseModel): + classification: int = Field(description="Chosen ID of the best matching entry.", ge=-1) + + +######################## +#### FORMATTING FUNCTIONS THAT INTERPRET THE MODEL RAW RESPONSE, FORMATS, and APPLIES TO DF +######################## + + +def format_classification_output(generated_text, result: VectorStoreSearchOutput) -> VectorStoreSearchOutput: + # Parse the generated text + try: + response = json.loads(generated_text) + validated_response = ClassificationResponseModel(**response) + except (json.JSONDecodeError, ValueError): + # If parsing or validation fails, return the original DataFrame + return result + + # Extract the classification + classification = validated_response.classification + + # Validate the classification value is in the expected range + MIN_INDEX = 0 + MAX_INDEX = 4 + if int(classification) < MIN_INDEX or int(classification) > MAX_INDEX: + return result + + # Otherwise, filter to only keep the row with the classified doc_id + result = result.iloc[[classification]].reset_index(drop=True) + + return VectorStoreSearchOutput(result) + + +######################## +#### ACTUAL AGENT CODE +######################## + + +class GcpAgent(GeneratorBase): + def __init__( + self, + project_id: str, + location: str, + model_name: str = "gemini-3-flash-preview", + task_type: str = "classification", + ): + self.client = genai.Client( + vertexai=True, + project=project_id, + location=location, + http_options=HttpOptions(api_version="v1"), + ) + + # assign model name and vectorstore isntance + self.model_name = model_name + + # decide logic for classification or reranking + # if task_type == "reranking": + # self.system_prompt = RERANK_SYSTEM_PROMPT + # self.response_formatting_function = format_reranking_output + # self.response_schema = RERANK_RESPONSE_SCHEMA + if task_type == "classification": + self.system_prompt = CLASSIFICATION_SYSTEM_PROMPT + self.response_formatting_function = format_classification_output + self.response_schema = ClassificationResponseModel.model_json_schema() + + else: + raise ValueError( + f"Unsupported task_type: {task_type}. Current supported types are 'reranking' and 'classification'." + ) + + def transform(self, results: VectorStoreSearchOutput) -> VectorStoreSearchOutput: + # Group rows by query_id and process individually + grouped = list(results.groupby("query_id")) + all_results = [] + + # Iterate over each group (query_id) + for _, group in grouped: + # Create a prompt for the current query_id + prompt = format_prompt_with_retrieval_results(group) + + # Prompt the model with the single prompt + response = self.client.models.generate_content( + model=self.model_name, + contents=prompt, + config=GenerateContentConfig( + system_instruction=self.system_prompt, + response_mime_type="application/json", + response_schema=self.response_schema, + ), + ) + + # Process the response from the genai + formatted_result = self.response_formatting_function(response.text, group) + all_results.append(formatted_result) + + # Combine all results into the final DataFrame + final_results = pd.concat(all_results, ignore_index=True) + + return VectorStoreSearchOutput(final_results) diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 6f87380..c4954e1 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -12,6 +12,8 @@ - Batch processing of input files to handle large datasets. - Support for CSV file format (additional formats may be added in future updates). - Integration with a custom embedder for generating vector embeddings. +- Support for user-defined hooks for preprocessing and postprocessing. +- Integreates with an optional AI agent for classifying and transforming search results. - Logging for tracking progress and handling errors during processing. Dependencies: @@ -60,6 +62,7 @@ class VectorStore: file_name (str): the original file with the knowledgebase to build the vector store data_type (str): the data type of the original file (curently only csv supported) vectoriser (object): A Vectoriser object from the corresponding ClassifAI Pacakge module + agent (object): An optional generate AI agent from the ClassifAI Agents module to transform candidate classification search results batch_size (int): the batch size to pass to the vectoriser when embedding meta_data (dict[str:type]): key-value pairs of metadata to extract from the input file and their correpsonding types output_dir (str): the path to the output directory where the VectorStore will be saved @@ -75,6 +78,7 @@ def __init__( # noqa: PLR0913 file_name, data_type, vectoriser, + agent, batch_size=8, meta_data=None, output_dir=None, @@ -89,6 +93,7 @@ def __init__( # noqa: PLR0913 data_type (str): The type of input data (currently supports only "csv"). vectoriser (object): The vectoriser object used to transform text into vector embeddings. + agent (object): The generate AI agent used to transform candidate classification search results. batch_size (int, optional): The batch size for processing the input file and batching to vectoriser. Defaults to 8. meta_data (dict, optional): key,value pair metadata column names to extract from the input file and their types. @@ -107,6 +112,7 @@ def __init__( # noqa: PLR0913 self.file_name = file_name self.data_type = data_type self.vectoriser = vectoriser + self.agent = agent if agent is not None else None self.batch_size = batch_size self.meta_data = meta_data if meta_data is not None else {} self.output_dir = output_dir @@ -458,6 +464,15 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V f"Preprocessing hook returned an invalid VectorStoreSearchOutput object. Error: {e}" ) from e + # If an agent is defined, use it to transform the results, expecxts a VectorStoreSearchOutput object in return + if self.agent is not None: + agent_result_df = self.agent.transform(result_df) + + if not isinstance(agent_result_df, VectorStoreSearchOutput): + raise ValueError("Agent Prep did not succesfully return a VectorStoreSearchOutput object.") + + return agent_result_df + return result_df @classmethod From c4003d482781c20fe216d6f013d67b0a1849cb1e Mon Sep 17 00:00:00 2001 From: frayle-ons <194791647+frayle-ons@users.noreply.github.com> Date: Tue, 20 Jan 2026 13:40:37 +0000 Subject: [PATCH 2/2] renaming baseclass and some minor code fixes --- DEMO/genAI_prototype.ipynb | 2 +- src/classifai/agents/__init__.py | 4 ++-- src/classifai/agents/base.py | 6 +++--- src/classifai/agents/gcp.py | 4 ++-- src/classifai/indexers/main.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/DEMO/genAI_prototype.ipynb b/DEMO/genAI_prototype.ipynb index d0cb569..6c19e21 100644 --- a/DEMO/genAI_prototype.ipynb +++ b/DEMO/genAI_prototype.ipynb @@ -155,7 +155,7 @@ "from classifai.agents import GcpAgent\n", "\n", "my_agent = GcpAgent(\n", - " project_id=\"platforms-sandbox\", location=\"europe-west2\", model_name=\"gemini-2.5-flash\", task_type=\"classification\"\n", + " project_id=\"xxxxxx\", location=\"europe-west2\", model_name=\"gemini-2.5-flash\", task_type=\"classification\"\n", ")" ] }, diff --git a/src/classifai/agents/__init__.py b/src/classifai/agents/__init__.py index f354377..dc9f36d 100644 --- a/src/classifai/agents/__init__.py +++ b/src/classifai/agents/__init__.py @@ -1,7 +1,7 @@ -from .base import GeneratorBase +from .base import AgentBase from .gcp import GcpAgent __all__ = [ + "AgentBase", "GcpAgent", - "GeneratorBase", ] diff --git a/src/classifai/agents/base.py b/src/classifai/agents/base.py index 3c9e1c4..d206c0b 100644 --- a/src/classifai/agents/base.py +++ b/src/classifai/agents/base.py @@ -7,13 +7,13 @@ ## -class GeneratorBase(ABC): - """Abstract base class for all Generative RAG models.""" +class Agentase(ABC): + """Abstract base class for all Generative and RAG models.""" @abstractmethod def transform( self, results: VectorStoreSearchOutput, ) -> VectorStoreSearchOutput: - """Passes prompt(s) to the generator and returns the generated text(s) and RAG ranking.""" + """Passes VectorStoreSearchOutput object, which the Agent manipulates in some way and returns.""" pass diff --git a/src/classifai/agents/gcp.py b/src/classifai/agents/gcp.py index f005d64..b1754f9 100644 --- a/src/classifai/agents/gcp.py +++ b/src/classifai/agents/gcp.py @@ -7,7 +7,7 @@ from classifai.indexers.dataclasses import VectorStoreSearchOutput -from .base import GeneratorBase +from .base import AgentBase ######################## #### SYSTEM PROMPTS FOR DIFFERENT TASK TYPES: currently, Classification or Summarization. @@ -136,7 +136,7 @@ def format_classification_output(generated_text, result: VectorStoreSearchOutput ######################## -class GcpAgent(GeneratorBase): +class GcpAgent(AgentBase): def __init__( self, project_id: str, diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index c4954e1..9391857 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -78,7 +78,7 @@ def __init__( # noqa: PLR0913 file_name, data_type, vectoriser, - agent, + agent=None, batch_size=8, meta_data=None, output_dir=None,