Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 67 additions & 17 deletions src/classifai/servers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,22 @@
import uvicorn

# New imports
from fastapi import FastAPI, Query
from fastapi import FastAPI, Query, Request
from fastapi.responses import RedirectResponse

from ..indexers.dataclasses import VectorStoreEmbedInput, VectorStoreReverseSearchInput, VectorStoreSearchInput
from .pydantic_models import (
ClassifaiData,
EmbeddingsList,
EmbeddingsResponseBody,
ResultsResponseBody,
RevClassifaiData,
RevResultsResponseBody,
convert_dataframe_to_pydantic_response,
convert_dataframe_to_reverse_search_pydantic_response,
)


def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901
"""Initialize and start the FastAPI application with dynamically created endpoints.
def setup_api(vector_stores, endpoint_names, hooks: list[dict] | None = None): # noqa: C901,PLR0915,PLR0912
"""Initialize the FastAPI application with dynamically created endpoints.
This function dynamically registers embedding and search endpoints for each provided
vector store and endpoint name. It also sets up a default route to redirect users to
the API documentation page.
Expand All @@ -41,8 +39,17 @@ def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901
embedding and search operations for a specific endpoint.
endpoint_names (list): A list of endpoint names corresponding to the vector stores.
port (int, optional): The port on which the API server will run. Defaults to 8000.


hooks (list[dict], optional): A list of hook dictionaries (one per VectorStore)
for additional configurations. Defaults to [{}].
Hook dictionaries should be in the format
{
"embed": {
"decorators": [callable_1, callable_2, ...],
"pre_endpt": callable,
"post_endpt": callable,
},
...
}
"""
if len(vector_stores) != len(endpoint_names):
raise ValueError("The number of vector stores must match the number of endpoint names.")
Expand All @@ -60,13 +67,14 @@ def create_embedding_endpoint(app, endpoint_name, vector_store):
app (FastAPI): The FastAPI application instance.
endpoint_name (str): The name of the endpoint to be created.
vector_store: The vector store object responsible for generating embeddings.
hooks (dict, optional): A dictionary of hooks for additional configurations.
Defaults to {}.

The created endpoint accepts POST requests with input data, generates embeddings
for the provided documents, and returns the results in a structured format.
"""

@app.post(f"/{endpoint_name}/embed", description=f"{endpoint_name} embedding endpoint")
async def embedding_endpoint(data: ClassifaiData) -> EmbeddingsResponseBody:
async def embedding_endpoint(request: Request, data: ClassifaiData):
input_ids = [x.id for x in data.entries]
documents = [x.description for x in data.entries]

Expand All @@ -85,6 +93,8 @@ async def embedding_endpoint(data: ClassifaiData) -> EmbeddingsResponseBody:
)
return EmbeddingsResponseBody(data=returnable)

return embedding_endpoint

def create_search_endpoint(app, endpoint_name, vector_store):
"""Create and register a search endpoint for a specific vector store.

Expand All @@ -98,8 +108,8 @@ def create_search_endpoint(app, endpoint_name, vector_store):
the vector store and returns the results in a structured format.
"""

@app.post(f"/{endpoint_name}/search", description=f"{endpoint_name} search endpoint")
async def search_endpoint(
request: Request,
data: ClassifaiData,
n_results: Annotated[
int,
Expand All @@ -108,7 +118,7 @@ async def search_endpoint(
ge=1, # Ensure at least one result is returned
),
] = 10,
) -> ResultsResponseBody:
):
input_ids = [x.id for x in data.entries]
queries = [x.description for x in data.entries]

Expand All @@ -123,6 +133,8 @@ async def search_endpoint(

return formatted_result

return search_endpoint

def create_reverse_search_endpoint(app, endpoint_name, vector_store):
"""Create and register a reverse_search endpoint for a specific vector store.

Expand All @@ -136,16 +148,16 @@ def create_reverse_search_endpoint(app, endpoint_name, vector_store):
the vector store and returns the results in a structured format.
"""

@app.post(f"/{endpoint_name}/reverse_search", description=f"{endpoint_name} reverse query endpoint")
def reverse_search_endpoint(
request: Request,
data: RevClassifaiData,
n_results: Annotated[
int,
Query(
description="The max number of results to return.",
),
] = 100,
) -> RevResultsResponseBody:
):
input_ids = [x.id for x in data.entries]
queries = [x.code for x in data.entries]

Expand All @@ -158,11 +170,44 @@ def reverse_search_endpoint(
)
return formatted_result

for endpoint_name, vector_store in zip(endpoint_names, vector_stores, strict=True):
return reverse_search_endpoint

for endpoint_name, vector_store, hooks_dict in zip(endpoint_names, vector_stores, hooks, strict=True):
logging.info("Registering endpoints for: %s", endpoint_name)
create_embedding_endpoint(app, endpoint_name, vector_store)
create_search_endpoint(app, endpoint_name, vector_store)
create_reverse_search_endpoint(app, endpoint_name, vector_store)
embedding_endpt = create_embedding_endpoint(app, endpoint_name, vector_store)
if hooks_dict is not None:
if hooks_dict.get("embed"):
for decorator in hooks_dict["embed"].get("decorators", []):
embedding_endpt = decorator(embedding_endpt)
embedding_endpt = app.post(
f"/{endpoint_name}/embed", description=f"{endpoint_name} embedding endpoint"
)(embedding_endpt)
if pre_endpt := hooks_dict["embed"].get("pre_endpt"):
embedding_endpt = pre_endpt(embedding_endpt)
if post_endpt := hooks_dict["embed"].get("post_endpt"):
embedding_endpt = post_endpt(embedding_endpt)
search_endpt = create_search_endpoint(app, endpoint_name, vector_store)
if hooks_dict.get("search"):
for decorator in hooks_dict["search"].get("decorators", []):
search_endpt = decorator(search_endpt)
search_endpt = app.post(f"/{endpoint_name}/search", description=f"{endpoint_name} search endpoint")(
search_endpt
)
if pre_endpt := hooks_dict["search"].get("pre_endpt"):
search_endpt = pre_endpt(search_endpt)
if post_endpt := hooks_dict["search"].get("post_endpt"):
search_endpt = post_endpt(search_endpt)
reverse_search_endpt = create_reverse_search_endpoint(app, endpoint_name, vector_store)
if hooks_dict.get("reverse_search"):
for decorator in hooks_dict["reverse_search"].get("decorators", []):
reverse_search_endpt = decorator(reverse_search_endpt)
reverse_search_endpt = app.post(
f"/{endpoint_name}/reverse_search", description=f"{endpoint_name} reverse query endpoint"
)(reverse_search_endpt)
if pre_endpt := hooks_dict["reverse_search"].get("pre_endpt"):
reverse_search_endpt = pre_endpt(reverse_search_endpt)
if post_endpt := hooks_dict["reverse_search"].get("post_endpt"):
reverse_search_endpt = post_endpt(reverse_search_endpt)

@app.get("/", description="UI accessibility")
def docs():
Expand All @@ -174,4 +219,9 @@ def docs():
start_page = RedirectResponse(url="/docs")
return start_page

return app


def start_api(app, port: int = 8000):
"""Start the FastAPI application using Uvicorn."""
uvicorn.run(app, port=port, log_level="info")
47 changes: 47 additions & 0 deletions testbed/server_hooks_slowapi_ratelimiter_poc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# /usr/bin/env -S uv run

# ------------- Run ClassifAI ------------- #

# Load packages
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address

from classifai.indexers import VectorStore
from classifai.servers.main import setup_api, start_api

# Initialise a Vectoriser
from classifai.vectorisers import HuggingFaceVectoriser

vectoriser = HuggingFaceVectoriser(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = VectorStore(
file_name="fake_soc_dataset.csv",
data_type="csv",
vectoriser=vectoriser,
meta_data=None,
overwrite=True,
output_dir="outputs",
)


# Define server hooks:
def test_pre_search_hook(search_endpt_function):
print("Pre-search hook executed")
return search_endpt_function


limiter = Limiter(key_func=get_remote_address)
server_hooks = [
{
"search": {
"decorators": [test_pre_search_hook, limiter.limit("2/minute")],
},
},
]

app = setup_api(vector_stores=[vector_store], endpoint_names=["fake_soc"], hooks=server_hooks)

app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

start_api(app, port=8000)