diff --git a/src/classifai/servers/main.py b/src/classifai/servers/main.py index 6ff6596..cb0f405 100644 --- a/src/classifai/servers/main.py +++ b/src/classifai/servers/main.py @@ -14,7 +14,7 @@ import uvicorn # New imports -from fastapi import FastAPI, Query +from fastapi import FastAPI, Query, Request from fastapi.responses import RedirectResponse from ..indexers.dataclasses import VectorStoreEmbedInput, VectorStoreReverseSearchInput, VectorStoreSearchInput @@ -22,16 +22,14 @@ ClassifaiData, EmbeddingsList, EmbeddingsResponseBody, - ResultsResponseBody, RevClassifaiData, - RevResultsResponseBody, convert_dataframe_to_pydantic_response, convert_dataframe_to_reverse_search_pydantic_response, ) -def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901 - """Initialize and start the FastAPI application with dynamically created endpoints. +def setup_api(vector_stores, endpoint_names, hooks: list[dict] | None = None): # noqa: C901,PLR0915,PLR0912 + """Initialize the FastAPI application with dynamically created endpoints. This function dynamically registers embedding and search endpoints for each provided vector store and endpoint name. It also sets up a default route to redirect users to the API documentation page. @@ -41,8 +39,17 @@ def start_api(vector_stores, endpoint_names, port=8000): # noqa: C901 embedding and search operations for a specific endpoint. endpoint_names (list): A list of endpoint names corresponding to the vector stores. port (int, optional): The port on which the API server will run. Defaults to 8000. - - + hooks (list[dict], optional): A list of hook dictionaries (one per VectorStore) + for additional configurations. Defaults to [{}]. + Hook dictionaries should be in the format + { + "embed": { + "decorators": [callable_1, callable_2, ...], + "pre_endpt": callable, + "post_endpt": callable, + }, + ... + } """ if len(vector_stores) != len(endpoint_names): raise ValueError("The number of vector stores must match the number of endpoint names.") @@ -60,13 +67,14 @@ def create_embedding_endpoint(app, endpoint_name, vector_store): app (FastAPI): The FastAPI application instance. endpoint_name (str): The name of the endpoint to be created. vector_store: The vector store object responsible for generating embeddings. + hooks (dict, optional): A dictionary of hooks for additional configurations. + Defaults to {}. The created endpoint accepts POST requests with input data, generates embeddings for the provided documents, and returns the results in a structured format. """ - @app.post(f"/{endpoint_name}/embed", description=f"{endpoint_name} embedding endpoint") - async def embedding_endpoint(data: ClassifaiData) -> EmbeddingsResponseBody: + async def embedding_endpoint(request: Request, data: ClassifaiData): input_ids = [x.id for x in data.entries] documents = [x.description for x in data.entries] @@ -85,6 +93,8 @@ async def embedding_endpoint(data: ClassifaiData) -> EmbeddingsResponseBody: ) return EmbeddingsResponseBody(data=returnable) + return embedding_endpoint + def create_search_endpoint(app, endpoint_name, vector_store): """Create and register a search endpoint for a specific vector store. @@ -98,8 +108,8 @@ def create_search_endpoint(app, endpoint_name, vector_store): the vector store and returns the results in a structured format. """ - @app.post(f"/{endpoint_name}/search", description=f"{endpoint_name} search endpoint") async def search_endpoint( + request: Request, data: ClassifaiData, n_results: Annotated[ int, @@ -108,7 +118,7 @@ async def search_endpoint( ge=1, # Ensure at least one result is returned ), ] = 10, - ) -> ResultsResponseBody: + ): input_ids = [x.id for x in data.entries] queries = [x.description for x in data.entries] @@ -123,6 +133,8 @@ async def search_endpoint( return formatted_result + return search_endpoint + def create_reverse_search_endpoint(app, endpoint_name, vector_store): """Create and register a reverse_search endpoint for a specific vector store. @@ -136,8 +148,8 @@ def create_reverse_search_endpoint(app, endpoint_name, vector_store): the vector store and returns the results in a structured format. """ - @app.post(f"/{endpoint_name}/reverse_search", description=f"{endpoint_name} reverse query endpoint") def reverse_search_endpoint( + request: Request, data: RevClassifaiData, n_results: Annotated[ int, @@ -145,7 +157,7 @@ def reverse_search_endpoint( description="The max number of results to return.", ), ] = 100, - ) -> RevResultsResponseBody: + ): input_ids = [x.id for x in data.entries] queries = [x.code for x in data.entries] @@ -158,11 +170,44 @@ def reverse_search_endpoint( ) return formatted_result - for endpoint_name, vector_store in zip(endpoint_names, vector_stores, strict=True): + return reverse_search_endpoint + + for endpoint_name, vector_store, hooks_dict in zip(endpoint_names, vector_stores, hooks, strict=True): logging.info("Registering endpoints for: %s", endpoint_name) - create_embedding_endpoint(app, endpoint_name, vector_store) - create_search_endpoint(app, endpoint_name, vector_store) - create_reverse_search_endpoint(app, endpoint_name, vector_store) + embedding_endpt = create_embedding_endpoint(app, endpoint_name, vector_store) + if hooks_dict is not None: + if hooks_dict.get("embed"): + for decorator in hooks_dict["embed"].get("decorators", []): + embedding_endpt = decorator(embedding_endpt) + embedding_endpt = app.post( + f"/{endpoint_name}/embed", description=f"{endpoint_name} embedding endpoint" + )(embedding_endpt) + if pre_endpt := hooks_dict["embed"].get("pre_endpt"): + embedding_endpt = pre_endpt(embedding_endpt) + if post_endpt := hooks_dict["embed"].get("post_endpt"): + embedding_endpt = post_endpt(embedding_endpt) + search_endpt = create_search_endpoint(app, endpoint_name, vector_store) + if hooks_dict.get("search"): + for decorator in hooks_dict["search"].get("decorators", []): + search_endpt = decorator(search_endpt) + search_endpt = app.post(f"/{endpoint_name}/search", description=f"{endpoint_name} search endpoint")( + search_endpt + ) + if pre_endpt := hooks_dict["search"].get("pre_endpt"): + search_endpt = pre_endpt(search_endpt) + if post_endpt := hooks_dict["search"].get("post_endpt"): + search_endpt = post_endpt(search_endpt) + reverse_search_endpt = create_reverse_search_endpoint(app, endpoint_name, vector_store) + if hooks_dict.get("reverse_search"): + for decorator in hooks_dict["reverse_search"].get("decorators", []): + reverse_search_endpt = decorator(reverse_search_endpt) + reverse_search_endpt = app.post( + f"/{endpoint_name}/reverse_search", description=f"{endpoint_name} reverse query endpoint" + )(reverse_search_endpt) + if pre_endpt := hooks_dict["reverse_search"].get("pre_endpt"): + reverse_search_endpt = pre_endpt(reverse_search_endpt) + if post_endpt := hooks_dict["reverse_search"].get("post_endpt"): + reverse_search_endpt = post_endpt(reverse_search_endpt) @app.get("/", description="UI accessibility") def docs(): @@ -174,4 +219,9 @@ def docs(): start_page = RedirectResponse(url="/docs") return start_page + return app + + +def start_api(app, port: int = 8000): + """Start the FastAPI application using Uvicorn.""" uvicorn.run(app, port=port, log_level="info") diff --git a/testbed/server_hooks_slowapi_ratelimiter_poc.py b/testbed/server_hooks_slowapi_ratelimiter_poc.py new file mode 100755 index 0000000..70624dd --- /dev/null +++ b/testbed/server_hooks_slowapi_ratelimiter_poc.py @@ -0,0 +1,47 @@ +# /usr/bin/env -S uv run + +# ------------- Run ClassifAI ------------- # + +# Load packages +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address + +from classifai.indexers import VectorStore +from classifai.servers.main import setup_api, start_api + +# Initialise a Vectoriser +from classifai.vectorisers import HuggingFaceVectoriser + +vectoriser = HuggingFaceVectoriser(model_name="sentence-transformers/all-MiniLM-L6-v2") +vector_store = VectorStore( + file_name="fake_soc_dataset.csv", + data_type="csv", + vectoriser=vectoriser, + meta_data=None, + overwrite=True, + output_dir="outputs", +) + + +# Define server hooks: +def test_pre_search_hook(search_endpt_function): + print("Pre-search hook executed") + return search_endpt_function + + +limiter = Limiter(key_func=get_remote_address) +server_hooks = [ + { + "search": { + "decorators": [test_pre_search_hook, limiter.limit("2/minute")], + }, + }, +] + +app = setup_api(vector_stores=[vector_store], endpoint_names=["fake_soc"], hooks=server_hooks) + +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +start_api(app, port=8000)