diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml deleted file mode 100644 index 5f8f8d1a..00000000 --- a/.github/workflows/build_docs.yml +++ /dev/null @@ -1,42 +0,0 @@ -# Simple workflow for deploying static content to GitHub Pages -name: Build docs based on Sphinx - -on: - # Runs on pushes targeting the default branch - push: - branches: ["main"] - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -jobs: - build: - environment: - name: build-docs - runs-on: ubuntu-22.04 - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set working directory - uses: ./.github/actions/set-working-directory - with: - working_directory: "${{ github.workspace }}" - - - name: Setup python environment - uses: ./.github/actions/setup-python-env - with: - working_directory: "${{ env.WORKING_DIRECTORY }}" - - - name: Build docs - run: | - cd docs - make html - - - uses: EndBug/add-and-commit@v9.1.1 - with: - add: "${{ env.WORKING_DIRECTORY }}/docs/_build/html" - default_author: github_actor - fetch: true - message: "[skip ci] Update docs" - pathspec_error_handling: ignore diff --git a/.github/workflows/static_docs.yml b/.github/workflows/static_docs.yml deleted file mode 100644 index 467f7c6a..00000000 --- a/.github/workflows/static_docs.yml +++ /dev/null @@ -1,42 +0,0 @@ -# Simple workflow for deploying static content to GitHub Pages -name: Deploy docs to GitHub Pages - -on: - # Runs on pushes targeting the default branch - release: - types: [published] - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages -permissions: - contents: read - pages: write - id-token: write - -# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. -# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. -concurrency: - group: "pages" - cancel-in-progress: false - -jobs: - deploy: - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Setup Pages - uses: actions/configure-pages@v5 - - name: Upload artifact - uses: actions/upload-pages-artifact@v3 - with: - # Upload entire repository - path: './docs/_build/html' - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index d4bb2cbb..00000000 --- a/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = . -BUILDDIR = _build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_build/doctrees/environment.pickle b/docs/_build/doctrees/environment.pickle deleted file mode 100644 index b5586057..00000000 Binary files a/docs/_build/doctrees/environment.pickle and /dev/null differ diff --git a/docs/_build/doctrees/index.doctree b/docs/_build/doctrees/index.doctree deleted file mode 100644 index 6e3a720a..00000000 Binary files a/docs/_build/doctrees/index.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/examples.doctree b/docs/_build/doctrees/source/examples.doctree deleted file mode 100644 index 490fd20d..00000000 Binary files a/docs/_build/doctrees/source/examples.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/intro.doctree b/docs/_build/doctrees/source/intro.doctree deleted file mode 100644 index 0a395c3c..00000000 Binary files a/docs/_build/doctrees/source/intro.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/modules.doctree b/docs/_build/doctrees/source/modules.doctree deleted file mode 100644 index 65a16513..00000000 Binary files a/docs/_build/doctrees/source/modules.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.doctree b/docs/_build/doctrees/source/notdiamond.doctree deleted file mode 100644 index a6bb7356..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.llms.doctree b/docs/_build/doctrees/source/notdiamond.llms.doctree deleted file mode 100644 index 36a4bd29..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.llms.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.metrics.doctree b/docs/_build/doctrees/source/notdiamond.metrics.doctree deleted file mode 100644 index 2b5f464a..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.metrics.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.prompts.doctree b/docs/_build/doctrees/source/notdiamond.prompts.doctree deleted file mode 100644 index ea386fc9..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.prompts.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.toolkit.doctree b/docs/_build/doctrees/source/notdiamond.toolkit.doctree deleted file mode 100644 index a7bb0342..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.toolkit.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.toolkit.litellm.doctree b/docs/_build/doctrees/source/notdiamond.toolkit.litellm.doctree deleted file mode 100644 index 7756d09a..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.toolkit.litellm.doctree and /dev/null differ diff --git a/docs/_build/doctrees/source/notdiamond.toolkit.rag.doctree b/docs/_build/doctrees/source/notdiamond.toolkit.rag.doctree deleted file mode 100644 index 6e46cb95..00000000 Binary files a/docs/_build/doctrees/source/notdiamond.toolkit.rag.doctree and /dev/null differ diff --git a/docs/_build/html/.buildinfo b/docs/_build/html/.buildinfo deleted file mode 100644 index f02d9643..00000000 --- a/docs/_build/html/.buildinfo +++ /dev/null @@ -1,4 +0,0 @@ -# Sphinx build info version 1 -# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 204a5c462923e8038e5918346d8d885f -tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/_build/html/.doctrees/environment.pickle b/docs/_build/html/.doctrees/environment.pickle deleted file mode 100644 index d514ea30..00000000 Binary files a/docs/_build/html/.doctrees/environment.pickle and /dev/null differ diff --git a/docs/_build/html/.doctrees/index.doctree b/docs/_build/html/.doctrees/index.doctree deleted file mode 100644 index b9d2fedb..00000000 Binary files a/docs/_build/html/.doctrees/index.doctree and /dev/null differ diff --git a/docs/_build/html/.doctrees/source/intro.doctree b/docs/_build/html/.doctrees/source/intro.doctree deleted file mode 100644 index 7fd60461..00000000 Binary files a/docs/_build/html/.doctrees/source/intro.doctree and /dev/null differ diff --git a/docs/_build/html/.doctrees/source/notdiamond.doctree b/docs/_build/html/.doctrees/source/notdiamond.doctree deleted file mode 100644 index fd2fc794..00000000 Binary files a/docs/_build/html/.doctrees/source/notdiamond.doctree and /dev/null differ diff --git a/docs/_build/html/.doctrees/source/notdiamond.llms.doctree b/docs/_build/html/.doctrees/source/notdiamond.llms.doctree deleted file mode 100644 index ad317ab0..00000000 Binary files a/docs/_build/html/.doctrees/source/notdiamond.llms.doctree and /dev/null differ diff --git a/docs/_build/html/.doctrees/source/notdiamond.metrics.doctree b/docs/_build/html/.doctrees/source/notdiamond.metrics.doctree deleted file mode 100644 index c6d491fa..00000000 Binary files a/docs/_build/html/.doctrees/source/notdiamond.metrics.doctree and /dev/null differ diff --git a/docs/_build/html/.doctrees/source/notdiamond.toolkit.doctree b/docs/_build/html/.doctrees/source/notdiamond.toolkit.doctree deleted file mode 100644 index bd17188b..00000000 Binary files a/docs/_build/html/.doctrees/source/notdiamond.toolkit.doctree and /dev/null differ diff --git a/docs/_build/html/.doctrees/source/notdiamond.toolkit.rag.doctree b/docs/_build/html/.doctrees/source/notdiamond.toolkit.rag.doctree deleted file mode 100644 index 5e35558b..00000000 Binary files a/docs/_build/html/.doctrees/source/notdiamond.toolkit.rag.doctree and /dev/null differ diff --git a/docs/_build/html/.nojekyll b/docs/_build/html/.nojekyll deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/_build/html/_modules/index.html b/docs/_build/html/_modules/index.html deleted file mode 100644 index efe356cf..00000000 --- a/docs/_build/html/_modules/index.html +++ /dev/null @@ -1,115 +0,0 @@ - - -
- - -
-import os
-from typing import Dict, List, Union
-
-from notdiamond.toolkit._retry import (
- AsyncRetryWrapper,
- ClientType,
- ModelType,
- OpenAIMessagesType,
- RetryManager,
- RetryWrapper,
-)
-
-
-
-[docs]
-def init(
- client: Union[ClientType, List[ClientType]],
- models: ModelType,
- max_retries: Union[int, Dict[str, int]] = 1,
- timeout: Union[float, Dict[str, float]] = 60.0,
- model_messages: Dict[str, OpenAIMessagesType] = None,
- api_key: Union[str, None] = None,
- async_mode: bool = False,
- backoff: Union[float, Dict[str, float]] = 2.0,
-) -> RetryManager:
- """
- Entrypoint for fallback and retry features without changing existing code.
-
- Add this to existing codebase without other modifications to enable the following capabilities:
-
- - Fallback to a different model if a model invocation fails.
- - If configured, fallback to a different *provider* if a model invocation fails
- (eg. azure/gpt-4o fails -> invoke openai/gpt-4o)
- - Load-balance between models and providers, if specified.
- - Pass timeout and retry configurations to each invoke, optionally configured per model.
- - Pass model-specific messages on each retry (prepended to the provided `messages` parameter)
-
- Parameters:
- client (Union[ClientType, List[ClientType]]): Clients to apply retry/fallback logic to.
- models (Union[Dict[str, float], List[str]]):
- Models to use of the format <provider>/<model>.
- Supports two formats:
- - List of models, eg. ["openai/gpt-4o", "azure/gpt-4o"]. Models will be prioritized as listed.
- - Dict of models to weights for load balancing, eg. {"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1}.
- If a model invocation fails, the next model is selected by sampling using the *remaining* weights.
- max_retries (Union[int, Dict[str, int]]):
- Maximum number of retries. Can be configured globally or per model.
- timeout (Union[float, Dict[str, float]]):
- Timeout in seconds per model. Can be configured globally or per model.
- model_messages (Dict[str, OpenAIMessagesType]):
- Model-specific messages to prepend to `messages` on each invocation, formatted OpenAI-style. Can be
- configured using any role which is valid as an initial message (eg. "system" or "user", but not "assistant").
- api_key (Optional[str]):
- Not Diamond API key for authentication. Unused for now - will offer logging and metrics in the future.
- async_mode (bool):
- Whether to manage clients as async.
- backoff (Union[float, Dict[str, float]]):
- Backoff factor for exponential backoff per each retry. Can be configured globally or per model.
-
- Returns:
- RetryManager: Manager object that handles retries and fallbacks. Not required for usage.
-
- Model Fallback Prioritization
- -----------------------------
-
- - If models is a list, the fallback model is selected in order after removing the failed model.
- eg. If "openai/gpt-4o" fails for the list:
- - ["openai/gpt-4o", "azure/gpt-4o"], "azure/gpt-4o" will be tried next
- - ["openai/gpt-4o-mini", "openai/gpt-4o", "azure/gpt-4o"], "openai/gpt-4o-mini" will be tried next.
- - If models is a dict, the next model is selected by sampling using the *remaining* weights.
- eg. If "openai/gpt-4o" fails for the dict:
- - {"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1}, "azure/gpt-4o" will be invoked 100% of the time
- - {"openai/gpt-4o": 0.5, "azure/gpt-4o": 0.25, "openai/gpt-4o-mini": 0.25}, then "azure/gpt-4o" and
- "openai/gpt-4o-mini" can be invoked with 50% probability each.
-
- Usage
- -----
-
- Please refer to tests/test_init.py for more examples on how to use notdiamond.init.
-
- .. code-block:: python
-
- # ...existing workflow code, including client initialization...
- openai_client = OpenAI(...)
- azure_client = AzureOpenAI(...)
-
- # Add `notdiamond.init` to the workflow.
- notdiamond.init(
- [openai_client, azure_client],
- models={"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1},
- max_retries={"openai/gpt-4o": 3, "azure/gpt-4o": 1},
- timeout={"openai/gpt-4o": 10.0, "azure/gpt-4o": 5.0},
- model_messages={
- "openai/gpt-4o": [{"role": "user", "content": "Here is a prompt for OpenAI."}],
- "azure/gpt-4o": [{"role": "user", "content": "Here is a prompt for Azure."}],
- },
- api_key="sk-...",
- backoff=2.0,
- )
-
- # ...continue existing workflow code...
- response = openai_client.chat.completions.create(
- model="notdiamond",
- messages=[{"role": "user", "content": "Hello!"}]
- )
-
- """
- api_key = api_key or os.getenv("NOTDIAMOND_API_KEY")
-
- if async_mode:
- wrapper_cls = AsyncRetryWrapper
- else:
- wrapper_cls = RetryWrapper
-
- for model in models:
- if len(model.split("/")) != 2:
- raise ValueError(
- f"Model {model} must be in the format <provider>/<model>."
- )
-
- if not isinstance(client, List):
- client_wrappers = [
- wrapper_cls(
- client=client,
- models=models,
- max_retries=max_retries,
- timeout=timeout,
- model_messages=model_messages,
- api_key=api_key,
- backoff=backoff,
- )
- ]
- else:
- client_wrappers = [
- wrapper_cls(
- client=cc,
- models=models,
- max_retries=max_retries,
- timeout=timeout,
- model_messages=model_messages,
- api_key=api_key,
- backoff=backoff,
- )
- for cc in client
- ]
- retry_manager = RetryManager(models, client_wrappers)
-
- return retry_manager
-
-
-from typing import Any
-
-from langchain_core.callbacks.base import BaseCallbackHandler
-
-from notdiamond.llms.provider import NDLLMProvider
-
-
-
-[docs]
-class NDLLMBaseCallbackHandler(BaseCallbackHandler):
- """
- Base callback handler for NotDiamond LLMs.
- Accepts all of the langchain_core callbacks and adds new ones.
- """
-
-
-[docs]
- def on_model_select(
- self, model_provider: NDLLMProvider, model_name: str
- ) -> Any:
- """
- Called when a model is selected.
- """
-
-
-
-[docs]
- def on_latency_tracking(
- self,
- session_id: str,
- model_provider: NDLLMProvider,
- tokens_per_second: float,
- ):
- """
- Called when latency tracking is enabled.
- """
-
-
-
-[docs]
- def on_api_error(self, error_message: str):
- """
- Called when an NDLLM API error occurs.
- """
-
-
-
-
-[docs]
-class UnsupportedLLMProvider(Exception):
- """The exception class for unsupported LLM provider"""
-
-
-
-
-[docs]
-class UnsupportedEmbeddingProvider(Exception):
- """The exception class for unsupported Embedding provider"""
-
-
-
-
-
-
-
-
-
-
-
-
-[docs]
-class MissingLLMConfigs(Exception):
- """The exception class for empty LLM provider configs array"""
-
-
-
-
-
-
-
-
-
-
-"""NotDiamond client class"""
-
-
-import inspect
-import logging
-import time
-import warnings
-from enum import Enum
-from typing import (
- Any,
- AsyncIterator,
- Callable,
- Dict,
- Iterator,
- List,
- Optional,
- Sequence,
- Tuple,
- Type,
- Union,
-)
-
-# Details: https://python.langchain.com/v0.1/docs/guides/development/pydantic_compatibility/
-from pydantic import BaseModel
-from pydantic_partial import create_partial_model
-
-from notdiamond import settings
-from notdiamond._utils import _module_check, token_counter
-from notdiamond.exceptions import (
- ApiError,
- CreateUnavailableError,
- MissingLLMConfigs,
-)
-from notdiamond.llms.config import LLMConfig
-from notdiamond.llms.providers import is_o1_model
-from notdiamond.llms.request import (
- amodel_select,
- create_preference_id,
- model_select,
- report_latency,
-)
-from notdiamond.metrics.metric import Metric
-from notdiamond.prompts import (
- _curly_escape,
- inject_system_prompt,
- o1_system_prompt_translate,
-)
-from notdiamond.types import NDApiKeyValidator
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-
-class _NDClientTarget(Enum):
- ROUTER = "router"
- INVOKER = "invoker"
-
-
-def _ndllm_factory(import_target: _NDClientTarget = None):
- _invoke_error_msg_tmpl = (
- "{fn_name} is not available. `notdiamond` can generate LLM responses after "
- "installing additional dependencies via `pip install notdiamond[create]`."
- )
-
- _default_llm_config_invalid_warning = "The default LLMConfig set is invalid. Defaulting to {provider}/{model}"
-
- _no_default_llm_config_warning = (
- "No default LLMConfig set. Defaulting to {provider}/{model}"
- )
-
- class _NDRouterClient(BaseModel):
- api_key: str
- llm_configs: Optional[List[Union[LLMConfig, str]]]
- default: Union[LLMConfig, int, str]
- max_model_depth: Optional[int]
- latency_tracking: bool
- hash_content: bool
- tradeoff: Optional[str]
- preference_id: Optional[str]
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
- callbacks: Optional[List]
- nd_api_url: Optional[str]
- user_agent: Union[str, None]
- max_retries: Optional[int]
- timeout: Optional[Union[float, int]]
-
- class Config:
- arbitrary_types_allowed = True
-
- def __init__(
- self,
- llm_configs: Optional[List[Union[LLMConfig, str]]] = None,
- api_key: Optional[str] = None,
- default: Union[LLMConfig, int, str] = 0,
- max_model_depth: Optional[int] = None,
- latency_tracking: bool = True,
- hash_content: bool = False,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- callbacks: Optional[List] = None,
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = None,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- user_agent: Union[str, None] = None,
- max_retries: Optional[int] = 3,
- timeout: Optional[Union[float, int]] = 60.0,
- **kwargs,
- ):
- if api_key is None:
- api_key = settings.NOTDIAMOND_API_KEY
- NDApiKeyValidator(api_key=api_key)
-
- if user_agent is None:
- user_agent = settings.DEFAULT_USER_AGENT
-
- if llm_configs is not None:
- llm_configs = self._parse_llm_configs_data(llm_configs)
-
- if max_model_depth is None:
- max_model_depth = len(llm_configs)
-
- if max_model_depth > len(llm_configs):
- LOGGER.warning(
- "WARNING: max_model_depth cannot be bigger than the number of LLMs."
- )
- max_model_depth = len(llm_configs)
-
- if tradeoff is not None:
- if tradeoff not in ["cost", "latency"]:
- raise ValueError(
- "Invalid tradeoff. Accepted values: cost, latency."
- )
-
- if tradeoff is not None:
- warnings.warn(
- "The tradeoff constructor parameter is deprecated and will be removed in a "
- "future version. Please specify the tradeoff when using model_select or invocation methods.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(
- api_key=api_key,
- llm_configs=llm_configs,
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- tools=tools,
- callbacks=callbacks,
- nd_api_url=nd_api_url,
- user_agent=user_agent,
- max_retries=max_retries,
- timeout=timeout,
- **kwargs,
- )
- self.user_agent = user_agent
- assert (
- self.api_key is not None
- ), "API key is not set. Please set a Not Diamond API key."
-
- @property
- def chat(self):
- return self
-
- @property
- def completions(self):
- return self
-
- def create_preference_id(self, name: Optional[str] = None) -> str:
- return create_preference_id(self.api_key, name, self.nd_api_url)
-
- async def amodel_select(
- self,
- messages: List[Dict[str, str]],
- input: Optional[Dict[str, Any]] = None,
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- **kwargs,
- ) -> tuple[str, Optional[LLMConfig]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and leaves the execution of the LLM call to the developer.
- The function is async, so it's suitable for async codebases.
-
- Parameters:
- messages (List[Dict[str, str]]): List of messages, OpenAI style.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no
- variables.
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- nd_api_url (Optional[str]): The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
- Returns:
- tuple[str, Optional[LLMConfig]]: returns the session_id and the chosen LLM
- """
- if input is None:
- input = {}
-
- if model is not None:
- llm_configs = self._parse_llm_configs_data(model)
- self.llm_configs = llm_configs
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- best_llm, session_id = await amodel_select(
- messages=messages,
- llm_configs=self.llm_configs,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- previous_session=previous_session,
- timeout=timeout or self.timeout,
- max_retries=max_retries or self.max_retries,
- nd_api_url=self.nd_api_url,
- _user_agent=self.user_agent,
- )
-
- if not best_llm:
- LOGGER.warning(
- f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
- )
- best_llm = self.default_llm
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- return session_id, best_llm
-
- def model_select(
- self,
- messages: List[Dict[str, str]],
- input: Optional[Dict[str, Any]] = None,
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- **kwargs,
- ) -> tuple[str, Optional[LLMConfig]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and leaves the execution of the LLM call to the developer.
-
- Parameters:
- messages (List[Dict[str, str]]): List of messages OpenAI style.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no
- variables.
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Returns:
- tuple[str, Optional[LLMConfig]]: returns the session_id and the chosen LLM
- """
- if input is None:
- input = {}
-
- if model is not None:
- llm_configs = self._parse_llm_configs_data(model)
- self.llm_configs = llm_configs
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- best_llm, session_id = model_select(
- messages=messages,
- llm_configs=self.llm_configs,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- previous_session=previous_session,
- timeout=timeout or self.timeout,
- max_retries=max_retries or self.max_retries,
- nd_api_url=self.nd_api_url,
- _user_agent=self.user_agent,
- )
-
- if not best_llm:
- LOGGER.warning(
- f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
- )
- best_llm = self.default_llm
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- return session_id, best_llm
-
- @staticmethod
- def _parse_llm_configs_data(
- llm_configs: list,
- ) -> List[LLMConfig]:
- providers = []
- for llm_config in llm_configs:
- if isinstance(llm_config, LLMConfig):
- providers.append(llm_config)
- continue
- parsed_provider = LLMConfig.from_string(llm_config)
- providers.append(parsed_provider)
- return providers
-
- def validate_params(
- self,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- ):
- self.default = default
-
- if max_model_depth is not None:
- self.max_model_depth = max_model_depth
-
- if self.llm_configs is None or len(self.llm_configs) == 0:
- raise MissingLLMConfigs(
- "No LLM config speficied. Specify at least one."
- )
-
- if self.max_model_depth is None:
- self.max_model_depth = len(self.llm_configs)
-
- if self.max_model_depth == 0:
- raise ValueError("max_model_depth has to be bigger than 0.")
-
- if self.max_model_depth > len(self.llm_configs):
- LOGGER.warning(
- "WARNING: max_model_depth cannot be bigger than the number of LLMs."
- )
- self.max_model_depth = len(self.llm_configs)
-
- if tradeoff is not None:
- if tradeoff not in ["cost", "latency"]:
- raise ValueError(
- "Invalid tradeoff. Accepted values: cost, latency."
- )
- self.tradeoff = tradeoff
-
- if preference_id is not None:
- self.preference_id = preference_id
-
- if latency_tracking is not None:
- self.latency_tracking = latency_tracking
-
- if hash_content is not None:
- self.hash_content = hash_content
-
- def bind_tools(
- self, tools: Sequence[Union[Dict[str, Any], Callable]]
- ) -> "NotDiamond":
- """
- Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it.
- Results in the tools being available in the LLM object.
- You can access the tool_calls in the result via `result.tool_calls`.
- """
-
- for provider in self.llm_configs:
- if provider.model not in settings.PROVIDERS[
- provider.provider
- ].get("support_tools", []):
- raise ApiError(
- f"{provider.provider}/{provider.model} does not support function calling."
- )
- self.tools = tools
-
- return self
-
- def call_callbacks(self, function_name: str, *args, **kwargs) -> None:
- """
- Call all callbacks with a specific function name.
- """
-
- if self.callbacks is None:
- return
-
- for callback in self.callbacks:
- if hasattr(callback, function_name):
- getattr(callback, function_name)(*args, **kwargs)
-
- def create(*args, **kwargs):
- format_str = f"`{inspect.stack()[0].function}`"
- raise CreateUnavailableError(
- _invoke_error_msg_tmpl.format(fn_name=format_str)
- )
-
- async def acreate(*args, **kwargs):
- format_str = f"`{inspect.stack()[0].function}`"
- raise CreateUnavailableError(
- _invoke_error_msg_tmpl.format(fn_name=format_str)
- )
-
- def invoke(*args, **kwargs):
- format_str = f"`{inspect.stack()[0].function}`"
- raise CreateUnavailableError(
- _invoke_error_msg_tmpl.format(fn_name=format_str)
- )
-
- async def ainvoke(*args, **kwargs):
- format_str = f"`{inspect.stack()[0].function}`"
- raise CreateUnavailableError(
- _invoke_error_msg_tmpl.format(fn_name=format_str)
- )
-
- def stream(*args, **kwargs):
- raise CreateUnavailableError(
- _invoke_error_msg_tmpl.format(
- fn_name=inspect.stack()[0].function
- )
- )
-
- async def astream(*args, **kwargs):
- raise CreateUnavailableError(
- _invoke_error_msg_tmpl.format(
- fn_name=inspect.stack()[0].function
- )
- )
-
- @property
- def default_llm(self) -> LLMConfig:
- """
- Return the default LLM that's set on the NotDiamond client class.
- """
- if isinstance(self.default, int):
- if self.default < len(self.llm_configs):
- return self.llm_configs[self.default]
-
- if isinstance(self.default, str):
- try:
- default = LLMConfig.from_string(self.default)
- if default in self.llm_configs:
- return default
- except Exception as e:
- LOGGER.debug(f"Error setting default llm: {e}")
-
- if isinstance(self.default, LLMConfig):
- return self.default
-
- default = self.llm_configs[0]
- if self.default is None:
- LOGGER.info(
- _no_default_llm_config_warning.format(
- provider=default.provider, model=default.model
- )
- )
- else:
- LOGGER.info(
- _default_llm_config_invalid_warning.format(
- provider=default.provider, model=default.model
- )
- )
- return default
-
- # Do not import from langchain_core directly, as it is now an optional SDK dependency
- try:
- LLM = _module_check("langchain_core.language_models.llms", "LLM")
- BaseMessageChunk = _module_check(
- "langchain_core.messages", "BaseMessageChunk"
- )
- JsonOutputParser = _module_check(
- "langchain_core.output_parsers", "JsonOutputParser"
- )
- ChatPromptTemplate = _module_check(
- "langchain_core.prompts", "ChatPromptTemplate"
- )
- except (ModuleNotFoundError, ImportError) as ierr:
- msg = _invoke_error_msg_tmpl.format(fn_name="NotDiamond creation")
- if import_target == _NDClientTarget.INVOKER:
- msg += " Create was requested, however - raising..."
- raise ImportError(msg) from ierr
- else:
- LOGGER.debug(msg)
- return _NDRouterClient
-
- class _NDInvokerClient(_NDRouterClient, LLM):
- """
- Implementation of NotDiamond class, the main class responsible for creating and invoking LLM prompts.
- The class inherits from Langchain's LLM class. Starting reference is from here:
- https://python.langchain.com/docs/modules/model_io/llms/custom_llm
-
- It's mandatory to have an API key set. If the api_key is not explicitly specified,
- it will check for NOTDIAMOND_API_KEY in the .env file.
-
- Raises:
- MissingLLMProviders: you must specify at least one LLM provider for the router to work
- ApiError: error raised when the NotDiamond API call fails.
- Ensure to set a default LLM provider to not break the code.
- """
-
- api_key: str
- llm_configs: Optional[List[Union[LLMConfig, str]]]
- default: Union[LLMConfig, int, str]
- max_model_depth: Optional[int]
- latency_tracking: bool
- hash_content: bool
- tradeoff: Optional[str]
- preference_id: Optional[str]
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
- callbacks: Optional[List]
- nd_api_url: Optional[str]
- user_agent: Union[str, None]
-
- def __init__(
- self,
- llm_configs: Optional[List[Union[LLMConfig, str]]] = None,
- api_key: Optional[str] = None,
- default: Union[LLMConfig, int, str] = 0,
- max_model_depth: Optional[int] = None,
- latency_tracking: bool = True,
- hash_content: bool = False,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = None,
- callbacks: Optional[List] = None,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- user_agent: Union[str, None] = None,
- timeout: Optional[Union[float, int]] = 60.0,
- max_retries: Optional[int] = 3,
- **kwargs,
- ) -> None:
- super().__init__(
- api_key=api_key,
- llm_configs=llm_configs,
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- tools=tools,
- callbacks=callbacks,
- nd_api_url=nd_api_url,
- user_agent=user_agent,
- timeout=timeout,
- max_retries=max_retries,
- **kwargs,
- )
- if user_agent is None:
- user_agent = settings.DEFAULT_USER_AGENT
-
- if tradeoff is not None:
- warnings.warn(
- "The tradeoff constructor parameter is deprecated and will be removed in a "
- "future version. Please specify the tradeoff when using model_select or invocation methods.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- self.user_agent = user_agent
- assert (
- self.api_key is not None
- ), "API key is not set. Please set a Not Diamond API key."
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- address = hex(id(self)) # Gets the memory address of the object
- return f"<{class_name} object at {address}>"
-
- @property
- def _llm_type(self) -> str:
- return "NotDiamond LLM"
-
- @staticmethod
- def _inject_model_instruction(messages, parser):
- format_instructions = parser.get_format_instructions()
- format_instructions = format_instructions.replace(
- "{", "{{"
- ).replace("}", "}}")
- messages[0]["content"] = (
- format_instructions + "\n" + messages[0]["content"]
- )
- return messages
-
- def _call(
- self,
- prompt: str,
- stop: Optional[List[str]] = None,
- run_manager: Optional[Any] = None,
- **kwargs: Any,
- ) -> str:
- if stop is not None:
- raise ValueError("stop kwargs are not permitted.")
- return "This function is deprecated for the latest LangChain version, use invoke instead"
-
- def create(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- response_model: Optional[Type[BaseModel]] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- **kwargs,
- ) -> tuple[str, str, LLMConfig]:
- """
- Function call to invoke the LLM, with the same interface
- as the OpenAI Python library.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
- response into the given model. In which case result will a
- dict.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- nd_api_url (Optional[str]): The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- LLMConfig: the best LLM selected by the router
- """
-
- return self.invoke(
- messages=messages,
- model=model,
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- metric=metric,
- previous_session=previous_session,
- response_model=response_model,
- timeout=timeout,
- max_retries=max_retries,
- **kwargs,
- )
-
- async def acreate(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- response_model: Optional[Type[BaseModel]] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- **kwargs,
- ) -> tuple[str, str, LLMConfig]:
- """
- Async function call to invoke the LLM, with the same interface
- as the OpenAI Python library.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
- response into the given model. In which case result will a
- dict.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- LLMConfig: the best LLM selected by the router
- """
-
- result = await self.ainvoke(
- messages=messages,
- model=model,
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- metric=metric,
- previous_session=previous_session,
- response_model=response_model,
- timeout=timeout,
- max_retries=max_retries,
- **kwargs,
- )
- return result
-
- def invoke(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- response_model: Optional[Type[BaseModel]] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- input: Optional[Dict[str, Any]] = None,
- **kwargs,
- ) -> tuple[str, str, LLMConfig]:
- """
- Function to invoke the LLM. Behind the scenes what happens:
- 1. API call to NotDiamond backend to get the most suitable LLM for the given prompt
- 2. Invoke the returned LLM client side
- 3. Return the response
-
- Parameters:
- prompt_template (Optional(Union[ NDPromptTemplate, NDChatPromptTemplate, str, ])):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
- response into the given model. In which case result will a
- dict.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no
- variables.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- LLMConfig: the best LLM selected by the router
- """
-
- if model is not None:
- llm_configs = self._parse_llm_configs_data(model)
- self.llm_configs = llm_configs
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- # If response_model is present, we will parse the response into the given model
- # doing this here so that if validation errors occur, we can raise them before making the API call
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- if input is None:
- input = {}
-
- best_llm, session_id = model_select(
- messages=messages,
- llm_configs=self.llm_configs,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- previous_session=previous_session,
- timeout=timeout or self.timeout,
- max_retries=max_retries or self.max_retries,
- nd_api_url=self.nd_api_url,
- )
-
- is_default = False
- if not best_llm:
- LOGGER.warning(
- f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
- )
- best_llm = self.default_llm
- is_default = True
-
- if best_llm.system_prompt is not None:
- messages = inject_system_prompt(
- messages, best_llm.system_prompt
- )
-
- messages = o1_system_prompt_translate(messages, best_llm)
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
-
- if self.tools and not is_o1_model(best_llm):
- llm = llm.bind_tools(self.tools)
-
- if response_model is not None:
- messages = _NDInvokerClient._inject_model_instruction(
- messages, response_model_parser
- )
- chain_messages = [
- (msg["role"], _curly_escape(msg["content"]))
- for msg in messages
- ]
- prompt_template = ChatPromptTemplate.from_messages(chain_messages)
- chain = prompt_template | llm
- accepted_errors = _get_accepted_invoke_errors(best_llm.provider)
-
- try:
- if self.latency_tracking:
- result = self._invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_config=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- else:
- result = chain.invoke(input, **kwargs)
- except accepted_errors as e:
- if best_llm.provider == "google":
- LOGGER.warning(
- f"Submitted chat messages are violating Google requirements with error {e}. "
- "If you see this message, `notdiamond` has returned a Google model as the best option, "
- "but the LLM call will fail. If possible, `notdiamond` will fall back to a non-Google model."
- )
-
- non_google_llm = next(
- (
- llm_config
- for llm_config in self.llm_configs
- if llm_config.provider != "google"
- ),
- None,
- )
-
- if non_google_llm is not None:
- best_llm = non_google_llm
- llm = self._llm_from_config(
- best_llm, callbacks=self.callbacks
- )
- if response_model is not None:
- messages = (
- _NDInvokerClient._inject_model_instruction(
- messages, response_model_parser
- )
- )
- chain_messages = [
- (msg["role"], _curly_escape(msg["content"]))
- for msg in messages
- ]
- prompt_template = ChatPromptTemplate.from_messages(
- chain_messages
- )
- chain = prompt_template | llm
-
- if self.latency_tracking:
- result = self._invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_config=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- else:
- result = chain.invoke(input, **kwargs)
- else:
- raise e
- else:
- raise e
-
- if response_model is not None:
- parsed_dict = response_model_parser.parse(result.content)
- result = response_model.parse_obj(parsed_dict)
-
- return result, session_id, best_llm
-
- async def ainvoke(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- response_model: Optional[Type[BaseModel]] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- input: Optional[Dict[str, Any]] = None,
- **kwargs,
- ) -> tuple[str, str, LLMConfig]:
- """
- Function to invoke the LLM. Behind the scenes what happens:
- 1. API call to NotDiamond backend to get the most suitable LLM for the given prompt
- 2. Invoke the returned LLM client side
- 3. Return the response
-
- Parameters:
- messages (List[Dict[str, str]]): List of messages, OpenAI style
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
- response into the given model. In which case result will a dict.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no
- variables.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, LLMConfig]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- LLMConfig: the best LLM selected by the router
- """
-
- if model is not None:
- llm_configs = self._parse_llm_configs_data(model)
- self.llm_configs = llm_configs
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- if input is None:
- input = {}
-
- best_llm, session_id = await amodel_select(
- messages=messages,
- llm_configs=self.llm_configs,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- previous_session=previous_session,
- timeout=timeout or self.timeout,
- max_retries=max_retries or self.max_retries,
- nd_api_url=self.nd_api_url,
- )
-
- is_default = False
- if not best_llm:
- LOGGER.warning(
- f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
- )
- best_llm = self.default_llm
- is_default = True
-
- if best_llm.system_prompt is not None:
- messages = inject_system_prompt(
- messages, best_llm.system_prompt
- )
-
- messages = o1_system_prompt_translate(messages, best_llm)
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
-
- if self.tools and not is_o1_model(best_llm):
- llm = llm.bind_tools(self.tools)
-
- if response_model is not None:
- messages = _NDInvokerClient._inject_model_instruction(
- messages, response_model_parser
- )
- chain_messages = [
- (msg["role"], _curly_escape(msg["content"]))
- for msg in messages
- ]
- prompt_template = ChatPromptTemplate.from_messages(chain_messages)
- chain = prompt_template | llm
- accepted_errors = _get_accepted_invoke_errors(best_llm.provider)
-
- try:
- if self.latency_tracking:
- result = await self._async_invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_config=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- else:
- result = await chain.ainvoke(input, **kwargs)
- except accepted_errors as e:
- if best_llm.provider == "google":
- LOGGER.warning(
- f"Submitted chat messages are violating Google requirements with error {e}. "
- "If you see this message, `notdiamond` has returned a Google model as the best option, "
- "but the LLM call will fail. If possible, `notdiamond` will fall back to a non-Google model."
- )
-
- non_google_llm = next(
- (
- llm_config
- for llm_config in self.llm_configs
- if llm_config.provider != "google"
- ),
- None,
- )
-
- if non_google_llm is not None:
- best_llm = non_google_llm
- llm = self._llm_from_config(
- best_llm, callbacks=self.callbacks
- )
- if response_model is not None:
- messages = (
- _NDInvokerClient._inject_model_instruction(
- messages, response_model_parser
- )
- )
- chain_messages = [
- (msg["role"], _curly_escape(msg["content"]))
- for msg in messages
- ]
- prompt_template = ChatPromptTemplate.from_messages(
- chain_messages
- )
- chain = prompt_template | llm
-
- if self.latency_tracking:
- result = (
- await self._async_invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_config=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- )
- else:
- result = await chain.ainvoke(input, **kwargs)
- else:
- raise e
- else:
- raise e
-
- if response_model is not None:
- parsed_dict = response_model_parser.parse(result.content)
- result = response_model.parse_obj(parsed_dict)
-
- return result, session_id, best_llm
-
- def stream(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- response_model: Optional[Type[BaseModel]] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- **kwargs,
- ) -> Iterator[Union[BaseMessageChunk, BaseModel]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and calls the LLM client side to stream the response.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): List of messages, OpenAI style
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
- response into the given model. In which case result will a
- dict.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Yields:
- Iterator[Union[BaseMessageChunk, BaseModel]]: returns the response in chunks.
- If response_model is present, it will return the partial model object
- """
-
- if model is not None:
- llm_configs = self._parse_llm_configs_data(model)
- self.llm_configs = llm_configs
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- best_llm, session_id = model_select(
- messages=messages,
- llm_configs=self.llm_configs,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- previous_session=previous_session,
- timeout=timeout or self.timeout,
- max_retries=max_retries or self.max_retries,
- nd_api_url=self.nd_api_url,
- )
-
- if not best_llm:
- LOGGER.warning(
- f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
- )
- best_llm = self.default_llm
-
- if best_llm.system_prompt is not None:
- messages = inject_system_prompt(
- messages, best_llm.system_prompt
- )
-
- if response_model is not None:
- messages = _NDInvokerClient._inject_model_instruction(
- messages, response_model_parser
- )
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
- if self.tools:
- llm = llm.bind_tools(self.tools)
-
- if response_model is not None:
- chain = llm | response_model_parser
- else:
- chain = llm
-
- for chunk in chain.stream(messages, **kwargs):
- if response_model is None:
- yield chunk
- else:
- partial_model = create_partial_model(response_model)
- yield partial_model(**chunk)
-
- async def astream(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[LLMConfig]] = None,
- default: Optional[Union[LLMConfig, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: Metric = Metric("accuracy"),
- previous_session: Optional[str] = None,
- response_model: Optional[Type[BaseModel]] = None,
- timeout: Optional[Union[float, int]] = None,
- max_retries: Optional[int] = None,
- **kwargs,
- ) -> AsyncIterator[Union[BaseMessageChunk, BaseModel]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and calls the LLM client side to stream the response. The function is async, so it's suitable for async codebases.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): List of messages, OpenAI style
- model (Optional[List[LLMConfig]]): List of models to choose from.
- default (Optional[Union[LLMConfig, int, str]]): Default LLM.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (Metric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to Metric("accuracy").
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the
- response into the given model. In which case result will a dict.
- timeout (int): The number of seconds to wait before terminating the API call to Not Diamond backend.
- max_retries (int): The number of retries to attempt before giving up.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Yields:
- AsyncIterator[Union[BaseMessageChunk, BaseModel]]: returns the response in chunks.
- If response_model is present, it will return the partial model object
- """
-
- if model is not None:
- llm_configs = self._parse_llm_configs_data(model)
- self.llm_configs = llm_configs
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- best_llm, session_id = await amodel_select(
- messages=messages,
- llm_configs=self.llm_configs,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- previous_session=previous_session,
- timeout=timeout or self.timeout,
- max_retries=max_retries or self.max_retries,
- nd_api_url=self.nd_api_url,
- )
-
- if not best_llm:
- LOGGER.warning(
- f"ND API error. Falling back to default provider={self.default_llm.provider}/{self.default_llm.model}."
- )
- best_llm = self.default_llm
-
- if best_llm.system_prompt is not None:
- messages = inject_system_prompt(
- messages, best_llm.system_prompt
- )
- if response_model is not None:
- messages = _NDInvokerClient._inject_model_instruction(
- messages, response_model_parser
- )
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
- if self.tools:
- llm = llm.bind_tools(self.tools)
-
- if response_model is not None:
- chain = llm | response_model_parser
- else:
- chain = llm
-
- async for chunk in chain.astream(messages, **kwargs):
- if response_model is None:
- yield chunk
- else:
- partial_model = create_partial_model(response_model)
- yield partial_model(**chunk)
-
- async def _async_invoke_with_latency_tracking(
- self,
- session_id: str,
- chain: Any,
- llm_config: LLMConfig,
- input: Optional[Dict[str, Any]] = {},
- is_default: bool = True,
- **kwargs,
- ):
- if session_id in ("NO-SESSION-ID", "") and not is_default:
- error_message = (
- "ND session_id is not valid for latency tracking."
- + "Please check the API response."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- start_time = time.time()
-
- result = await chain.ainvoke(input, **kwargs)
-
- end_time = time.time()
-
- tokens_completed = token_counter(
- model=llm_config.model,
- messages=[{"role": "assistant", "content": result.content}],
- )
- tokens_per_second = tokens_completed / (end_time - start_time)
-
- report_latency(
- session_id=session_id,
- llm_config=llm_config,
- tokens_per_second=tokens_per_second,
- notdiamond_api_key=self.api_key,
- nd_api_url=self.nd_api_url,
- _user_agent=self.user_agent,
- )
- self.call_callbacks(
- "on_latency_tracking",
- session_id,
- llm_config,
- tokens_per_second,
- )
-
- return result
-
- def _invoke_with_latency_tracking(
- self,
- session_id: str,
- chain: Any,
- llm_config: LLMConfig,
- input: Optional[Dict[str, Any]] = {},
- is_default: bool = True,
- **kwargs,
- ):
- LOGGER.debug(f"Latency tracking enabled, session_id={session_id}")
- if session_id in ("NO-SESSION-ID", "") and not is_default:
- error_message = (
- "ND session_id is not valid for latency tracking."
- + "Please check the API response."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- start_time = time.time()
- result = chain.invoke(input, **kwargs)
- end_time = time.time()
-
- tokens_completed = token_counter(
- model=llm_config.model,
- messages=[{"role": "assistant", "content": result.content}],
- )
- tokens_per_second = tokens_completed / (end_time - start_time)
-
- report_latency(
- session_id=session_id,
- llm_config=llm_config,
- tokens_per_second=tokens_per_second,
- notdiamond_api_key=self.api_key,
- nd_api_url=self.nd_api_url,
- _user_agent=self.user_agent,
- )
- self.call_callbacks(
- "on_latency_tracking",
- session_id,
- llm_config,
- tokens_per_second,
- )
-
- return result
-
- @staticmethod
- def _llm_from_config(
- provider: LLMConfig,
- callbacks: Optional[List] = None,
- ) -> Any:
- default_kwargs = {"max_retries": 5, "timeout": 120}
- passed_kwargs = {**default_kwargs, **provider.kwargs}
-
- if provider.provider == "openai":
- ChatOpenAI = _module_check(
- "langchain_openai.chat_models",
- "ChatOpenAI",
- provider.provider,
- )
- if is_o1_model(provider):
- passed_kwargs["temperature"] = 1.0
-
- return ChatOpenAI(
- openai_api_key=provider.api_key,
- model_name=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "anthropic":
- ChatAnthropic = _module_check(
- "langchain_anthropic", "ChatAnthropic", provider.provider
- )
- return ChatAnthropic(
- anthropic_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "google":
- ChatGoogleGenerativeAI = _module_check(
- "langchain_google_genai",
- "ChatGoogleGenerativeAI",
- provider.provider,
- )
- return ChatGoogleGenerativeAI(
- google_api_key=provider.api_key,
- model=provider.model,
- convert_system_message_to_human=True,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "cohere":
- ChatCohere = _module_check(
- "langchain_cohere.chat_models",
- "ChatCohere",
- provider.provider,
- )
- return ChatCohere(
- cohere_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "mistral":
- ChatMistralAI = _module_check(
- "langchain_mistralai.chat_models",
- "ChatMistralAI",
- provider.provider,
- )
- return ChatMistralAI(
- mistral_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "togetherai":
- provider_settings = settings.PROVIDERS.get(
- provider.provider, None
- )
- model_prefixes = provider_settings.get("model_prefix", None)
- model_prefix = model_prefixes.get(provider.model, None)
- del passed_kwargs["max_retries"]
- del passed_kwargs["timeout"]
-
- if model_prefix is not None:
- model = f"{model_prefix}/{provider.model}"
- ChatTogether = _module_check(
- "langchain_together", "ChatTogether", provider.provider
- )
- return ChatTogether(
- together_api_key=provider.api_key,
- model=model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "perplexity":
- del passed_kwargs["max_retries"]
- passed_kwargs["request_timeout"] = passed_kwargs["timeout"]
- del passed_kwargs["timeout"]
- ChatPerplexity = _module_check(
- "langchain_community.chat_models",
- "ChatPerplexity",
- provider.provider,
- )
- return ChatPerplexity(
- pplx_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "replicate":
- provider_settings = settings.PROVIDERS.get(
- provider.provider, None
- )
- model_prefixes = provider_settings.get("model_prefix", None)
- model_prefix = model_prefixes.get(provider.model, None)
- passed_kwargs["request_timeout"] = passed_kwargs["timeout"]
- del passed_kwargs["timeout"]
-
- if model_prefix is not None:
- model = f"replicate/{model_prefix}/{provider.model}"
- ChatLiteLLM = _module_check(
- "langchain_community.chat_models",
- "ChatLiteLLM",
- provider.provider,
- )
- return ChatLiteLLM(
- model=model,
- callbacks=callbacks,
- replicate_api_key=provider.api_key,
- **passed_kwargs,
- )
- raise ValueError(f"Unsupported provider: {provider.provider}")
-
- def verify_against_response_model(self) -> bool:
- """
- Verify that the LLMs support response modeling.
- """
-
- for provider in self.llm_configs:
- if provider.model not in settings.PROVIDERS[
- provider.provider
- ].get("support_response_model", []):
- raise ApiError(
- f"{provider.provider}/{provider.model} does not support response modeling."
- )
-
- return True
-
- if import_target is _NDClientTarget.ROUTER:
- return _NDRouterClient
- return _NDInvokerClient
-
-
-_NDClient = _ndllm_factory()
-
-
-
-[docs]
-class NotDiamond(_NDClient):
- api_key: str
- """
- API key required for making calls to NotDiamond.
- You can get an API key via our dashboard: https://app.notdiamond.ai
- If an API key is not set, it will check for NOTDIAMOND_API_KEY in .env file.
- """
-
- llm_configs: Optional[List[Union[LLMConfig, str]]]
- """The list of LLMs that are available to route between."""
-
- default: Union[LLMConfig, int, str]
- """
- Set a default LLM, so in case anything goes wrong in the flow,
- as for example NotDiamond API call fails, your code won't break and you have
- a fallback model. There are various ways to configure a default model:
-
- - Integer, specifying the index of the default provider from the llm_configs list
- - String, similar how you can specify llm_configs, of structure 'provider_name/model_name'
- - LLMConfig, just directly specify the object of the provider
-
- By default, we will set your first LLM in the list as the default.
- """
-
- max_model_depth: Optional[int]
- """
- If your top recommended model is down, specify up to which depth of routing you're willing to go.
- If max_model_depth is not set, it defaults to the length of the llm_configs list.
- If max_model_depth is set to 0, the init will fail.
- If the value is larger than the llm_configs list length, we reset the value to len(llm_configs).
- """
-
- latency_tracking: bool
- """
- Tracking and sending latency of LLM call to NotDiamond server as feedback, so we can improve our router.
- By default this is turned on, set it to False to turn off.
- """
-
- hash_content: bool
- """
- Hashing the content before being sent to the NotDiamond API.
- By default this is False.
- """
-
- tradeoff: Optional[str]
- """
- [DEPRECATED] The tradeoff constructor parameter is deprecated and will be removed in a future version.
- Please specify the tradeoff when using model_select or invocation methods.
-
- Define tradeoff between "cost" and "latency" for the router to determine the best LLM for a given query.
- If None is specified, then the router will not consider either cost or latency.
-
- The supported values: "cost", "latency"
-
- Defaults to None.
- """
-
- preference_id: Optional[str]
- """The ID of the router preference that was configured via the Dashboard. Defaults to None."""
-
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
- """Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it."""
-
- nd_api_url: Optional[str]
- """The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL."""
-
- user_agent: Union[str, None]
-
- max_retries: int
- """The maximum number of retries to make when calling the Not Diamond API."""
-
- timeout: float
- """The timeout for the Not Diamond API call."""
-
-
-
-
- def __init__(
- self,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- user_agent: Union[str, None] = settings.DEFAULT_USER_AGENT,
- *args,
- **kwargs,
- ):
- super().__init__(
- nd_api_url=nd_api_url, user_agent=user_agent, *args, **kwargs
- )
- self.nd_api_url = nd_api_url
-
- if kwargs.get("tradeoff") is not None:
- warnings.warn(
- "The tradeoff constructor parameter is deprecated and will be removed in a "
- "future version. Please specify the tradeoff when using model_select or invocation methods.",
- DeprecationWarning,
- stacklevel=2,
- )
-
-
-
-def _get_accepted_invoke_errors(provider: str) -> Tuple:
- if provider == "google":
- ChatGoogleGenerativeAIError = _module_check(
- "langchain_google_genai.chat_models",
- "ChatGoogleGenerativeAIError",
- provider,
- )
- accepted_errors = (ChatGoogleGenerativeAIError, ValueError)
- else:
- accepted_errors = (ValueError,)
- return accepted_errors
-
-import logging
-from typing import Optional
-
-from notdiamond import settings
-from notdiamond.exceptions import (
- UnsupportedEmbeddingProvider,
- UnsupportedLLMProvider,
-)
-
-POSSIBLE_PROVIDERS = list(settings.PROVIDERS.keys())
-POSSIBLE_MODELS = list(
- model
- for provider_values in settings.PROVIDERS.values()
- for values in provider_values.values()
- if isinstance(values, list)
- for model in values
-)
-
-POSSIBLE_EMBEDDING_PROVIDERS = [
- *list(settings.EMBEDDING_PROVIDERS.keys()),
- "huggingface",
-]
-POSSIBLE_EMBEDDING_MODELS = list(
- model
- for provider_values in settings.EMBEDDING_PROVIDERS.values()
- for values in provider_values.values()
- if isinstance(values, list)
- for model in values
-)
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-
-
-[docs]
-class LLMConfig:
- """
- A NotDiamond LLM provider config (or LLMConfig) is represented by a combination of provider and model.
- Provider refers to the company of the foundational model, such as openai, anthropic, google.
- The model represents the model name as defined by the owner company, such as gpt-3.5-turbo
- Beside this you can also specify the API key for each provider, specify extra arguments
- that are also supported by Langchain (eg. temperature), and a system prmopt to be used
- with the provider. If the provider is selected during routing, then the system prompt will
- be used, replacing the one in the message array if there are any.
-
- All supported providers and models can be found in our docs.
-
- If the API key it's not specified, it will try to pick it up from an .env file before failing.
- As example for OpenAI it will look for OPENAI_API_KEY.
-
- Attributes:
- provider (str): The name of the LLM provider (e.g., "openai", "anthropic"). Must be one of the
- predefined providers in `POSSIBLE_PROVIDERS`.
- model (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
- Must be one of the predefined models in `POSSIBLE_MODELS`.
- system_prompt (Optional[str], optional): The system prompt to use for the provider. Defaults to None.
- api_key (Optional[str], optional): The API key for accessing the LLM provider's services.
- Defaults to None, in which case it tries to fetch from the settings.
- openrouter_model (str): The OpenRouter model equivalent for this provider / model
- **kwargs: Additional keyword arguments that might be necessary for specific providers or models.
-
- Raises:
- UnsupportedLLMProvider: If the `provider` or `model` specified is not supported.
- """
-
-
-[docs]
- def __init__(
- self,
- provider: str,
- model: str,
- is_custom: bool = False,
- system_prompt: Optional[str] = None,
- context_length: Optional[int] = None,
- input_price: Optional[float] = None,
- custom_input_price: Optional[float] = None,
- output_price: Optional[float] = None,
- custom_output_price: Optional[float] = None,
- latency: Optional[float] = None,
- custom_latency: Optional[float] = None,
- api_key: Optional[str] = None,
- **kwargs,
- ):
- """_summary_
-
- Args:
- provider (str): The name of the LLM provider (e.g., "openai", "anthropic").
- model (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
- is_custom (bool): Whether this is a custom model. Defaults to False.
- system_prompt (Optional[str], optional): The system prompt to use for the provider. Defaults to None.
- context_length (Optional[int], optional): Custom context window length for the provider/model.
- custom_input_price (Optional[float], optional): Custom input price (USD) per million tokens for this
- provider/model; will default to public input price if available.
- custom_output_price (Optional[float], optional): Custom output price (USD) per million tokens for this
- provider/model; will default to public output price if available.
- custom_latency (Optional[float], optional): Custom latency (time to first token) for provider/model.
- api_key (Optional[str], optional): The API key for accessing the LLM provider's services.
- Defaults to None.
- **kwargs: Additional keyword arguments that might be necessary for specific providers or models.
-
- Raises:
- UnsupportedLLMProvider: If the `provider` or `model` specified is not supported.
- """
- if is_custom:
- self._openrouter_model = None
- self.api_key = api_key
- self.default_input_price = custom_input_price or input_price
- self.default_output_price = custom_output_price or output_price
- else:
- if provider not in POSSIBLE_PROVIDERS:
- raise UnsupportedLLMProvider(
- f"Given LLM provider {provider} is not in the list of supported providers."
- )
- if model not in POSSIBLE_MODELS:
- raise UnsupportedLLMProvider(
- f"Given LLM model {model} is not in the list of supported models."
- )
- self._openrouter_model = settings.PROVIDERS[provider][
- "openrouter_identifier"
- ].get(model, None)
-
- self.api_key = (
- api_key
- if api_key is not None
- else settings.PROVIDERS[provider]["api_key"]
- )
-
- self.default_input_price = settings.PROVIDERS[provider]["price"][
- model
- ]["input"]
- self.default_output_price = settings.PROVIDERS[provider]["price"][
- model
- ]["output"]
-
- self.provider = provider
- self.model = model
- self.system_prompt = system_prompt
-
- self.is_custom = is_custom
- self.context_length = context_length
- self.input_price = custom_input_price or input_price
- self.output_price = custom_output_price or output_price
- self.latency = custom_latency or latency
-
- self.kwargs = kwargs
-
-
- def __str__(self) -> str:
- return f"{self.provider}/{self.model}"
-
- def __repr__(self) -> str:
- return f"LLMConfig({self.provider}/{self.model})"
-
- def __eq__(self, other):
- if isinstance(other, LLMConfig):
- return (
- self.provider == other.provider and self.model == other.model
- )
- return False
-
- def __hash__(self):
- return hash(str(self))
-
- @property
- def openrouter_model(self):
- if self._openrouter_model is None:
- LOGGER.warning(
- f"Configured model {str(self)} is not available via OpenRouter. Please try another model."
- )
- return self._openrouter_model
-
-
-[docs]
- def prepare_for_request(self):
- """
- Converts the LLMConfig object to a dict in the format accepted by
- the NotDiamond API.
-
- Returns:
- dict
- """
- return {
- "provider": self.provider,
- "model": self.model,
- "is_custom": self.is_custom,
- "context_length": self.context_length,
- "input_price": self.input_price,
- "output_price": self.output_price,
- "latency": self.latency,
- }
-
-
-
-[docs]
- def set_api_key(self, api_key: str) -> "LLMConfig":
- self.api_key = api_key
-
- return self
-
-
-
-[docs]
- @classmethod
- def from_string(cls, llm_provider: str):
- """
- We allow our users to specify LLM providers for NotDiamond in the string format 'provider_name/model_name',
- as example 'openai/gpt-3.5-turbo'. Underlying our workflows we want to ensure we use LLMConfig as
- the base type, so this class method converts a string specification of an LLM provider into an
- LLMConfig object.
-
- Args:
- llm_provider (str): this is the string definition of the LLM provider
-
- Returns:
- LLMConfig: initialized object with correct provider and model
- """
- split_items = llm_provider.split("/")
- if len(split_items) not in [2, 3]:
- raise ValueError(
- f"Expected string of format 'provider/model' or 'prefix/provider/model' but got {llm_provider}"
- )
- elif len(split_items) == 3:
- _, provider, model = split_items
- else:
- provider = split_items[0]
- model = split_items[1]
- return cls(provider=provider, model=model)
-
-
-
-
-
-[docs]
-class EmbeddingConfig:
- """
- A NotDiamond embedding provider config (or EmbeddingConfig) is represented by a combination of provider and model.
- Provider refers to the company of the foundational model, such as openai, anthropic, google.
- The model represents the model name as defined by the owner company, such as text-embedding-3-large
- Beside this you can also specify the API key for each provider or extra arguments
- that are also supported by Langchain.
-
- All supported providers and models can be found in our docs.
-
- If the API key is not specified, the Config will try to read the key from an .env file before failing.
- For example, the Config will look for `OPENAI_API_KEY` to authenticate any OpenAI provider.
-
- Attributes:
- provider (str): The name of the LLM provider (e.g., "openai", "anthropic"). Must be one of the
- predefined providers in `POSSIBLE_EMBEDDING_PROVIDERS`.
- model (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
- Must be one of the predefined models in `POSSIBLE_MODELS`.
- api_key (Optional[str], optional): The API key for accessing the LLM provider's services.
- Defaults to None, in which case it tries to fetch from the environment.
- **kwargs: Additional keyword arguments that might be necessary for specific providers or models.
-
- Raises:
- UnsupportedLLMProvider: If the `provider` or `model` specified is not supported.
- """
-
-
-[docs]
- def __init__(
- self,
- provider: str,
- model: str,
- api_key: Optional[str] = None,
- **kwargs,
- ):
- """_summary_
-
- Args:
- provider (str): The name of the embedding provider (e.g., "openai", "anthropic").
- model (str): The name of the embedding model to use (e.g., "text-embedding-3-large").
- api_key (Optional[str], optional): The API key for accessing the embedding provider's services.
- Defaults to None.
- **kwargs: Additional keyword arguments that might be necessary for specific providers or models.
-
- Raises:
- UnsupportedEmbeddingProvider: If the `provider` or `model` specified is not supported.
- """
- if provider not in POSSIBLE_EMBEDDING_PROVIDERS:
- raise UnsupportedEmbeddingProvider(
- f"Given embedding provider {provider} is not in the list of supported providers."
- )
-
- if (
- model not in POSSIBLE_EMBEDDING_MODELS
- and provider != "huggingface"
- ):
- raise UnsupportedEmbeddingProvider(
- f"Given embedding model {model} is not in the list of supported models."
- )
-
- self.api_key = (
- api_key
- if api_key is not None
- else settings.PROVIDERS[provider]["api_key"]
- )
-
- self.provider = provider
- self.model = model
- self.kwargs = kwargs
-
-
- def __str__(self) -> str:
- return f"{self.provider}/{self.model}"
-
- def __repr__(self) -> str:
- return f"EmbeddingConfig({self.provider}/{self.model})"
-
- def __eq__(self, other):
- if isinstance(other, EmbeddingConfig):
- return (
- self.provider == other.provider and self.model == other.model
- )
- return False
-
- def __hash__(self):
- return hash(str(self))
-
-
-[docs]
- def set_api_key(self, api_key: str) -> "EmbeddingConfig":
- self.api_key = api_key
-
- return self
-
-
-
-[docs]
- @classmethod
- def from_string(cls, llm_provider: str):
- """
- We allow our users to specify LLM providers for NotDiamond in the string format 'provider_name/model_name',
- for example 'openai/gpt-3.5-turbo'. Our workflows expect LLMConfig as
- the base type, so this class method converts a string specification of an LLM provider into an
- LLMConfig object.
-
- Args:
- llm_provider (str): this is the string definition of the LLM provider
-
- Returns:
- LLMConfig: initialized object with correct provider and model
- """
- split_items = llm_provider.split("/")
- if len(split_items) not in [2, 3]:
- raise ValueError(
- f"Expected string of format 'provider/model' or 'prefix/provider/model' but got {llm_provider}"
- )
- elif len(split_items) == 3:
- _, provider, model = split_items
- else:
- provider = split_items[0]
- model = split_items[1]
- return cls(provider=provider, model=model)
-
-
-
-"""NDLLM Class"""
-
-import time
-from typing import (
- Any,
- AsyncIterator,
- Callable,
- Dict,
- Iterator,
- List,
- Optional,
- Sequence,
- Type,
- Union,
-)
-
-from langchain.prompts import PromptTemplate
-from langchain_anthropic import ChatAnthropic
-from langchain_cohere.chat_models import ChatCohere
-from langchain_community.chat_models import ChatLiteLLM, ChatPerplexity
-from langchain_core.callbacks.base import BaseCallbackHandler
-from langchain_core.callbacks.manager import CallbackManagerForLLMRun
-from langchain_core.language_models.llms import LLM
-from langchain_core.messages import AIMessage, BaseMessage, BaseMessageChunk
-from langchain_core.output_parsers import JsonOutputParser
-from langchain_core.prompt_values import StringPromptValue
-from langchain_core.prompts import ChatPromptTemplate
-from langchain_google_genai import ChatGoogleGenerativeAI
-from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
-from langchain_mistralai.chat_models import ChatMistralAI
-from langchain_openai import ChatOpenAI
-from langchain_together import Together
-from litellm import token_counter
-from pydantic import BaseModel
-from pydantic_partial import create_partial_model
-
-from notdiamond import settings
-from notdiamond.callbacks import NDLLMBaseCallbackHandler
-from notdiamond.exceptions import ApiError, MissingLLMProviders
-from notdiamond.llms.provider import NDLLMProvider
-from notdiamond.llms.request import amodel_select, model_select, report_latency
-from notdiamond.metrics.metric import NDMetric
-from notdiamond.prompts.prompt import NDChatPromptTemplate, NDPromptTemplate
-from notdiamond.types import NDApiKeyValidator
-
-
-
-[docs]
-class NDLLM(LLM):
- """
- Implementation of NDLLM class, the main class responsible for routing.
- The class inherits from Langchain's LLM class. Starting reference is from here:
- https://python.langchain.com/docs/modules/model_io/llms/custom_llm
-
- It's mandatory to have an API key set. If the api_key is not explicitly specified,
- it will check for NOTDIAMOND_API_KEY in the .env file.
-
- Raises:
- MissingLLMProviders: you must specify at least one LLM provider for the router to work
- ApiError: error raised when the NotDiamond API call fails.
- Ensure to set a default LLM provider to not break the code.
- """
-
- api_key: str
- """
- API key required for making calls to NotDiamond.
- You can get an API key via our dashboard: https://app.notdiamond.ai
- If an API key is not set, it will check for NOTDIAMOND_API_KEY in .env file.
- """
-
- llm_providers: Optional[List[NDLLMProvider]]
- """The list of LLM providers that are available to route between."""
-
- default: Union[NDLLMProvider, int, str]
- """
- Set a default LLM provider, so in case anything goes wrong in the flow,
- as for example NotDiamond API call fails, your code won't break and you have
- a fallback model. There are various ways to configure a default model:
-
- - Integer, specifying the index of the default provider from the llm_providers list
- - String, similar how you can specify llm_providers, of structure 'provider_name/model_name'
- - NDLLMProvider, just directly specify the object of the provider
-
- By default, we will set your first LLM in the list as the default.
- """
-
- max_model_depth: Optional[int]
- """
- If your top recommended model is down, specify up to which depth of routing you're willing to go.
- If max_model_depth is not set, it defaults to the length of the llm_providers list.
- If max_model_depth is set to 0, the init will fail.
- If the value is larger than the llm_providers list length, we reset the value to len(llm_providers).
- """
-
- latency_tracking: bool
- """
- Tracking and sending latency of LLM call to NotDiamond server as feedback, so we can improve our router.
- By default this is turned on, set it to False to turn off.
- """
-
- hash_content: bool
- """
- Hashing the content before being sent to the NotDiamond API.
- By default this is False.
- """
-
- tradeoff: Optional[str]
- """
- Define tradeoff between "cost" and "latency" for the router to determine the best LLM for a given query.
- If None is specified, then the router will not consider either cost or latency.
-
- The supported values: "cost", "latency"
-
- Defaults to None.
- """
-
- preference_id: Optional[str]
- """The ID of the router preference that was configured via the Dashboard. Defaults to None."""
-
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]]
- """Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it."""
-
- callbacks: Optional[
- List[Union[BaseCallbackHandler, NDLLMBaseCallbackHandler]]
- ]
- """
- Callback handler for the LLM object. It will be passed to the LLM object when invoking it.
- Also has custom NDLLM callbacks:
- - on_model_select
- - on_latency_tracking
- - on_api_error
- """
-
- def __init__(
- self,
- llm_providers: Optional[List[NDLLMProvider]] = None,
- api_key: Optional[str] = None,
- default: Union[NDLLMProvider, int, str] = 0,
- max_model_depth: Optional[int] = None,
- latency_tracking: bool = True,
- hash_content: bool = False,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- callbacks: Optional[
- List[Union[BaseCallbackHandler, NDLLMBaseCallbackHandler]]
- ] = None,
- **kwargs,
- ) -> None:
- if api_key is None:
- api_key = settings.NOTDIAMOND_API_KEY
- NDApiKeyValidator(api_key=api_key)
-
- if llm_providers is not None:
- llm_providers = self._parse_llm_providers_data(llm_providers)
-
- if max_model_depth is None:
- max_model_depth = len(llm_providers)
-
- if max_model_depth > len(llm_providers):
- print(
- "WARNING: max_model_depth cannot be bigger than the number of LLM providers."
- )
- max_model_depth = len(llm_providers)
-
- if tradeoff is not None:
- if tradeoff not in ["cost", "latency"]:
- raise ValueError(
- "Invalid tradeoff. Accepted values: cost, latency."
- )
-
- super(NDLLM, self).__init__(
- api_key=api_key,
- llm_providers=llm_providers,
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- callbacks=callbacks,
- **kwargs,
- )
-
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- address = hex(id(self)) # Gets the memory address of the object
- return f"<{class_name} object at {address}>"
-
- @property
- def chat(self):
- return self
-
- @property
- def completions(self):
- return self
-
- @property
- def default_llm_provider(self) -> Union[NDLLMProvider, None]:
- """
- Return the default LLM provider that's set on the NDLLM class.
- """
- if isinstance(self.default, int):
- return self.llm_providers[int(self.default)]
- if isinstance(self.default, str):
- if self.default.isdigit():
- return self.llm_providers[int(self.default)]
- return NDLLMProvider.from_string(self.default)
- if isinstance(self.default, NDLLMProvider):
- return self.default
- return self.llm_providers[0]
-
- @staticmethod
- def _parse_llm_providers_data(llm_providers: list) -> List[NDLLMProvider]:
- providers = []
- for llm_provider in llm_providers:
- if isinstance(llm_provider, NDLLMProvider):
- providers.append(llm_provider)
- continue
- parsed_provider = NDLLMProvider.from_string(llm_provider)
- providers.append(parsed_provider)
- return providers
-
- @property
- def _llm_type(self) -> str:
- return "NotDiamond LLM"
-
- def _call(
- self,
- prompt: str,
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- if stop is not None:
- raise ValueError("stop kwargs are not permitted.")
- return "This function is deprecated for the latest LangChain version, use invoke instead"
-
-
-[docs]
- def create(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[NDLLMProvider]] = None,
- default: Optional[Union[NDLLMProvider, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: NDMetric = NDMetric("accuracy"),
- response_model: Optional[Type[BaseModel]] = None,
- **kwargs,
- ) -> tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- """
- Function call to invoke the LLM, with the same interface
- as the OpenAI Python library.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- model (Optional[List[NDLLMProvider]]): List of models to choose from.
- default (Optional[Union[NDLLMProvider, int, str]]): Default LLM provider.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- NDLLMProvider: the best LLM provider selected by the router
- """
-
- if model is not None:
- llm_providers = self._parse_llm_providers_data(model)
- self.llm_providers = llm_providers
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- return self.invoke(
- messages=messages,
- metric=metric,
- response_model=response_model,
- **kwargs,
- )
-
-
-
-[docs]
- async def acreate(
- self,
- messages: List[Dict[str, str]],
- model: Optional[List[NDLLMProvider]] = None,
- default: Optional[Union[NDLLMProvider, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: NDMetric = NDMetric("accuracy"),
- response_model: Optional[Type[BaseModel]] = None,
- **kwargs,
- ) -> tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- """
- Async function call to invoke the LLM, with the same interface
- as the OpenAI Python library.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- model (Optional[List[NDLLMProvider]]): List of models to choose from.
- default (Optional[Union[NDLLMProvider, int, str]]): Default LLM provider.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- NDLLMProvider: the best LLM provider selected by the router
- """
- if model is not None and len(model) > 0:
- llm_providers = self._parse_llm_providers_data(model)
- self.llm_providers = llm_providers
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- result = await self.ainvoke(
- messages=messages,
- metric=metric,
- response_model=response_model,
- **kwargs,
- )
- return result
-
-
-
-[docs]
- def invoke(
- self,
- prompt_template: Optional[
- Union[
- NDPromptTemplate,
- PromptTemplate,
- NDChatPromptTemplate,
- ChatPromptTemplate,
- str,
- ]
- ] = None,
- messages: Optional[List[Dict[str, str]]] = None,
- input: Optional[Dict[str, Any]] = None,
- metric: NDMetric = NDMetric("accuracy"),
- response_model: Optional[Type[BaseModel]] = None,
- **kwargs,
- ) -> tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- """
- Function to invoke the LLM. Behind the scenes what happens:
- 1. API call to NotDiamond backend to get the most suitable LLM for the given prompt
- 2. Invoke the returned LLM client side
- 3. Return the response
-
- Parameters:
- prompt_template (Optional(Union[ NDPromptTemplate, PromptTemplate, NDChatPromptTemplate, ChatPromptTemplate, str, ])):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no variables.
- metric (NDMetric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to NDMetric("accuracy").
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the response into
- the given model. In which case result will a dict.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- NDLLMProvider: the best LLM provider selected by the router
- """
-
- # If response_model is present, we will parse the response into the given model
- # doing this here so that if validation errors occur, we can raise them before making the API call
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- prompt_template = self._prepare_prompt_template(
- prompt_template,
- messages,
- response_model_parser=response_model_parser,
- )
-
- if input is None:
- input = {}
-
- prompt_template.partial_variables = {
- **prompt_template.partial_variables,
- **input,
- }
-
- best_llm, session_id = model_select(
- prompt_template=prompt_template,
- llm_providers=self.llm_providers,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- )
-
- is_default = False
- if not best_llm:
- best_llm = self.default_llm_provider
- is_default = True
-
- if best_llm is None:
- error_message = (
- "ND couldn't find a suitable model to call."
- + "To avoid disruptions, we recommend setting a default fallback model or increasing max model depth."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- if best_llm.system_prompt is not None:
- prompt_template = prompt_template.inject_system_prompt(
- best_llm.system_prompt
- )
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_provider(best_llm, callbacks=self.callbacks)
-
- if self.tools:
- llm = llm.bind_tools(self.tools)
-
- chain = prompt_template | llm
-
- try:
- if self.latency_tracking:
- result = self._invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_provider=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- else:
- result = chain.invoke(input, **kwargs)
- except (ChatGoogleGenerativeAIError, ValueError) as e:
- if (
- isinstance(prompt_template, NDChatPromptTemplate)
- and best_llm.provider == "google"
- ):
- print(
- f"WARNING: Google model's chat messages are violating requirements with error {e}."
- )
- print(
- "If you see this message, means the NotDiamond API returned a Google model as the best option,"
- + "but the LLM call will fail. So we will automatically fall back to a non-Google model, if possible."
- )
-
- non_google_llm = next(
- (
- llm_provider
- for llm_provider in self.llm_providers
- if llm_provider.provider != "google"
- ),
- None,
- )
-
- if non_google_llm is not None:
- best_llm = non_google_llm
- llm = self._llm_from_provider(
- best_llm, callbacks=self.callbacks
- )
- chain = prompt_template | llm
-
- if self.latency_tracking:
- result = self._invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_provider=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- else:
- result = chain.invoke(input, **kwargs)
- else:
- raise e
- else:
- raise e
-
- if isinstance(result, str):
- result = AIMessage(content=result)
-
- if response_model is not None:
- parsed_dict = response_model_parser.parse(result.content)
- result = response_model.parse_obj(parsed_dict)
-
- return result, session_id, best_llm
-
-
-
-[docs]
- async def ainvoke(
- self,
- prompt_template: Optional[
- Union[
- NDPromptTemplate,
- PromptTemplate,
- NDChatPromptTemplate,
- ChatPromptTemplate,
- str,
- ]
- ] = None,
- messages: Optional[List[Dict[str, str]]] = None,
- input: Optional[Dict[str, Any]] = None,
- metric: NDMetric = NDMetric("accuracy"),
- response_model: Optional[Type[BaseModel]] = None,
- **kwargs,
- ) -> tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- """
- Function to invoke the LLM. Behind the scenes what happens:
- 1. API call to NotDiamond backend to get the most suitable LLM for the given prompt
- 2. Invoke the returned LLM client side
- 3. Return the response
-
- Parameters:
- prompt_template (Optional(Union[ NDPromptTemplate, PromptTemplate, NDChatPromptTemplate, ChatPromptTemplate, str, ])):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no variables.
- metric (NDMetric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to NDMetric("accuracy").
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the response into
- the given model. In which case result will a dict.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Returns:
- tuple[Union[AIMessage, BaseModel], str, NDLLMProvider]:
- result: response type defined by Langchain, contains the response from the LLM.
- or object of the response_model
- str: session_id returned by the NotDiamond API
- NDLLMProvider: the best LLM provider selected by the router
- """
-
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- prompt_template = self._prepare_prompt_template(
- prompt_template,
- messages,
- response_model_parser=response_model_parser,
- )
-
- if input is None:
- input = {}
-
- prompt_template.partial_variables = {
- **prompt_template.partial_variables,
- **input,
- }
-
- best_llm, session_id = await amodel_select(
- prompt_template=prompt_template,
- llm_providers=self.llm_providers,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- )
-
- is_default = False
- if not best_llm:
- best_llm = self.default_llm_provider
- is_default = True
-
- if best_llm is None:
- error_message = (
- "ND couldn't find a suitable model to call."
- + "To avoid disruptions, we recommend setting a default fallback model or make max depth larger."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- if best_llm.system_prompt is not None:
- prompt_template = prompt_template.inject_system_prompt(
- best_llm.system_prompt
- )
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_provider(best_llm, callbacks=self.callbacks)
-
- if self.tools:
- llm = llm.bind_tools(self.tools)
-
- chain = prompt_template | llm
-
- try:
- if self.latency_tracking:
- result = await self._async_invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_provider=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- else:
- result = await chain.ainvoke(input, **kwargs)
- except (ChatGoogleGenerativeAIError, ValueError) as e:
- if (
- isinstance(prompt_template, NDChatPromptTemplate)
- and best_llm.provider == "google"
- ):
- print(
- f"WARNING: Google model's chat messages are violating requirements with error {e}."
- )
- print(
- "If you see this message, means the NotDiamond API returned a Google model as the best option,"
- + "but the LLM call will fail. So we will automatically fall back to a non-Google model, if possible."
- )
-
- non_google_llm = next(
- (
- llm_provider
- for llm_provider in self.llm_providers
- if llm_provider.provider != "google"
- ),
- None,
- )
-
- if non_google_llm is not None:
- best_llm = non_google_llm
- llm = self._llm_from_provider(
- best_llm, callbacks=self.callbacks
- )
- chain = prompt_template | llm
-
- if self.latency_tracking:
- result = (
- await self._async_invoke_with_latency_tracking(
- session_id=session_id,
- chain=chain,
- llm_provider=best_llm,
- is_default=is_default,
- input=input,
- **kwargs,
- )
- )
- else:
- result = await chain.ainvoke(input, **kwargs)
- else:
- raise e
- else:
- raise e
-
- if isinstance(result, str):
- result = AIMessage(content=result)
-
- if response_model is not None:
- parsed_dict = response_model_parser.parse(result.content)
- result = response_model.parse_obj(parsed_dict)
-
- return result, session_id, best_llm
-
-
-
-[docs]
- def stream(
- self,
- prompt_template: Optional[
- Union[
- NDPromptTemplate,
- PromptTemplate,
- NDChatPromptTemplate,
- ChatPromptTemplate,
- str,
- ]
- ] = None,
- messages: Optional[List[Dict[str, str]]] = None,
- input: Optional[Dict[str, Any]] = None,
- metric: NDMetric = NDMetric("accuracy"),
- response_model: Optional[Type[BaseModel]] = None,
- **kwargs,
- ) -> Iterator[Union[BaseMessageChunk, BaseModel]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and calls the LLM client side to stream the response.
-
- Parameters:
- prompt_template (Optional(Union[ NDPromptTemplate, PromptTemplate, NDChatPromptTemplate, ChatPromptTemplate, str, ])):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no variables.
- metric (NDMetric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to NDMetric("accuracy").
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the response into
- the given model. In which case result will a dict.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Yields:
- Iterator[Union[BaseMessageChunk, BaseModel]]: returns the response in chunks.
- If response_model is present, it will return the partial model object
- """
-
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- prompt_template = self._prepare_prompt_template(
- prompt_template=prompt_template,
- messages=messages,
- response_model_parser=response_model_parser,
- )
-
- if input is None:
- input = {}
-
- prompt_template.partial_variables = {
- **prompt_template.partial_variables,
- **input,
- }
-
- best_llm, session_id = model_select(
- prompt_template=prompt_template,
- llm_providers=self.llm_providers,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- )
-
- if not best_llm:
- best_llm = self.default_llm_provider
-
- if best_llm is None:
- error_message = (
- "ND couldn't find a suitable model to call."
- + "To avoid disruptions, we recommend setting a default fallback model or make max depth larger."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- if best_llm.system_prompt is not None:
- prompt_template = prompt_template.inject_system_prompt(
- best_llm.system_prompt
- )
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_provider(best_llm, callbacks=self.callbacks)
- if self.tools:
- llm = llm.bind_tools(self.tools)
-
- if response_model is not None:
- chain = llm | response_model_parser
- else:
- chain = llm
-
- for chunk in chain.stream(prompt_template.format(), **kwargs):
- if response_model is None:
- yield chunk
- else:
- partial_model = create_partial_model(response_model)
- yield partial_model(**chunk)
-
-
-
-[docs]
- async def astream(
- self,
- prompt_template: Optional[
- Union[
- NDPromptTemplate,
- PromptTemplate,
- NDChatPromptTemplate,
- ChatPromptTemplate,
- str,
- ]
- ] = None,
- messages: Optional[List[Dict[str, str]]] = None,
- input: Optional[Dict[str, Any]] = None,
- metric: NDMetric = NDMetric("accuracy"),
- response_model: Optional[Type[BaseModel]] = None,
- **kwargs,
- ) -> AsyncIterator[Union[BaseMessageChunk, BaseModel]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and calls the LLM client side to stream the response. The function is async, so it's suitable for async codebases.
-
- Parameters:
- prompt_template (Optional(Union[ NDPromptTemplate, PromptTemplate, NDChatPromptTemplate, ChatPromptTemplate, str, ])):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no variables.
- metric (NDMetric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to NDMetric("accuracy").
- response_model (Optional[Type[BaseModel]], optional): If present, will use JsonOutputParser to parse the response into
- the given model. In which case result will a dict.
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Raises:
- ApiError: when the NotDiamond API fails
-
- Yields:
- AsyncIterator[Union[BaseMessageChunk, BaseModel]]: returns the response in chunks.
- If response_model is present, it will return the partial model object
- """
-
- response_model_parser = None
- if response_model is not None:
- self.verify_against_response_model()
- response_model_parser = JsonOutputParser(
- pydantic_object=response_model
- )
-
- prompt_template = self._prepare_prompt_template(
- prompt_template=prompt_template,
- messages=messages,
- response_model_parser=response_model_parser,
- )
- best_llm, session_id = await amodel_select(
- prompt_template=prompt_template,
- llm_providers=self.llm_providers,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- )
-
- if input is None:
- input = {}
-
- prompt_template.partial_variables = {
- **prompt_template.partial_variables,
- **input,
- }
-
- if not best_llm:
- best_llm = self.default_llm_provider
-
- if best_llm is None:
- error_message = (
- "ND couldn't find a suitable model to call."
- + "To avoid disruptions, we recommend setting a default fallback model or make max depth larger."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- if best_llm.system_prompt is not None:
- prompt_template = prompt_template.inject_system_prompt(
- best_llm.system_prompt
- )
-
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- llm = self._llm_from_provider(best_llm, callbacks=self.callbacks)
- if self.tools:
- llm = llm.bind_tools(self.tools)
-
- if response_model is not None:
- chain = llm | response_model_parser
- else:
- chain = llm
-
- async for chunk in chain.astream(prompt_template.format(), **kwargs):
- if response_model is None:
- yield chunk
- else:
- partial_model = create_partial_model(response_model)
- yield partial_model(**chunk)
-
-
-
-[docs]
- async def amodel_select(
- self,
- messages: Optional[List[Dict[str, str]]] = None,
- prompt_template: Optional[
- Union[
- NDPromptTemplate,
- PromptTemplate,
- NDChatPromptTemplate,
- ChatPromptTemplate,
- str,
- ]
- ] = None,
- input: Optional[Dict[str, Any]] = None,
- model: Optional[List[NDLLMProvider]] = None,
- default: Optional[Union[NDLLMProvider, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: NDMetric = NDMetric("accuracy"),
- **kwargs,
- ) -> tuple[str, Optional[NDLLMProvider]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and leaves the execution of the LLM call to the developer.
- The function is async, so it's suitable for async codebases.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- prompt_template (Union[ NDPromptTemplate, PromptTemplate, NDChatPromptTemplate, ChatPromptTemplate, str, ]):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no variables.
- model (Optional[List[NDLLMProvider]]): List of models to choose from.
- default (Optional[Union[NDLLMProvider, int, str]]): Default LLM provider.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (NDMetric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to NDMetric("accuracy").
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Returns:
- tuple[str, Optional[NDLLMProvider]]: returns the session_id and the chosen LLM provider
- """
- prompt_template = self._prepare_prompt_template(
- prompt_template,
- messages,
- )
-
- if input is None:
- input = {}
-
- prompt_template.partial_variables = {
- **prompt_template.partial_variables,
- **input,
- }
-
- if model is not None:
- llm_providers = self._parse_llm_providers_data(model)
- self.llm_providers = llm_providers
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- best_llm, session_id = await amodel_select(
- prompt_template=prompt_template,
- llm_providers=self.llm_providers,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- )
-
- if not best_llm and self.default is not None:
- print("ND API error. Falling back to default provider.")
- best_llm = self.default_llm_provider
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- return session_id, best_llm
-
-
-
-[docs]
- def model_select(
- self,
- messages: Optional[List[Dict[str, str]]] = None,
- prompt_template: Optional[
- Union[
- NDPromptTemplate,
- PromptTemplate,
- NDChatPromptTemplate,
- ChatPromptTemplate,
- str,
- ]
- ] = None,
- input: Optional[Dict[str, Any]] = None,
- model: Optional[List[NDLLMProvider]] = None,
- default: Optional[Union[NDLLMProvider, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- metric: NDMetric = NDMetric("accuracy"),
- **kwargs,
- ) -> tuple[str, Optional[NDLLMProvider]]:
- """
- This function calls the NotDiamond backend to fetch the most suitable model for the given prompt,
- and leaves the execution of the LLM call to the developer.
-
- Parameters:
- messages (Optional[List[Dict[str, str]], optional): Can be used instead of prompt_template to pass
- the messages OpenAI style.
- prompt_template (Union[ NDPromptTemplate, PromptTemplate, NDChatPromptTemplate, ChatPromptTemplate, str, ]):
- the prompt template defined by the user. It also supports Langchain prompt template types.
- input (Optional[Dict[str, Any]], optional): If the prompt_template contains variables, use input to specify
- the values for those variables. Defaults to None, assuming no variables.
- model (Optional[List[NDLLMProvider]]): List of models to choose from.
- default (Optional[Union[NDLLMProvider, int, str]]): Default LLM provider.
- max_model_depth (Optional[int]): If your top recommended model is down, specify up to which depth
- of routing you're willing to go.
- latency_tracking (Optional[bool]): Latency tracking flag.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str]): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- metric (NDMetric, optional): Metric used by NotDiamond router to choose the best LLM.
- Defaults to NDMetric("accuracy").
- **kwargs: Any other arguments that are supported by Langchain's invoke method, will be passed through.
-
- Returns:
- tuple[str, Optional[NDLLMProvider]]: returns the session_id and the chosen LLM provider
- """
- prompt_template = self._prepare_prompt_template(
- prompt_template,
- messages,
- )
-
- if input is None:
- input = {}
-
- prompt_template.partial_variables = {
- **prompt_template.partial_variables,
- **input,
- }
-
- if model is not None:
- llm_providers = self._parse_llm_providers_data(model)
- self.llm_providers = llm_providers
-
- self.validate_params(
- default=default,
- max_model_depth=max_model_depth,
- latency_tracking=latency_tracking,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- )
-
- best_llm, session_id = model_select(
- prompt_template=prompt_template,
- llm_providers=self.llm_providers,
- metric=metric,
- notdiamond_api_key=self.api_key,
- max_model_depth=self.max_model_depth,
- hash_content=self.hash_content,
- tradeoff=self.tradeoff,
- preference_id=self.preference_id,
- tools=self.tools,
- )
-
- if not best_llm and self.default is not None:
- print("ND API error. Falling back to default provider.")
- best_llm = self.default_llm_provider
- self.call_callbacks("on_model_select", best_llm, best_llm.model)
-
- return session_id, best_llm
-
-
- async def _async_invoke_with_latency_tracking(
- self,
- session_id: str,
- chain: Any,
- llm_provider: NDLLMProvider,
- input: Optional[Dict[str, Any]] = {},
- is_default: bool = True,
- **kwargs,
- ):
- if session_id in ("NO-SESSION-ID", "") and not is_default:
- error_message = (
- "ND session_id is not valid for latency tracking."
- + "Please check the API response."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- start_time = time.time()
-
- result = await chain.ainvoke(input, **kwargs)
-
- end_time = time.time()
-
- if isinstance(result, str):
- result = AIMessage(content=result)
-
- tokens_completed = token_counter(
- model=llm_provider.model,
- messages=[{"role": "assistant", "content": result.content}],
- )
- tokens_per_second = tokens_completed / (end_time - start_time)
-
- report_latency(
- session_id=session_id,
- llm_provider=llm_provider,
- tokens_per_second=tokens_per_second,
- notdiamond_api_key=self.api_key,
- )
- self.call_callbacks(
- "on_latency_tracking", session_id, llm_provider, tokens_per_second
- )
-
- return result
-
- def _invoke_with_latency_tracking(
- self,
- session_id: str,
- chain: Any,
- llm_provider: NDLLMProvider,
- input: Optional[Dict[str, Any]] = {},
- is_default: bool = True,
- **kwargs,
- ):
- if session_id in ("NO-SESSION-ID", "") and not is_default:
- error_message = (
- "ND session_id is not valid for latency tracking."
- + "Please check the API response."
- )
- self.call_callbacks("on_api_error", error_message)
- raise ApiError(error_message)
-
- start_time = time.time()
- result = chain.invoke(input, **kwargs)
- end_time = time.time()
-
- if isinstance(result, str):
- result = AIMessage(content=result)
-
- tokens_completed = token_counter(
- model=llm_provider.model,
- messages=[{"role": "assistant", "content": result.content}],
- )
- tokens_per_second = tokens_completed / (end_time - start_time)
-
- report_latency(
- session_id=session_id,
- llm_provider=llm_provider,
- tokens_per_second=tokens_per_second,
- notdiamond_api_key=self.api_key,
- )
- self.call_callbacks(
- "on_latency_tracking", session_id, llm_provider, tokens_per_second
- )
-
- return result
-
- @staticmethod
- def _llm_from_provider(
- provider: NDLLMProvider,
- callbacks: Optional[
- List[Union[BaseCallbackHandler, NDLLMBaseCallbackHandler]]
- ],
- ) -> Any:
- default_kwargs = {"max_retries": 5, "timeout": 120}
- passed_kwargs = {**default_kwargs, **provider.kwargs}
-
- if provider.provider == "openai":
- return ChatOpenAI(
- openai_api_key=provider.api_key,
- model_name=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "anthropic":
- return ChatAnthropic(
- anthropic_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "google":
- return ChatGoogleGenerativeAI(
- google_api_key=provider.api_key,
- model=provider.model,
- convert_system_message_to_human=True,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "cohere":
- return ChatCohere(
- cohere_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "mistral":
- return ChatMistralAI(
- mistral_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "togetherai":
- provider_settings = settings.PROVIDERS.get(provider.provider, None)
- model_prefixes = provider_settings.get("model_prefix", None)
- model_prefix = model_prefixes.get(provider.model, None)
- del passed_kwargs["max_retries"]
- del passed_kwargs["timeout"]
-
- if model_prefix is not None:
- model = f"{model_prefix}/{provider.model}"
- return Together(
- together_api_key=provider.api_key,
- model=model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "perplexity":
- del passed_kwargs["max_retries"]
- passed_kwargs["request_timeout"] = passed_kwargs["timeout"]
- del passed_kwargs["timeout"]
- return ChatPerplexity(
- pplx_api_key=provider.api_key,
- model=provider.model,
- callbacks=callbacks,
- **passed_kwargs,
- )
- if provider.provider == "replicate":
- provider_settings = settings.PROVIDERS.get(provider.provider, None)
- model_prefixes = provider_settings.get("model_prefix", None)
- model_prefix = model_prefixes.get(provider.model, None)
- passed_kwargs["request_timeout"] = passed_kwargs["timeout"]
- del passed_kwargs["timeout"]
-
- if model_prefix is not None:
- model = f"replicate/{model_prefix}/{provider.model}"
- return ChatLiteLLM(
- model=model,
- callbacks=callbacks,
- replicate_api_key=provider.api_key,
- **passed_kwargs,
- )
- raise ValueError(f"Unsupported provider: {provider.provider}")
-
- @staticmethod
- def _prepare_prompt_template(
- prompt_template, messages=None, response_model_parser=None
- ) -> Union[NDPromptTemplate, NDChatPromptTemplate]:
- resulting_prompt_template = None
- if prompt_template is not None and messages is not None:
- print(
- "Warning: prompt_template value is overriding messages value. Set one of those values for optimal performance."
- )
- if prompt_template is not None:
- if isinstance(prompt_template, NDPromptTemplate) or isinstance(
- prompt_template, NDChatPromptTemplate
- ):
- resulting_prompt_template = prompt_template
- elif isinstance(prompt_template, str):
- resulting_prompt_template = NDPromptTemplate(
- template=prompt_template
- )
- elif isinstance(prompt_template, StringPromptValue):
- resulting_prompt_template = NDChatPromptTemplate.from_messages(
- prompt_template.to_messages()
- )
- elif isinstance(prompt_template, PromptTemplate):
- resulting_prompt_template = (
- NDPromptTemplate.from_langchain_prompt_template(
- prompt_template
- )
- )
- elif isinstance(prompt_template, ChatPromptTemplate):
- resulting_prompt_template = (
- NDChatPromptTemplate.from_langchain_chat_prompt_template(
- prompt_template
- )
- )
- elif isinstance(prompt_template, list):
- if all(isinstance(pt, BaseMessage) for pt in prompt_template):
- resulting_prompt_template = (
- NDChatPromptTemplate.from_messages(prompt_template)
- )
- if resulting_prompt_template is None:
- raise ValueError(
- f"Unsupported prompt_template type {type(prompt_template)}"
- )
- if messages is not None:
- resulting_prompt_template = (
- NDChatPromptTemplate.from_openai_messages(messages)
- )
-
- if resulting_prompt_template is None:
- raise ValueError("prompt_template or messages must be specified.")
-
- if response_model_parser is not None:
- resulting_prompt_template = (
- resulting_prompt_template.inject_model_instruction(
- response_model_parser
- )
- )
-
- return resulting_prompt_template
-
-
-[docs]
- def bind_tools(
- self,
- tools: Sequence[Union[Dict[str, Any], Callable]],
- ) -> "NDLLM":
- """
- Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it.
- Results in the tools being available in the LLM object.
- You can access the tool_calls in the result via `result.tool_calls`.
- """
-
- for provider in self.llm_providers:
- if provider.model not in settings.PROVIDERS[provider.provider].get(
- "support_tools", []
- ):
- raise ApiError(
- f"{provider.provider}/{provider.model} does not support function calling."
- )
- self.tools = tools
-
- return self
-
-
-
-[docs]
- def call_callbacks(self, function_name: str, *args, **kwargs) -> None:
- """
- Call all callbacks with a specific function name.
- """
-
- if self.callbacks is None:
- return
-
- for callback in self.callbacks:
- if hasattr(callback, function_name):
- getattr(callback, function_name)(*args, **kwargs)
-
-
-
-[docs]
- def verify_against_response_model(self) -> bool:
- """
- Verify that the LLM providers support response modeling.
- """
-
- for provider in self.llm_providers:
- if provider.model not in settings.PROVIDERS[provider.provider].get(
- "support_response_model", []
- ):
- raise ApiError(
- f"{provider.provider}/{provider.model} does not support response modeling."
- )
-
- return True
-
-
-
-[docs]
- def validate_params(
- self,
- default: Optional[Union[NDLLMProvider, int, str]] = None,
- max_model_depth: Optional[int] = None,
- latency_tracking: Optional[bool] = None,
- hash_content: Optional[bool] = None,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- ):
- if default is not None:
- self.default = default
-
- if max_model_depth is not None:
- self.max_model_depth = max_model_depth
-
- if self.llm_providers is None or len(self.llm_providers) == 0:
- raise MissingLLMProviders(
- "No LLM provider speficied. Specify at least one."
- )
-
- if self.max_model_depth is None:
- self.max_model_depth = len(self.llm_providers)
-
- if self.max_model_depth == 0:
- raise ValueError("max_model_depth has to be bigger than 0.")
-
- if self.max_model_depth > len(self.llm_providers):
- print(
- "WARNING: max_model_depth cannot be bigger than the number of LLM providers."
- )
- self.max_model_depth = len(self.llm_providers)
-
- if tradeoff is not None:
- if tradeoff not in ["cost", "latency"]:
- raise ValueError(
- "Invalid tradeoff. Accepted values: cost, latency."
- )
- self.tradeoff = tradeoff
-
- if preference_id is not None:
- self.preference_id = preference_id
-
- if latency_tracking is not None:
- self.latency_tracking = latency_tracking
-
- if hash_content is not None:
- self.hash_content = hash_content
-
-
-
-from typing import Optional
-
-from notdiamond import settings
-from notdiamond.exceptions import UnsupportedLLMProvider
-
-POSSIBLE_PROVIDERS = list(settings.PROVIDERS.keys())
-POSSIBLE_MODELS = list(
- model
- for provider_values in settings.PROVIDERS.values()
- for values in provider_values.values()
- if isinstance(values, list)
- for model in values
-)
-
-
-
-[docs]
-class NDLLMProvider:
- """
- An NDLLM provider is represented by a combination of provider and model.
- Provider refers to the company of the foundational model, such as openai, anthropic, google.
- The model represents the model name as defined by the owner company, such as gpt-3.5-turbo
- Beside this you can also specify the API key for each provider, specify extra arguments
- that are also supported by Langchain (eg. temperature), and a system prmopt to be used
- with the provider. If the provider is selected during routing, then the system prompt will
- be used, replacing the one in the message array if there are any.
-
- All supported providers and models can be found in our docs.
-
- If the API key it's not specified, it will try to pick it up from an .env file before failing.
- As example for OpenAI it will look for OPENAI_API_KEY.
-
- Attributes:
- provider (str): The name of the LLM provider (e.g., "openai", "anthropic"). Must be one of the
- predefined providers in `POSSIBLE_PROVIDERS`.
- model (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
- Must be one of the predefined models in `POSSIBLE_MODELS`.
- system_prompt (Optional[str], optional): The system prompt to use for the provider. Defaults to None.
- api_key (Optional[str], optional): The API key for accessing the LLM provider's services.
- Defaults to None, in which case it tries to fetch from the settings.
- openrouter_model (str): The OpenRouter model equivalent for this provider / model
- **kwargs: Additional keyword arguments that might be necessary for specific providers or models.
-
- Raises:
- UnsupportedLLMProvider: If the `provider` or `model` specified is not supported.
- """
-
-
-[docs]
- def __init__(
- self,
- provider: str,
- model: str,
- system_prompt: Optional[str] = None,
- api_key: Optional[str] = None,
- **kwargs,
- ):
- """_summary_
-
- Args:
- provider (str): The name of the LLM provider (e.g., "openai", "anthropic").
- model (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
- system_prompt (Optional[str], optional): The system prompt to use for the provider. Defaults to None.
- api_key (Optional[str], optional): The API key for accessing the LLM provider's services.
- Defaults to None.
- **kwargs: Additional keyword arguments that might be necessary for specific providers or models.
-
- Raises:
- UnsupportedLLMProvider: If the `provider` or `model` specified is not supported.
- """
- if provider not in POSSIBLE_PROVIDERS:
- raise UnsupportedLLMProvider(
- f"Given LLM provider {provider} is not in the list of supported providers."
- )
- if model not in POSSIBLE_MODELS:
- raise UnsupportedLLMProvider(
- f"Given LLM model {model} is not in the list of supported models."
- )
-
- self.provider = provider
- self.model = model
- self.system_prompt = system_prompt
- self._openrouter_model = settings.PROVIDERS[provider][
- "openrouter_identifier"
- ].get(model, None)
- self.api_key = (
- api_key
- if api_key is not None
- else settings.PROVIDERS[provider]["api_key"]
- )
- self.kwargs = kwargs
-
-
- def __repr__(self) -> str:
- return f"{self.provider}/{self.model}"
-
- @property
- def openrouter_model(self):
- if self._openrouter_model is None:
- print("WARNING: this model is not available via OpenRouter")
- return self._openrouter_model
-
-
-[docs]
- def prepare_for_request(self):
- """
- Converts the NDLLMProvider object to a dict in the format accepted by
- the NotDiamond API.
-
- Returns:
- dict
- """
- return {"provider": self.provider, "model": self.model}
-
-
-
-[docs]
- def set_api_key(self, api_key: str) -> "NDLLMProvider":
- self.api_key = api_key
-
- return self
-
-
-
-[docs]
- @classmethod
- def from_string(cls, llm_provider: str):
- """
- We allow our users to specify LLM providers for NDLLM in the string format 'provider_name/model_name',
- as example 'openai/gpt-3.5-turbo'. Underlying our workflows we want to ensure we use NDLLMProvider as
- the base type, so this class method converts a string specification of an LLM provider into an
- NDLLMProvider object.
-
- Args:
- llm_provider (str): this is the string definition of the LLM provider
-
- Returns:
- NDLLMProvider: initialized object with correct provider and model
- """
- split_items = llm_provider.split("/")
- provider = split_items[0]
- model = split_items[1]
- return cls(provider=provider, model=model)
-
-
-
-from enum import Enum
-
-from notdiamond.llms.config import LLMConfig
-
-
-
-[docs]
-class NDLLMProviders(Enum):
- """
- NDLLMProviders serves as a registry for the supported LLM models by NotDiamond.
- It allows developers to easily specify available LLM providers for the router.
-
- Attributes:
- GPT_3_5_TURBO (NDLLMProvider): refers to 'gpt-3.5-turbo' model by OpenAI
- GPT_3_5_TURBO_0125 (NDLLMProvider): refers to 'gpt-3.5-turbo-0125' model by OpenAI
- GPT_4 (NDLLMProvider): refers to 'gpt-4' model by OpenAI
- GPT_4_0613 (NDLLMProvider): refers to 'gpt-4-0613' model by OpenAI
- GPT_4_1106_PREVIEW (NDLLMProvider): refers to 'gpt-4-1106-preview' model by OpenAI
- GPT_4_TURBO (NDLLMProvider): refers to 'gpt-4-turbo' model by OpenAI
- GPT_4_TURBO_PREVIEW (NDLLMProvider): refers to 'gpt-4-turbo-preview' model by OpenAI
- GPT_4_TURBO_2024_04_09 (NDLLMProvider): refers to 'gpt-4-turbo-2024-04-09' model by OpenAI
- GPT_4o_2024_05_13 (NDLLMProvider): refers to 'gpt-4o-2024-05-13' model by OpenAI
- GPT_4o_2024_08_06 (NDLLMProvider): refers to 'gpt-4o-2024-08-06' model by OpenAI
- GPT_4o (NDLLMProvider): refers to 'gpt-4o' model by OpenAI
- GPT_4o_MINI_2024_07_18 (NDLLMProvider): refers to 'gpt-4o-mini-2024-07-18' model by OpenAI
- GPT_4o_MINI (NDLLMProvider): refers to 'gpt-4o-mini' model by OpenAI
- GPT_4_0125_PREVIEW (NDLLMProvider): refers to 'gpt-4-0125-preview' model by OpenAI
- GPT_4_1 (NDLLMProvider): refers to 'gpt-4.1' model by OpenAI
- GPT_4_1_2025_04_14 (NDLLMProvider): refers to 'gpt-4.1-2025-04-14' model by OpenAI
- GPT_4_1_MINI (NDLLMProvider): refers to 'gpt-4.1-mini' model by OpenAI
- GPT_4_1_MINI_2025_04_14 (NDLLMProvider): refers to 'gpt-4.1-mini-2025-04-14' model by OpenAI
- GPT_4_1_NANO (NDLLMProvider): refers to 'gpt-4.1-nano' model by OpenAI
- GPT_4_1_NANO_2025_04_14 (NDLLMProvider): refers to 'gpt-4.1-nano-2025-04-14' model by OpenAI
- O1_PREVIEW (NDLLMProvider): refers to 'o1-preview' model by OpenAI
- O1_PREVIEW_2024_09_12 (NDLLMProvider): refers to 'o1-preview-2024-09-12' model by OpenAI
- O1_MINI (NDLLMProvider): refers to 'o1-mini' model by OpenAI
- O1_MINI_2024_09_12 (NDLLMProvider): refers to 'o1-mini-2024-09-12' model by OpenAI
-
- CLAUDE_2_1 (NDLLMProvider): refers to 'claude-2.1' model by Anthropic
- CLAUDE_3_OPUS_20240229 (NDLLMProvider): refers to 'claude-3-opus-20240229' model by Anthropic
- CLAUDE_3_SONNET_20240229 (NDLLMProvider): refers to 'claude-3-sonnet-20240229' model by Anthropic
- CLAUDE_3_5_SONNET_20240620 (NDLLMProvider): refers to 'claude-3-5-sonnet-20240620' model by Anthropic
- CLAUDE_3_7_SONNET_LATEST (NDLLMProvider): refers to 'claude-3-7-sonnet-latest' model by Anthropic
- CLAUDE_3_7_SONNET_20250219 (NDLLMProvider): refers to 'claude-3-7-sonnet-20250219' model by Anthropic
- CLAUDE_3_5_HAIKU_20241022 (NDLLMProvider): refers to 'claude-3-5-haiku-20241022' model by Anthropic
- CLAUDE_3_HAIKU_20240307 (NDLLMProvider): refers to 'claude-3-haiku-20240307' model by Anthropic
- CLAUDE_OPUS_4_20250514 (NDLLMProvider): refers to 'claude-opus-4-20250514' model by Anthropic
- CLAUDE_SONNET_4_20250514 (NDLLMProvider): refers to 'claude-sonnet-4-20250514' model by Anthropic
- CLAUDE_OPUS_4_0 (NDLLMProvider): refers to 'claude-opus-4-0' model by Anthropic
- CLAUDE_SONNET_4_0 (NDLLMProvider): refers to 'claude-sonnet-4-0' model by Anthropic
-
- GEMINI_PRO (NDLLMProvider): refers to 'gemini-pro' model by Google
- GEMINI_1_PRO_LATEST (NDLLMProvider): refers to 'gemini-1.0-pro-latest' model by Google
- GEMINI_15_PRO_LATEST (NDLLMProvider): refers to 'gemini-1.5-pro-latest' model by Google
- GEMINI_15_PRO_EXP_0801 (NDLLMProvider): refers to 'gemini-1.5-pro-exp-0801' model by Google
- GEMINI_15_FLASH_LATEST (NDLLMProvider): refers to 'gemini-1.5-flash-latest' model by Google
- GEMINI_20_FLASH (NDLLMProvider): refers to 'gemini-20-flash' model by Google
- GEMINI_20_FLASH_001 (NDLLMProvider): refers to 'gemini-20-flash-001' model by Google
- GEMINI_25_FLASH (NDLLMProvider): refers to 'gemini-25-flash' model by Google
- GEMINI_25_PRO (NDLLMProvider): refers to 'gemini-25-pro' model by Google
-
- COMMAND_R (NDLLMProvider): refers to 'command-r' model by Cohere
- COMMAND_R_PLUS (NDLLMProvider): refers to 'command-r-plus' model by Cohere
-
- MISTRAL_LARGE_LATEST (NDLLMProvider): refers to 'mistral-large-latest' model by Mistral AI
- MISTRAL_LARGE_2407 (NDLLMProvider): refers to 'mistral-large-2407' model by Mistral AI
- MISTRAL_LARGE_2402 (NDLLMProvider): refers to 'mistral-large-2402' model by Mistral AI
- MISTRAL_MEDIUM_LATEST (NDLLMProvider): refers to 'mistral-medium-latest' model by Mistral AI
- MISTRAL_SMALL_LATEST (NDLLMProvider): refers to 'mistral-small-latest' model by Mistral AI
- OPEN_MISTRAL_7B (NDLLMProvider): refers to 'open-mistral-7b' model by Mistral AI
- OPEN_MIXTRAL_8X7B (NDLLMProvider): refers to 'open-mixtral-8x7b' model by Mistral AI
- OPEN_MIXTRAL_8X22B (NDLLMProvider): refers to 'open-mixtral-8x22b' model by Mistral AI
- OPEN_MISTRAL_NEMO (NDLLMProvider): refers to 'open-mistral-nemo' model by Mistral AI
-
- TOGETHER_MISTRAL_7B_INSTRUCT_V0_2 (NDLLMProvider): refers to 'Mistral-7B-Instruct-v0.2' model served via TogetherAI
- TOGETHER_MIXTRAL_8X7B_INSTRUCT_V0_1 (NDLLMProvider): refers to 'Mixtral-8x7B-Instruct-v0.1' model served via TogetherAI
- TOGETHER_MIXTRAL_8X22B_INSTRUCT_V0_1 (NDLLMProvider): refers to 'Mixtral-8x22B-Instruct-v0.1' model served via TogetherAI
- TOGETHER_LLAMA_3_70B_CHAT_HF (NDLLMProvider): refers to 'Llama-3-70b-chat-hf' model served via TogetherAI
- TOGETHER_LLAMA_3_8B_CHAT_HF (NDLLMProvider): refers to 'Llama-3-8b-chat-hf' model served via TogetherAI
- TOGETHER_QWEN2_72B_INSTRUCT (NDLLMProvider): refers to 'Qwen2-72B-Instruct' model served via TogetherAI
- TOGETHER_LLAMA_3_1_8B_INSTRUCT_TURBO (NDLLMProvider): refers to 'Meta-Llama-3.1-8B-Instruct-Turbo'
- model served via TogetherAI
- TOGETHER_LLAMA_3_1_70B_INSTRUCT_TURBO (NDLLMProvider): refers to 'Meta-Llama-3.1-70B-Instruct-Turbo'
- model served via TogetherAI
- TOGETHER_LLAMA_3_1_405B_INSTRUCT_TURBO (NDLLMProvider): refers to 'Meta-Llama-3.1-405B-Instruct-Turbo'
- model served via TogetherAI
- TOGETHER_DEEPSEEK_R1 (NDLLMProvider): refers to 'DeepSeek-R1'
- model served via TogetherAI
-
- REPLICATE_MISTRAL_7B_INSTRUCT_V0_2 (NDLLMProvider): refers to "mistral-7b-instruct-v0.2" model served via Replicate
- REPLICATE_MIXTRAL_8X7B_INSTRUCT_V0_1 (NDLLMProvider): refers to "mixtral-8x7b-instruct-v0.1" model served via Replicate
- REPLICATE_META_LLAMA_3_70B_INSTRUCT (NDLLMProvider): refers to "meta-llama-3-70b-instruct" model served via Replicate
- REPLICATE_META_LLAMA_3_8B_INSTRUCT (NDLLMProvider): refers to "meta-llama-3-8b-instruct" model served via Replicate
- REPLICATE_META_LLAMA_3_1_405B_INSTRUCT (NDLLMProvider): refers to "meta-llama-3.1-405b-instruct"
- model served via Replicate
-
- SONAR (NDLLMProvider): refers to "sonar" model by Perplexity
- """
-
- GPT_3_5_TURBO = ("openai", "gpt-3.5-turbo")
- GPT_3_5_TURBO_0125 = ("openai", "gpt-3.5-turbo-0125")
- GPT_4 = ("openai", "gpt-4")
- GPT_4_0613 = ("openai", "gpt-4-0613")
- GPT_4_1106_PREVIEW = ("openai", "gpt-4-1106-preview")
- GPT_4_TURBO = ("openai", "gpt-4-turbo")
- GPT_4_TURBO_PREVIEW = ("openai", "gpt-4-turbo-preview")
- GPT_4_TURBO_2024_04_09 = ("openai", "gpt-4-turbo-2024-04-09")
- GPT_4o_2024_05_13 = ("openai", "gpt-4o-2024-05-13")
- GPT_4o_2024_08_06 = ("openai", "gpt-4o-2024-08-06")
- GPT_4o = ("openai", "gpt-4o")
- GPT_4o_MINI_2024_07_18 = ("openai", "gpt-4o-mini-2024-07-18")
- GPT_4o_MINI = ("openai", "gpt-4o-mini")
- GPT_4_0125_PREVIEW = ("openai", "gpt-4-0125-preview")
- GPT_4_1 = ("openai", "gpt-4.1")
- GPT_4_1_2025_04_14 = ("openai", "gpt-4.1-2025-04-14")
- GPT_4_1_MINI = ("openai", "gpt-4.1-mini")
- GPT_4_1_MINI_2025_04_14 = ("openai", "gpt-4.1-mini-2025-04-14")
- GPT_4_1_NANO = ("openai", "gpt-4.1-nano")
- GPT_4_1_NANO_2025_04_14 = ("openai", "gpt-4.1-nano-2025-04-14")
- O1_PREVIEW = ("openai", "o1-preview")
- O1_PREVIEW_2024_09_12 = ("openai", "o1-preview-2024-09-12")
- O1_MINI = ("openai", "o1-mini")
- O1_MINI_2024_09_12 = ("openai", "o1-mini-2024-09-12")
- CHATGPT_4o_LATEST = ("openai", "chatgpt-4o-latest")
-
- CLAUDE_2_1 = ("anthropic", "claude-2.1")
- CLAUDE_3_OPUS_20240229 = ("anthropic", "claude-3-opus-20240229")
- CLAUDE_3_SONNET_20240229 = ("anthropic", "claude-3-sonnet-20240229")
- CLAUDE_3_5_SONNET_20240620 = ("anthropic", "claude-3-5-sonnet-20240620")
- CLAUDE_3_5_SONNET_20241022 = ("anthropic", "claude-3-5-sonnet-20241022")
- CLAUDE_3_5_SONNET_LATEST = ("anthropic", "claude-3-5-sonnet-latest")
- CLAUDE_3_7_SONNET_LATEST = ("anthropic", "claude-3-7-sonnet-latest")
- CLAUDE_3_7_SONNET_20250219 = ("anthropic", "claude-3-7-sonnet-20250219")
- CLAUDE_3_5_HAIKU_20241022 = ("anthropic", "claude-3-5-haiku-20241022")
- CLAUDE_3_HAIKU_20240307 = ("anthropic", "claude-3-haiku-20240307")
- CLAUDE_OPUS_4_20250514 = ("anthropic", "claude-opus-4-20250514")
- CLAUDE_SONNET_4_20250514 = ("anthropic", "claude-sonnet-4-20250514")
- CLAUDE_OPUS_4_0 = ("anthropic", "claude-opus-4-0")
- CLAUDE_SONNET_4_0 = ("anthropic", "claude-sonnet-4-0")
-
- GEMINI_PRO = ("google", "gemini-pro")
- GEMINI_1_PRO_LATEST = ("google", "gemini-1.0-pro-latest")
- GEMINI_15_PRO_LATEST = ("google", "gemini-1.5-pro-latest")
- GEMINI_15_PRO_EXP_0801 = ("google", "gemini-1.5-pro-exp-0801")
- GEMINI_15_FLASH_LATEST = ("google", "gemini-1.5-flash-latest")
- GEMINI_20_FLASH = ("google", "gemini-2.0-flash")
- GEMINI_20_FLASH_001 = ("google", "gemini-2.0-flash-001")
- GEMINI_25_FLASH = ("google", "gemini-2.5-flash")
- GEMINI_25_PRO = ("google", "gemini-2.5-pro")
-
- COMMAND_R = ("cohere", "command-r")
- COMMAND_R_PLUS = ("cohere", "command-r-plus")
-
- MISTRAL_LARGE_LATEST = ("mistral", "mistral-large-latest")
- MISTRAL_LARGE_2407 = ("mistral", "mistral-large-2407")
- MISTRAL_LARGE_2402 = ("mistral", "mistral-large-2402")
- MISTRAL_MEDIUM_LATEST = ("mistral", "mistral-medium-latest")
- MISTRAL_SMALL_LATEST = ("mistral", "mistral-small-latest")
- CODESTRAL_LATEST = ("mistral", "codestral-latest")
- OPEN_MISTRAL_7B = ("mistral", "open-mistral-7b")
- OPEN_MIXTRAL_8X7B = ("mistral", "open-mixtral-8x7b")
- OPEN_MIXTRAL_8X22B = ("mistral", "open-mixtral-8x22b")
- OPEN_MISTRAL_NEMO = ("mistral", "open-mistral-nemo")
-
- TOGETHER_MISTRAL_7B_INSTRUCT_V0_2 = (
- "togetherai",
- "Mistral-7B-Instruct-v0.2",
- )
- TOGETHER_MIXTRAL_8X7B_INSTRUCT_V0_1 = (
- "togetherai",
- "Mixtral-8x7B-Instruct-v0.1",
- )
- TOGETHER_MIXTRAL_8X22B_INSTRUCT_V0_1 = (
- "togetherai",
- "Mixtral-8x22B-Instruct-v0.1",
- )
- TOGETHER_LLAMA_3_70B_CHAT_HF = ("togetherai", "Llama-3-70b-chat-hf")
- TOGETHER_LLAMA_3_8B_CHAT_HF = ("togetherai", "Llama-3-8b-chat-hf")
- TOGETHER_QWEN2_72B_INSTRUCT = ("togetherai", "Qwen2-72B-Instruct")
- TOGETHER_LLAMA_3_1_8B_INSTRUCT_TURBO = (
- "togetherai",
- "Meta-Llama-3.1-8B-Instruct-Turbo",
- )
- TOGETHER_LLAMA_3_1_70B_INSTRUCT_TURBO = (
- "togetherai",
- "Meta-Llama-3.1-70B-Instruct-Turbo",
- )
- TOGETHER_LLAMA_3_1_405B_INSTRUCT_TURBO = (
- "togetherai",
- "Meta-Llama-3.1-405B-Instruct-Turbo",
- )
- TOGETHER_DEEPSEEK_R1 = ("togetherai", "DeepSeek-R1")
-
- SONAR = (
- "perplexity",
- "sonar",
- )
-
- REPLICATE_MISTRAL_7B_INSTRUCT_V0_2 = (
- "replicate",
- "mistral-7b-instruct-v0.2",
- )
- REPLICATE_MIXTRAL_8X7B_INSTRUCT_V0_1 = (
- "replicate",
- "mixtral-8x7b-instruct-v0.1",
- )
- REPLICATE_META_LLAMA_3_70B_INSTRUCT = (
- "replicate",
- "meta-llama-3-70b-instruct",
- )
- REPLICATE_META_LLAMA_3_8B_INSTRUCT = (
- "replicate",
- "meta-llama-3-8b-instruct",
- )
- REPLICATE_META_LLAMA_3_1_405B_INSTRUCT = (
- "replicate",
- "meta-llama-3.1-405b-instruct",
- )
-
- def __new__(cls, provider, model):
- return LLMConfig(provider=provider, model=model)
-
-
-
-
-[docs]
-def is_o1_model(llm: LLMConfig):
- return llm in (
- NDLLMProviders.O1_PREVIEW,
- NDLLMProviders.O1_PREVIEW_2024_09_12,
- NDLLMProviders.O1_MINI,
- NDLLMProviders.O1_MINI_2024_09_12,
- )
-
-
-import json
-import logging
-from typing import Any, Callable, Dict, List, Optional, Sequence, Union
-
-import aiohttp
-import requests
-
-from notdiamond import settings
-from notdiamond._utils import _default_headers, convert_tool_to_openai_function
-from notdiamond.llms.config import LLMConfig
-from notdiamond.metrics.metric import Metric
-from notdiamond.types import ModelSelectRequestPayload
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-
-
-[docs]
-def model_select_prepare(
- messages: List[Dict[str, str]],
- llm_configs: List[LLMConfig],
- metric: Metric,
- notdiamond_api_key: str,
- max_model_depth: int,
- hash_content: bool,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = [],
- previous_session: Optional[str] = None,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- _user_agent: str = settings.DEFAULT_USER_AGENT,
-):
- """
- This is the core method for the model_select endpoint.
- It returns the best fitting LLM to call and a session ID that can be used for feedback.
-
- Parameters:
- messages (List[Dict[str, str]]): list of messages to be used for the LLM call
- llm_configs (List[LLMConfig]): a list of available LLMs that the router can decide from
- metric (Metric): metric based off which the router makes the decision. As of now only 'accuracy' supported.
- notdiamond_api_key (str): API key generated via the NotDiamond dashboard.
- max_model_depth (int): if your top recommended model is down, specify up to which depth of routing you're willing to go.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str], optional): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- async_mode (bool, optional): whether to run the request in async mode. Defaults to False.
- nd_api_url (Optional[str], optional): The URL of the NotDiamond API. Defaults to None.
-
- Returns:
- tuple(url, payload, headers): returns data to be used for the API call of modelSelect
- """
- url = f"{nd_api_url}/v2/modelRouter/modelSelect"
- tools_dict = get_tools_in_openai_format(tools)
-
- payload: ModelSelectRequestPayload = {
- "messages": messages,
- "llm_providers": [
- llm_provider.prepare_for_request() for llm_provider in llm_configs
- ],
- "metric": metric.metric,
- "max_model_depth": max_model_depth,
- "hash_content": hash_content,
- }
-
- if tools_dict:
- payload["tools"] = tools_dict
- if tradeoff is not None:
- payload["tradeoff"] = tradeoff
- if preference_id is not None:
- payload["preference_id"] = preference_id
- if previous_session is not None:
- payload["previous_session"] = previous_session
-
- headers = _default_headers(notdiamond_api_key, _user_agent)
-
- return url, payload, headers
-
-
-
-
-[docs]
-def get_tools_in_openai_format(
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]],
-):
- """
- This function converts the tools list into the format that OpenAI expects.
- Does this by using langchains Model that automatically creates the dictionary on bind_tools
-
- Parameters:
- tools (Optional[Sequence[Union[Dict[str, Any], Callable]]]): list of tools to be converted
-
- Returns:
- dict: dictionary of tools in the format that OpenAI expects
- """
- if tools:
- return [
- {
- "type": "function",
- "function": convert_tool_to_openai_function(tool),
- }
- for tool in tools
- ]
-
- return None
-
-
-
-
-[docs]
-def model_select_parse(response_code, response_json, llm_configs):
- if response_code == 200:
- providers = response_json["providers"]
- session_id = response_json["session_id"]
-
- top_provider = providers[0]
-
- best_llm = list(
- filter(
- lambda x: (x.model == top_provider["model"])
- & (x.provider == top_provider["provider"]),
- llm_configs,
- )
- )[0]
- return best_llm, session_id
-
- error_message = response_json["detail"]
- LOGGER.error(f"API error: {response_code}. {error_message}")
- return None, "NO-SESSION-ID"
-
-
-
-
-[docs]
-def model_select(
- messages: List[Dict[str, str]],
- llm_configs: List[LLMConfig],
- metric: Metric,
- notdiamond_api_key: str,
- max_model_depth: int,
- hash_content: bool,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = [],
- previous_session: Optional[str] = None,
- timeout: Optional[Union[float, int]] = 60,
- max_retries: Optional[int] = 3,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- _user_agent: str = settings.DEFAULT_USER_AGENT,
-):
- """
- This endpoint receives the prompt and routing settings, and makes a call to the NotDiamond API.
- It returns the best fitting LLM to call and a session ID that can be used for feedback.
-
- Parameters:
- messages (List[Dict[str, str]]): list of messages to be used for the LLM call
- llm_configs (List[LLMConfig]): a list of available LLMs that the router can decide from
- metric (Metric): metric based off which the router makes the decision. As of now only 'accuracy' supported.
- notdiamond_api_key (str): API key generated via the NotDiamond dashboard.
- max_model_depth (int): if your top recommended model is down, specify up to which depth of routing you're willing to go.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str], optional): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- timeout (int, optional): timeout for the request. Defaults to 60.
- max_retries (int, optional): The maximum number of retries to make when calling the Not Diamond API.
- Defaults to 3.
- nd_api_url (Optional[str], optional): The URL of the NotDiamond API. Defaults to None.
- Returns:
- tuple(LLMConfig, string): returns a tuple of the chosen LLMConfig to call and a session ID string.
- In case of an error the LLM defaults to None and the session ID defaults
- to 'NO-SESSION-ID'.
- """
- url, payload, headers = model_select_prepare(
- messages=messages,
- llm_configs=llm_configs,
- metric=metric,
- notdiamond_api_key=notdiamond_api_key,
- max_model_depth=max_model_depth,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- tools=tools,
- previous_session=previous_session,
- nd_api_url=nd_api_url,
- _user_agent=_user_agent,
- )
-
- for n_retry in range(1, max_retries + 1):
- try:
- response = requests.post(
- url, data=json.dumps(payload), headers=headers, timeout=timeout
- )
- response_code = response.status_code
- response_json = response.json()
- break
- except Exception as e:
- LOGGER.error(
- f"Retry {n_retry} of {max_retries}: API error: {e}",
- exc_info=True,
- )
- if n_retry == max_retries:
- return None, "NO-SESSION-ID"
-
- best_llm, session_id = model_select_parse(
- response_code, response_json, llm_configs
- )
-
- return best_llm, session_id
-
-
-
-
-[docs]
-async def amodel_select(
- messages: List[Dict[str, str]],
- llm_configs: List[LLMConfig],
- metric: Metric,
- notdiamond_api_key: str,
- max_model_depth: int,
- hash_content: bool,
- tradeoff: Optional[str] = None,
- preference_id: Optional[str] = None,
- tools: Optional[Sequence[Union[Dict[str, Any], Callable]]] = [],
- previous_session: Optional[str] = None,
- timeout: Optional[Union[float, int]] = 60,
- max_retries: Optional[int] = 3,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- _user_agent: str = settings.DEFAULT_USER_AGENT,
-):
- """
- This endpoint receives the prompt and routing settings, and makes a call to the NotDiamond API.
- It returns the best fitting LLM to call and a session ID that can be used for feedback.
-
- Parameters:
- messages (List[Dict[str, str]]): list of messages to be used for the LLM call
- llm_configs (List[LLMConfig]): a list of available LLMs that the router can decide from
- metric (Metric): metric based off which the router makes the decision. As of now only 'accuracy' supported.
- notdiamond_api_key (str): API key generated via the NotDiamond dashboard.
- max_model_depth (int): if your top recommended model is down, specify up to which depth of routing you're willing to go.
- hash_content (Optional[bool]): Flag for hashing content before sending to NotDiamond API.
- tradeoff (Optional[str], optional): Define the "cost" or "latency" tradeoff
- for the router to determine the best LLM for a given query.
- preference_id (Optional[str], optional): The ID of the router preference that was configured via the Dashboard.
- Defaults to None.
- previous_session (Optional[str], optional): The session ID of a previous session, allow you to link requests.
- timeout (int, optional): timeout for the request. Defaults to 60.
- max_retries (int, optional): The maximum number of retries to make when calling the Not Diamond API.
- nd_api_url (Optional[str], optional): The URL of the NotDiamond API. Defaults to None.
- Returns:
- tuple(LLMConfig, string): returns a tuple of the chosen LLMConfig to call and a session ID string.
- In case of an error the LLM defaults to None and the session ID defaults
- to 'NO-SESSION-ID'.
- """
- url, payload, headers = model_select_prepare(
- messages=messages,
- llm_configs=llm_configs,
- metric=metric,
- notdiamond_api_key=notdiamond_api_key,
- max_model_depth=max_model_depth,
- hash_content=hash_content,
- tradeoff=tradeoff,
- preference_id=preference_id,
- tools=tools,
- previous_session=previous_session,
- nd_api_url=nd_api_url,
- _user_agent=_user_agent,
- )
-
- for n_retry in range(1, max_retries + 1):
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- url,
- data=json.dumps(payload),
- headers=headers,
- timeout=timeout,
- ) as response:
- response_code = response.status
- response_json = await response.json()
- break
- except Exception as e:
- LOGGER.error(
- f"Retry {n_retry} of {max_retries}: API error: {e}",
- exc_info=True,
- )
- if n_retry == max_retries:
- return None, "NO-SESSION-ID"
-
- best_llm, session_id = model_select_parse(
- response_code, response_json, llm_configs
- )
-
- return best_llm, session_id
-
-
-
-
-[docs]
-def report_latency(
- session_id: str,
- llm_config: LLMConfig,
- tokens_per_second: float,
- notdiamond_api_key: str,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- _user_agent: str = settings.DEFAULT_USER_AGENT,
-):
- """
- This method makes an API call to the NotDiamond server to report the latency of an LLM call.
- It helps fine-tune our model router and ensure we offer recommendations that meet your latency expectation.
-
- This feature can be disabled on the NDLLM class level by setting `latency_tracking` to False.
-
- Parameters:
- session_id (str): the session ID that was returned from the `invoke` or `model_select` calls, so we know which
- router call your latency report refers to.
- llm_provider (LLMConfig): specifying the LLM provider for which the latency is reported
- tokens_per_second (float): latency of the model call calculated based on time elapsed, input tokens, and output tokens
- notdiamond_api_key (str): NotDiamond API call used for authentication
- nd_api_url (Optional[str], optional): The URL of the NotDiamond API. Defaults to None.
- Returns:
- int: status code of the API call, 200 if it's success
-
- Raises:
- ApiError: if the API call to the NotDiamond backend fails, this error is raised
- """
- url = f"{nd_api_url}/v2/report/metrics/latency"
-
- payload = {
- "session_id": session_id,
- "provider": llm_config.prepare_for_request(),
- "feedback": {"tokens_per_second": tokens_per_second},
- }
-
- headers = _default_headers(notdiamond_api_key, _user_agent)
-
- try:
- response = requests.post(url, json=payload, headers=headers)
- except Exception as e:
- LOGGER.error(
- f"API error for report metrics latency: {e}", exc_info=True
- )
- return 500
-
- return response.status_code
-
-
-
-
-[docs]
-def create_preference_id(
- notdiamond_api_key: str,
- name: Optional[str] = None,
- nd_api_url: Optional[str] = settings.NOTDIAMOND_API_URL,
- _user_agent: str = settings.DEFAULT_USER_AGENT,
-) -> str:
- """
- Create a preference id with an optional name. The preference name will appear in your
- dashboard on Not Diamond.
- """
- url = f"{nd_api_url}/v2/preferences/userPreferenceCreate"
- headers = _default_headers(notdiamond_api_key, _user_agent)
- res = requests.post(url=url, headers=headers, json={"name": name})
- if res.status_code == 200:
- preference_id = res.json()["preference_id"]
- else:
- raise Exception(f"Error creating preference ID: {res.text}")
-
- return preference_id
-
-
-from typing import Optional
-
-from notdiamond import settings
-from notdiamond.exceptions import ApiError
-from notdiamond.llms.config import LLMConfig
-from notdiamond.metrics.request import feedback_request
-from notdiamond.types import NDApiKeyValidator
-
-
-
-[docs]
-class Metric:
- def __init__(self, metric: Optional[str] = "accuracy"):
- self.metric = metric
-
- def __call__(self):
- return self.metric
-
-
-[docs]
- def feedback(
- self,
- session_id: str,
- llm_config: LLMConfig,
- value: int,
- notdiamond_api_key: Optional[str] = None,
- _user_agent: str = None,
- ):
- if notdiamond_api_key is None:
- notdiamond_api_key = settings.NOTDIAMOND_API_KEY
- NDApiKeyValidator(api_key=notdiamond_api_key)
- if value not in [0, 1]:
- raise ApiError("Invalid feedback value. It must be 0 or 1.")
-
- return feedback_request(
- session_id=session_id,
- llm_config=llm_config,
- feedback_payload=self.request_payload(value),
- notdiamond_api_key=notdiamond_api_key,
- _user_agent=_user_agent,
- )
-
-
-
-
-
-
-import logging
-from typing import Dict
-
-import requests
-
-from notdiamond import settings
-from notdiamond._utils import _default_headers
-from notdiamond.exceptions import ApiError
-from notdiamond.llms.config import LLMConfig
-from notdiamond.types import FeedbackRequestPayload
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-
-
-[docs]
-def feedback_request(
- session_id: str,
- llm_config: LLMConfig,
- feedback_payload: Dict[str, int],
- notdiamond_api_key: str,
- nd_api_url: str = settings.NOTDIAMOND_API_URL,
- _user_agent: str = settings.DEFAULT_USER_AGENT,
-) -> bool:
- url = f"{nd_api_url}/v2/report/metrics/feedback"
-
- payload: FeedbackRequestPayload = {
- "session_id": session_id,
- "provider": llm_config.prepare_for_request(),
- "feedback": feedback_payload,
- }
-
- headers = _default_headers(notdiamond_api_key, _user_agent)
-
- try:
- response = requests.post(url, json=payload, headers=headers)
- except Exception as e:
- raise ApiError(f"ND API error for feedback: {e}")
-
- if response.status_code != 200:
- LOGGER.error(
- f"ND API feedback error: failed to report feedback with status {response.status_code}. {response.text}"
- )
- return False
-
- return True
-
-
-import logging
-import re
-from typing import Dict, List
-
-from notdiamond.llms.config import LLMConfig
-from notdiamond.llms.providers import is_o1_model
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-
-
-[docs]
-def inject_system_prompt(
- messages: List[Dict[str, str]], system_prompt: str
-) -> List[Dict[str, str]]:
- """
- Add a system prompt to an OpenAI-style message list. If a system prompt is already present, replace it.
- """
- new_messages = []
- found = False
- for msg in messages:
- # t7: replace the first system prompt with the new one
- if msg["role"] == "system" and not found:
- new_messages.append({"role": "system", "content": system_prompt})
- found = True
- else:
- new_messages.append(msg)
- if not found:
- new_messages.insert(0, {"role": "system", "content": system_prompt})
- return new_messages
-
-
-
-
-[docs]
-def _curly_escape(text: str) -> str:
- """
- Escape curly braces in the text, but only for single occurrences of alphabetic characters.
- This function will not escape double curly braces or non-alphabetic characters.
- """
- return re.sub(r"(?<!{){([a-zA-Z])}(?!})", r"{{\1}}", text)
-
-
-
-
-[docs]
-def o1_system_prompt_translate(
- messages: List[Dict[str, str]], llm: LLMConfig
-) -> List[Dict[str, str]]:
- if is_o1_model(llm):
- translated_messages = []
- for msg in messages:
- if msg["role"] == "system":
- translated_messages.append(
- {"role": "user", "content": msg["content"]}
- )
- else:
- translated_messages.append(msg)
- return translated_messages
- return messages
-
-
-import ppdeep
-
-
-
-[docs]
-def nd_hash(s: str) -> str:
- """
- Source of library from: https://github.com/elceef/ppdeep
- """
- return ppdeep.hash(s)
-
-
-from typing import Any, Dict, List, Optional, Union
-
-from langchain.prompts import PromptTemplate
-from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
-from langchain_core.output_parsers import JsonOutputParser
-from langchain_core.prompts import ChatPromptTemplate
-from langchain_core.prompts.string import get_template_variables
-
-
-
-[docs]
-class NDPromptTemplate(PromptTemplate):
- """Custom implementation of NDPromptTemplate
- Starting reference is from here:
- https://api.python.langchain.com/en/latest/prompts/langchain_core.prompts.prompt.PromptTemplate.html
- """
-
- def __init__(
- self,
- template: str,
- input_variables: Optional[List[str]] = None,
- partial_variables: Optional[Dict[str, Any]] = {},
- ):
- if input_variables is None:
- input_variables = get_template_variables(template, "f-string")
-
- if partial_variables:
- input_variables = []
-
- super(NDPromptTemplate, self).__init__(
- template=template,
- input_variables=input_variables,
- partial_variables=partial_variables,
- )
-
-
-[docs]
- @classmethod
- def from_langchain_prompt_template(cls, prompt_template: PromptTemplate):
- return cls(
- template=prompt_template.template,
- input_variables=prompt_template.input_variables,
- partial_variables=prompt_template.partial_variables,
- )
-
-
-
-[docs]
- def format(self, **kwargs: Any) -> str:
- """Format the prompt template with the given variables and convert it to NDPromptTemplate."""
- return super(NDPromptTemplate, self).format(**kwargs)
-
-
-
-
-
-
-
-
-
-[docs]
- def inject_system_prompt(self, system_prompt: str):
- self.template = system_prompt
- return self
-
-
-
-[docs]
- def inject_model_instruction(self, parser: JsonOutputParser):
- format_instructions = parser.get_format_instructions()
- format_instructions = format_instructions.replace("{", "{{").replace(
- "}", "}}"
- )
- self.template = format_instructions + "\n" + self.template
-
- return self
-
-
-
-
-
-[docs]
-class NDChatPromptTemplate(ChatPromptTemplate):
- """
- Starting reference is from
- here:https://api.python.langchain.com/en/latest/prompts/langchain_core.prompts.chat.ChatPromptTemplate.html
- """
-
- def __init__(
- self,
- messages: Optional[List] = None,
- input_variables: Optional[List[str]] = None,
- partial_variables: [str, Any] = dict,
- ):
- if messages is None:
- messages = []
- if partial_variables:
- input_variables = []
-
- super().__init__(
- messages=messages,
- input_variables=input_variables,
- partial_variables=partial_variables,
- )
-
- @property
- def template(self):
- message = """
- SYSTEM: {system_prompt}
- CONTEXT: {context_prompt}
- QUERY: {user_query}
- """
- return message
-
-
-[docs]
- @classmethod
- def from_langchain_chat_prompt_template(
- cls, chat_prompt_template: ChatPromptTemplate
- ):
- return cls(
- messages=chat_prompt_template.messages,
- input_variables=chat_prompt_template.input_variables,
- partial_variables=chat_prompt_template.partial_variables,
- )
-
-
-
-[docs]
- @classmethod
- def from_openai_messages(cls, messages: List[Dict[str, str]]):
- transformed_messages = []
- for message in messages:
- if message["role"] == "system":
- transformed_messages.append(SystemMessage(message["content"]))
- elif message["role"] == "assistant":
- transformed_messages.append(AIMessage(message["content"]))
- elif message["role"] == "user":
- transformed_messages.append(HumanMessage(message["content"]))
- else:
- raise ValueError(f"Unsupported role: {message['role']}")
- return cls(
- messages=transformed_messages,
- input_variables=None,
- partial_variables={},
- )
-
-
-
-[docs]
- def format(self, **kwargs: Any) -> str:
- """Format the prompt template with the given variables. and converts it to NDChatPromptTemplate."""
- return super(NDChatPromptTemplate, self).format(**kwargs)
-
-
-
-[docs]
- def get_last_human_message(self, formated_messages: List) -> str:
- for message in reversed(formated_messages):
- if isinstance(message, HumanMessage):
- return message.content
-
- raise ValueError("No human message found in the list of messages.")
-
-
-
-[docs]
- def get_role_of_message(
- self, message: Union[AIMessage, HumanMessage, SystemMessage]
- ) -> str:
- if isinstance(message, SystemMessage):
- return "system"
- if isinstance(message, AIMessage):
- return "assistant"
- if isinstance(message, HumanMessage):
- return "user"
- raise ValueError(f"Unsupported message type: {type(message)}")
-
-
-
-[docs]
- def prepare_for_request(self):
- formated_messages = self.format_messages(**self.partial_variables)
- messages = []
- for message in formated_messages:
- if (
- isinstance(message, SystemMessage)
- or isinstance(message, AIMessage)
- or isinstance(message, HumanMessage)
- ):
- messages.append(
- {
- "role": self.get_role_of_message(message),
- "content": message.content,
- }
- )
-
- return messages
-
-
-
-[docs]
- def inject_system_prompt(self, system_prompt: str):
- messages = self.prepare_for_request()
- new_messages = []
- found = False
- for msg in messages:
- # t7: replace the first system prompt with the new one
- if msg["role"] == "system" and not found:
- new_messages.append(
- {"role": "system", "content": system_prompt}
- )
- found = True
- else:
- new_messages.append(msg)
- if not found:
- new_messages.insert(
- 0, {"role": "system", "content": system_prompt}
- )
- return self.from_openai_messages(new_messages)
-
-
-
-[docs]
- def inject_model_instruction(self, parser: JsonOutputParser):
- format_instructions = parser.get_format_instructions()
- format_instructions = format_instructions.replace("{", "{{").replace(
- "}", "}}"
- )
- self.messages[0].content = (
- format_instructions + "\n" + self.messages[0].content
- )
-
- return self
-
-
-
-import json
-import tempfile
-import time
-from collections import OrderedDict
-from typing import Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import pandas as pd
-import requests
-from tqdm import tqdm
-
-from notdiamond.exceptions import ApiError
-from notdiamond.llms.client import NotDiamond
-from notdiamond.llms.config import LLMConfig
-from notdiamond.settings import NOTDIAMOND_API_KEY, NOTDIAMOND_API_URL, VERSION
-from notdiamond.types import NDApiKeyValidator
-from notdiamond._utils import token_counter
-
-
-
-[docs]
-class CustomRouter:
- """
- Implementation of CustomRouter class, used to train custom routers using custom datasets.
-
- Attributes:
- language (str): The language of the dataset in lowercase. Defaults to "english".
- maximize (bool): Whether higher score is better. Defaults to true.
- api_key (Optional[str], optional): The NotDiamond API key. If not specified, will try to
- find it in the environment variable NOTDIAMOND_API_KEY.
- """
-
- def __init__(
- self,
- language: str = "english",
- maximize: bool = True,
- api_key: Optional[str] = None,
- ):
- if api_key is None:
- api_key = NOTDIAMOND_API_KEY
- NDApiKeyValidator(api_key=api_key)
-
- self.api_key = api_key
- self.language = language
- self.maximize = maximize
-
- def _request_train_router(
- self,
- prompt_column: str,
- dataset_file: str,
- llm_configs: List[LLMConfig],
- preference_id: Optional[str],
- nd_api_url: str,
- ) -> str:
- url = f"{nd_api_url}/v2/pzn/trainCustomRouter"
-
- files = {"dataset_file": open(dataset_file, "rb")}
-
- payload = {
- "language": self.language,
- "llm_providers": json.dumps(
- [provider.prepare_for_request() for provider in llm_configs]
- ),
- "prompt_column": prompt_column,
- "maximize": self.maximize,
- "preference_id": preference_id,
- }
-
- headers = {
- "Authorization": f"Bearer {self.api_key}",
- "User-Agent": f"Python-SDK/{VERSION}",
- }
-
- response = requests.post(
- url=url, headers=headers, data=payload, files=files
- )
- if response.status_code != 200:
- raise ApiError(
- f"ND backend error status code: {response.status_code}, {response.text}"
- )
-
- preference_id = response.json()["preference_id"]
- return preference_id
-
- def _prepare_joint_dataset(
- self,
- dataset: Dict[Union[str, LLMConfig], pd.DataFrame],
- prompt_column: str,
- response_column: str,
- score_column: str,
- ) -> Tuple[pd.DataFrame, List[LLMConfig]]:
- a_provider = list(dataset.keys())[0]
- prompts = dataset[a_provider].get(prompt_column, None)
- if prompts is None:
- raise ValueError(f"Prompt column {prompt_column} not found in df.")
- prompts = prompts.to_list()
-
- llm_configs = []
- joint_dataset = {prompt_column: prompts}
- for provider, df in dataset.items():
- llm_configs.append(provider)
-
- responses = df.get(response_column, None)
- if responses is None:
- raise ValueError(
- f"Response column {response_column} not found in df."
- )
- responses = responses.to_list()
- joint_dataset[f"{str(provider)}/response"] = responses
-
- scores = df.get(score_column, None)
- if scores is None:
- raise ValueError(
- f"Score column {score_column} not found in df."
- )
- scores = scores.to_list()
- joint_dataset[f"{str(provider)}/score"] = scores
-
- joint_df = pd.DataFrame(joint_dataset)
-
- llm_configs = NotDiamond._parse_llm_configs_data(llm_configs)
- return joint_df, llm_configs
-
-
-[docs]
- def fit(
- self,
- dataset: Dict[Union[str, LLMConfig], pd.DataFrame],
- prompt_column: str,
- response_column: str,
- score_column: str,
- preference_id: Optional[str] = None,
- nd_api_url: Optional[str] = NOTDIAMOND_API_URL,
- ) -> str:
- """
- Method to train a custom router using provided dataset.
-
- Parameters:
- dataset (Dict[str, pandas.DataFrame]): The dataset to train a custom router.
- Each key in the dictionary should be in the form of <provider>/<model>.
- prompt_column (str): The column name in each DataFrame corresponding
- to the prompts used to evaluate the LLM.
- response_column (str): The column name in each DataFrame corresponding
- to the response given by the LLM for a given prompt.
- score_column (str): The column name in each DataFrame corresponding
- to the score given to the response from the LLM.
- preference_id (Optional[str], optional): If specified, the custom router
- associated with the preference_id will be updated with the provided dataset.
- nd_api_url (Optional[str], optional): The URL of the NotDiamond API. Defaults to prod.
-
- Raises:
- ApiError: When the NotDiamond API fails
- ValueError: When parsing the provided dataset fails
- UnsupportedLLMProvider: When a provider specified in the dataset is not supported.
-
- Returns:
- str:
- preference_id: the preference_id associated with the custom router.
- Use this preference_id in your routing calls to use the custom router.
- """
-
- joint_df, llm_configs = self._prepare_joint_dataset(
- dataset, prompt_column, response_column, score_column
- )
-
- with tempfile.NamedTemporaryFile(suffix=".csv") as joint_csv:
- joint_df.to_csv(joint_csv.name, index=False)
- preference_id = self._request_train_router(
- prompt_column,
- joint_csv.name,
- llm_configs,
- preference_id,
- nd_api_url,
- )
-
- return preference_id
-
-
- def _get_latency(self, llm_config: LLMConfig, prompt: str) -> float:
- llm = NotDiamond._llm_from_config(llm_config)
- start_time = time.time()
- _ = llm.invoke([("human", prompt)])
- end_time = time.time()
- return (end_time - start_time) * 1000 # ms
-
- def _get_cost(
- self, llm_config: LLMConfig, prompt: str, response: str
- ) -> float:
- n_input_tokens = token_counter(model="gpt-4o", text=prompt)
- n_output_tokens = token_counter(model="gpt-4o", text=response)
- input_price = (
- llm_config.default_input_price
- if llm_config.input_price is None
- else llm_config.input_price
- )
- output_price = (
- llm_config.default_output_price
- if llm_config.output_price is None
- else llm_config.output_price
- )
- return (
- n_input_tokens * input_price + n_output_tokens * output_price
- ) / 1e6
-
- def _eval_custom_router(
- self,
- client: NotDiamond,
- llm_configs: List[LLMConfig],
- joint_df: pd.DataFrame,
- prompt_column: str,
- include_latency: bool,
- ) -> Tuple[pd.DataFrame, pd.DataFrame]:
- eval_results = OrderedDict()
- eval_results[prompt_column] = []
- eval_results["session_id"] = []
- eval_results["notdiamond/score"] = []
- eval_results["notdiamond/cost"] = []
- eval_results["notdiamond/response"] = []
- eval_results["notdiamond/recommended_provider"] = []
-
- if include_latency:
- eval_results["notdiamond/latency"] = []
-
- for provider in llm_configs:
- provider_score_column = (
- f"{provider.provider}/{provider.model}/score"
- )
- eval_results[provider_score_column] = []
-
- provider_response_column = (
- f"{provider.provider}/{provider.model}/response"
- )
- eval_results[provider_response_column] = []
-
- provider_cost_column = f"{provider.provider}/{provider.model}/cost"
- eval_results[provider_cost_column] = []
-
- if include_latency:
- provider_latency_column = (
- f"{provider.provider}/{provider.model}/latency"
- )
- eval_results[provider_latency_column] = []
-
- for _, row in tqdm(joint_df.iterrows(), total=len(joint_df)):
- prompt = row[prompt_column]
- eval_results[prompt_column].append(prompt)
-
- session_id, nd_provider = client.chat.completions.model_select(
- messages=[{"role": "user", "content": prompt}], timeout=60
- )
- if nd_provider is None:
- continue
-
- eval_results["session_id"].append(session_id)
-
- provider_matched = False
- for provider in llm_configs:
- provider_score = row[
- f"{provider.provider}/{provider.model}/score"
- ]
- eval_results[
- f"{provider.provider}/{provider.model}/score"
- ].append(provider_score)
-
- provider_response = row[
- f"{provider.provider}/{provider.model}/response"
- ]
- eval_results[
- f"{provider.provider}/{provider.model}/response"
- ].append(provider_response)
-
- provider_cost = self._get_cost(
- provider, prompt, provider_response
- )
- eval_results[
- f"{provider.provider}/{provider.model}/cost"
- ].append(provider_cost)
-
- if include_latency:
- provider_latency = self._get_latency(provider, prompt)
- eval_results[
- f"{provider.provider}/{provider.model}/latency"
- ].append(provider_latency)
-
- if (
- not provider_matched
- and provider.provider == nd_provider.provider
- and provider.model == nd_provider.model
- ):
- provider_matched = True
- eval_results["notdiamond/score"].append(provider_score)
- eval_results["notdiamond/cost"].append(provider_cost)
- eval_results["notdiamond/response"].append(
- provider_response
- )
- eval_results["notdiamond/recommended_provider"].append(
- f"{nd_provider.provider}/{nd_provider.model}"
- )
- if include_latency:
- eval_results["notdiamond/latency"].append(
- provider_latency
- )
-
- if not provider_matched:
- raise ValueError(
- f"""
- Custom router returned {nd_provider.provider}/{nd_provider.model}
- which is not in the set of models in the test dataset
- """
- )
-
- eval_results_df = pd.DataFrame(eval_results)
-
- eval_stats = OrderedDict()
- best_average_provider = None
- best_average_score = -(2 * int(self.maximize) - 1) * np.inf
-
- nd_average_score = eval_results_df["notdiamond/score"].mean()
- eval_stats["Not Diamond Average Score"] = [nd_average_score]
-
- nd_average_cost = eval_results_df["notdiamond/cost"].mean()
- eval_stats["Not Diamond Average Cost"] = [nd_average_cost]
-
- if include_latency:
- nd_average_latency = eval_results_df["notdiamond/latency"].mean()
- eval_stats["Not Diamond Average Latency"] = [nd_average_latency]
-
- for provider in llm_configs:
- provider_avg_score = eval_results_df[
- f"{provider.provider}/{provider.model}/score"
- ].mean()
- eval_stats[f"{provider.provider}/{provider.model}/avg_score"] = [
- provider_avg_score
- ]
-
- provider_avg_cost = eval_results_df[
- f"{provider.provider}/{provider.model}/cost"
- ].mean()
- eval_stats[f"{provider.provider}/{provider.model}/avg_cost"] = [
- provider_avg_cost
- ]
-
- if include_latency:
- provider_avg_latency = eval_results_df[
- f"{provider.provider}/{provider.model}/latency"
- ].mean()
- eval_stats[
- f"{provider.provider}/{provider.model}/avg_latency"
- ] = [provider_avg_latency]
-
- if self.maximize:
- if provider_avg_score > best_average_score:
- best_average_score = provider_avg_score
- best_average_cost = provider_avg_cost
- best_average_provider = (
- f"{provider.provider}/{provider.model}"
- )
- if include_latency:
- best_average_latency = provider_avg_latency
- else:
- if provider_avg_score < best_average_score:
- best_average_score = provider_avg_score
- best_average_cost = provider_avg_cost
- best_average_provider = (
- f"{provider.provider}/{provider.model}"
- )
- if include_latency:
- best_average_latency = provider_avg_latency
-
- eval_stats["Best Average Provider"] = [best_average_provider]
- eval_stats["Best Provider Average Score"] = [best_average_score]
- eval_stats["Best Provider Average Cost"] = [best_average_cost]
-
- if include_latency:
- eval_stats["Best Provider Average Latency"] = [
- best_average_latency
- ]
-
- first_columns = [
- "Best Average Provider",
- "Best Provider Average Score",
- "Best Provider Average Cost",
- "Best Provider Average Latency",
- "Not Diamond Average Score",
- "Not Diamond Average Cost",
- "Not Diamond Average Latency",
- ]
- else:
- first_columns = [
- "Best Average Provider",
- "Best Provider Average Score",
- "Best Provider Average Cost",
- "Not Diamond Average Score",
- "Not Diamond Average Cost",
- ]
-
- column_order = first_columns + [
- col for col in eval_stats.keys() if col not in first_columns
- ]
- ordered_eval_stats = OrderedDict()
- for col in column_order:
- ordered_eval_stats[col] = eval_stats[col]
-
- eval_stats_df = pd.DataFrame(ordered_eval_stats)
- return eval_results_df, eval_stats_df
-
-
-[docs]
- def eval(
- self,
- dataset: Dict[Union[str, LLMConfig], pd.DataFrame],
- prompt_column: str,
- response_column: str,
- score_column: str,
- preference_id: str,
- include_latency: bool = False,
- ) -> Tuple[pd.DataFrame, pd.DataFrame]:
- """
- Method to evaluate a custom router using provided dataset.
-
- Parameters:
- dataset (Dict[str, pandas.DataFrame]): The dataset to train a custom router.
- Each key in the dictionary should be in the form of <provider>/<model>.
- prompt_column (str): The column name in each DataFrame corresponding
- to the prompts used to evaluate the LLM.
- response_column (str): The column name in each DataFrame corresponding
- to the response given by the LLM for a given prompt.
- score_column (str): The column name in each DataFrame corresponding
- to the score given to the response from the LLM.
- preference_id (str): The preference_id associated with the custom router
- returned from .fit().
-
- Raises:
- ApiError: When the NotDiamond API fails
- ValueError: When parsing the provided dataset fails
- UnsupportedLLMProvider: When a provider specified in the dataset is not supported.
-
- Returns:
- Tuple[pandas.DataFrame, pandas.DataFrame]:
- eval_results_df: A DataFrame containing all the prompts, responses of each provider
- (indicated by column <provider>/<model>/response), scores of each provider
- (indicated by column <provider>/<model>/score), and notdiamond custom router
- response and score (indicated by column notdiamond/response and notdiamond/score).
- eval_stats_df: A DataFrame containing the "Best Average Provider" computed from the
- provided dataset, the "Best Provider Average Score" achieved by the "Best Average Provider",
- and the "Not Diamond Average Score" achieved through custom router.
- """
-
- joint_df, llm_configs = self._prepare_joint_dataset(
- dataset, prompt_column, response_column, score_column
- )
-
- client = NotDiamond(
- llm_configs=llm_configs,
- api_key=self.api_key,
- preference_id=preference_id,
- )
-
- eval_results_df, eval_stats_df = self._eval_custom_router(
- client, llm_configs, joint_df, prompt_column, include_latency
- )
- return eval_results_df, eval_stats_df
-
-
-
-import os
-from importlib import metadata
-from typing import (
- Any,
- AsyncIterator,
- Dict,
- Iterator,
- List,
- Optional,
- Sequence,
- Union,
-)
-
-from langchain.chat_models.base import init_chat_model
-from langchain_community.adapters.openai import convert_message_to_dict
-from langchain_core.language_models import LanguageModelInput
-from langchain_core.messages.utils import convert_to_messages
-from langchain_core.prompt_values import (
- ChatPromptValue,
- PromptValue,
- StringPromptValue,
-)
-from langchain_core.runnables import Runnable, RunnableConfig
-
-import notdiamond as nd
-
-_LANGCHAIN_PROVIDERS = {
- "openai",
- "anthropic",
- "google",
- "mistral",
- "togetherai",
- "cohere",
-}
-
-
-
-[docs]
-class NotDiamondRunnable(Runnable[LanguageModelInput, str]):
- """
- See Runnable docs for details
- https://python.langchain.com/v0.1/docs/expression_language/interface/
- """
-
- llm_configs: List
- api_key: Optional[str] = os.getenv("NOTDIAMOND_API_KEY")
- client: Any
-
-
-[docs]
- def __init__(
- self,
- nd_llm_configs: Optional[List] = None,
- nd_api_key: Optional[str] = None,
- nd_client: Optional[Any] = None,
- nd_kwargs: Optional[Dict[str, Any]] = None,
- ):
- """
- Params:
- nd_llm_configs: List of LLM configs to use.
- nd_api_key: Not Diamond API key.
- nd_client: Not Diamond client.
- nd_kwargs: Keyword arguments to pass directly to model_select.
- """
- if not nd_client:
- if not nd_api_key or not nd_llm_configs:
- raise ValueError(
- "Must provide either client or api_key and llm_configs to "
- "instantiate NotDiamondRunnable."
- )
- nd_client = nd.NotDiamond(
- llm_configs=nd_llm_configs,
- api_key=nd_api_key,
- )
- elif nd_client.llm_configs:
- for llm_config in nd_client.llm_configs:
- if isinstance(llm_config, str):
- llm_config = nd.LLMConfig.from_string(llm_config)
- if llm_config.provider not in _LANGCHAIN_PROVIDERS:
- raise ValueError(
- f"Requested provider in {llm_config} supported by Not Diamond "
- "but not langchain.chat_models.base.init_chat_model. Please "
- "remove it from your llm_configs."
- )
-
- try:
- nd_client.user_agent = (
- f"langchain-community/{metadata.version('notdiamond')}"
- )
- except AttributeError:
- pass
-
- self.client = nd_client
- self.api_key = nd_client.api_key
- self.llm_configs = nd_client.llm_configs
- self.nd_kwargs = nd_kwargs or dict()
-
-
- def _model_select(self, input: LanguageModelInput) -> str:
- messages = _convert_input_to_message_dicts(input)
- _, provider = self.client.chat.completions.model_select(
- messages=messages, **self.nd_kwargs
- )
- provider_str = _nd_provider_to_langchain_provider(str(provider))
- return provider_str
-
- async def _amodel_select(self, input: LanguageModelInput) -> str:
- messages = _convert_input_to_message_dicts(input)
- _, provider = await self.client.chat.completions.amodel_select(
- messages=messages, **self.nd_kwargs
- )
- provider_str = _nd_provider_to_langchain_provider(str(provider))
- return provider_str
-
-
-[docs]
- def stream(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> Iterator[str]:
- yield self._model_select(input)
-
-
-
-[docs]
- def invoke(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- ) -> str:
- return self._model_select(input)
-
-
-
-[docs]
- def batch(
- self,
- inputs: List[LanguageModelInput],
- config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
- **kwargs: Optional[Any],
- ) -> List[str]:
- return [self._model_select(input) for input in inputs]
-
-
-
-[docs]
- async def astream(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> AsyncIterator[str]:
- yield await self._amodel_select(input)
-
-
-
-[docs]
- async def ainvoke(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> str:
- return await self._amodel_select(input)
-
-
-
-[docs]
- async def abatch(
- self,
- inputs: List[LanguageModelInput],
- config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
- **kwargs: Optional[Any],
- ) -> List[str]:
- return [await self._amodel_select(input) for input in inputs]
-
-
-
-
-
-[docs]
-class NotDiamondRoutedRunnable(Runnable[LanguageModelInput, Any]):
-
-[docs]
- def __init__(
- self,
- *args: Any,
- configurable_fields: Optional[List[str]] = None,
- nd_llm_configs: Optional[List] = None,
- nd_api_key: Optional[str] = None,
- nd_client: Optional[Any] = None,
- nd_kwargs: Optional[Dict[str, Any]] = None,
- **kwargs: Optional[Dict[Any, Any]],
- ) -> None:
- """
- Params:
- nd_llm_configs: List of LLM configs to use.
- nd_api_key: Not Diamond API key.
- nd_client: Not Diamond client.
- nd_kwargs: Keyword arguments to pass directly to model_select.
- """
- _nd_kwargs = {
- kw: kwargs[kw] for kw in kwargs.keys() if kw.startswith("nd_")
- }
- if nd_kwargs:
- _nd_kwargs.update(nd_kwargs)
-
- self._ndrunnable = NotDiamondRunnable(
- nd_api_key=nd_api_key,
- nd_llm_configs=nd_llm_configs,
- nd_client=nd_client,
- nd_kwargs=_nd_kwargs,
- )
- _routed_fields = ["model", "model_provider"]
- if configurable_fields is None:
- configurable_fields = []
- self._configurable_fields = _routed_fields + configurable_fields
- self._configurable_model = init_chat_model(
- *args,
- configurable_fields=self._configurable_fields,
- config_prefix="nd",
- **{kw: kwv for kw, kwv in kwargs.items() if kw not in _nd_kwargs}, # type: ignore[arg-type]
- )
-
-
-
-[docs]
- def stream(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> Iterator[Any]:
- provider_str = self._ndrunnable._model_select(input)
- _config = self._build_model_config(provider_str, config)
- yield from self._configurable_model.stream(input, config=_config)
-
-
-
-[docs]
- def invoke(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> Any:
- provider_str = self._ndrunnable._model_select(input)
- _config = self._build_model_config(provider_str, config)
- return self._configurable_model.invoke(input, config=_config)
-
-
-
-[docs]
- def batch(
- self,
- inputs: List[LanguageModelInput],
- config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
- **kwargs: Optional[Any],
- ) -> List[Any]:
- config = config or {}
-
- provider_strs = [
- self._ndrunnable._model_select(input) for input in inputs
- ]
- if isinstance(config, dict):
- _configs = [
- self._build_model_config(ps, config) for ps in provider_strs
- ]
- else:
- _configs = [
- self._build_model_config(ps, config[i])
- for i, ps in enumerate(provider_strs)
- ]
-
- return self._configurable_model.batch(inputs, config=_configs)
-
-
-
-[docs]
- async def astream(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> AsyncIterator[Any]:
- provider_str = await self._ndrunnable._amodel_select(input)
- _config = self._build_model_config(provider_str, config)
- async for chunk in self._configurable_model.astream(
- input, config=_config
- ):
- yield chunk
-
-
-
-[docs]
- async def ainvoke(
- self,
- input: LanguageModelInput,
- config: Optional[RunnableConfig] = None,
- **kwargs: Optional[Any],
- ) -> Any:
- provider_str = await self._ndrunnable._amodel_select(input)
- _config = self._build_model_config(provider_str, config)
- return await self._configurable_model.ainvoke(input, config=_config)
-
-
-
-[docs]
- async def abatch(
- self,
- inputs: List[LanguageModelInput],
- config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
- **kwargs: Optional[Any],
- ) -> List[Any]:
- config = config or {}
-
- provider_strs = [
- await self._ndrunnable._amodel_select(input) for input in inputs
- ]
- if isinstance(config, dict):
- _configs = [
- self._build_model_config(ps, config) for ps in provider_strs
- ]
- else:
- _configs = [
- self._build_model_config(ps, config[i])
- for i, ps in enumerate(provider_strs)
- ]
-
- return await self._configurable_model.abatch(inputs, config=_configs)
-
-
-
-[docs]
- def _build_model_config(
- self, provider_str: str, config: Optional[RunnableConfig] = None
- ) -> RunnableConfig:
- """
- Provider string should take the form '{model}/{model_provider}'
- """
- config = config or RunnableConfig()
-
- model_provider, model = provider_str.split("/")
- _config = RunnableConfig(
- configurable={
- "nd_model": model,
- "nd_model_provider": model_provider,
- },
- )
-
- for k, v in config.items():
- _config["configurable"][f"nd_{k}"] = v
- return _config
-
-
-
-
-def _convert_input_to_message_dicts(
- input: LanguageModelInput,
-) -> List[Dict[str, str]]:
- if isinstance(input, PromptValue):
- output = input
- elif isinstance(input, str):
- output = StringPromptValue(text=input)
- elif isinstance(input, Sequence):
- output = ChatPromptValue(messages=convert_to_messages(input))
- else:
- raise ValueError(
- f"Invalid input type {type(input)}. "
- "Must be a PromptValue, str, or list of BaseMessages."
- )
- return [
- convert_message_to_dict(message) for message in output.to_messages()
- ]
-
-
-def _nd_provider_to_langchain_provider(llm_config_str: str) -> str:
- provider, model = llm_config_str.split("/")
- provider = (
- provider.replace("google", "google_genai")
- .replace("mistral", "mistralai")
- .replace("togetherai", "together")
- )
- return f"{provider}/{model}"
-
-"""
-Tools for working directly with OpenAI's various models.
-"""
-import logging
-from typing import List, Union
-
-from notdiamond import NotDiamond
-from notdiamond.llms.providers import NDLLMProviders
-from notdiamond.settings import NOTDIAMOND_API_KEY, OPENAI_API_KEY
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-_ND_PARAMS = [
- "llm_configs",
- "default",
- "max_model_depth",
- "latency_tracking",
- "hash_content",
- "tradeoff",
- "preference_id",
- "tools",
- "callbacks",
- "nd_api_url",
- "nd_api_key",
- "user_agent",
-]
-_SHARED_PARAMS = ["timeout", "max_retries"]
-
-
-
-[docs]
-class _OpenAIBase:
- """
- Base class which wraps both an openai client and Not Diamond retry / fallback logic.
- """
-
- def __init__(self, oai_client_cls, *args, **kwargs):
- nd_kwargs = {
- k: v for k, v in kwargs.items() if k in _ND_PARAMS + _SHARED_PARAMS
- }
-
- # TODO [a9] remove llm_configs as valid constructor arg for ND client
- self._nd_client = NotDiamond(
- api_key=nd_kwargs.get("nd_api_key", NOTDIAMOND_API_KEY),
- llm_configs=["openai/gpt-3.5-turbo"],
- *args,
- **nd_kwargs,
- )
-
- # Create a OpenAI client with a dummy model - will ignore this during routing
- oai_kwargs = {k: v for k, v in kwargs.items() if k not in _ND_PARAMS}
- self._oai_client = oai_client_cls(
- *args, api_key=OPENAI_API_KEY, **oai_kwargs
- )
-
- def __getattr__(self, name):
- return getattr(self._oai_client, name)
-
- def __call__(self, *args, **kwargs):
- return self._oai_client(*args, **kwargs)
-
- def __dir__(self):
- return dir(self._oai_client)
-
- @property
- def chat(self):
- class ChatCompletions:
- def __init__(self, parent):
- self.parent = parent
-
- @property
- def completions(self):
- return self
-
- def create(self, *args, **kwargs):
- return self.parent.create(*args, **kwargs)
-
- return ChatCompletions(self)
-
- def _create_prep(self, model: Union[str, List], **kwargs):
- model = kwargs.get("model", model)
-
- if model is None:
- LOGGER.info(
- "No LLM configs provided. Not Diamond will route to all OpenAI models."
- )
- llm_configs = [
- str(p) for p in NDLLMProviders if p.provider == "openai"
- ]
- elif isinstance(model, str):
- llm_configs = model.split(",")
- elif isinstance(model, list):
- llm_configs = self._nd_client._parse_llm_configs_data(model)
-
- if "messages" not in kwargs:
- raise ValueError("'messages' argument is required")
-
- return llm_configs
-
-
-
-
-[docs]
-class OpenAI(_OpenAIBase):
- """
- Encapsulating class for an openai.OpenAI client. This supports the same methods as
- the openai package, while also supporting routed prompts with calls to `completion`.
- """
-
- def __init__(self, *args, **kwargs):
- from openai import OpenAI as OpenAIClient
-
- super().__init__(OpenAIClient, *args, **kwargs)
-
-
-[docs]
- def create(self, *args, model: Union[str, List] = None, **kwargs):
- """
- Perform chat completion using OpenAI's API, after routing the prompt to a
- specific LLM via Not Diamond.
- """
- llm_configs = self._create_prep(model, **kwargs)
- session_id, best_llm = self._nd_client.model_select(
- *args, model=llm_configs, **kwargs
- )
- response = self._oai_client.chat.completions.create(
- *args, model=str(best_llm.model), **kwargs
- )
- LOGGER.info(f"Routed prompt to {best_llm} for session ID {session_id}")
- return response
-
-
-
-
-
-[docs]
-class AsyncOpenAI(_OpenAIBase):
- """
- Encapsulating class for an openai.OpenAI client. This supports the same methods as
- the openai package, while also supporting routed prompts with calls to `completion`.
- """
-
- def __init__(self, *args, **kwargs):
- from openai import AsyncOpenAI as OpenAIClient
-
- super().__init__(OpenAIClient, *args, **kwargs)
-
-
-[docs]
- async def create(self, *args, model: Union[str, List] = None, **kwargs):
- """
- Perform async chat completion using OpenAI's API, after routing the prompt to a
- specific LLM via Not Diamond.
- """
- llm_configs = self._create_prep(model, **kwargs)
- session_id, best_llm = await self._nd_client.amodel_select(
- *args, model=llm_configs, **kwargs
- )
- response = await self._oai_client.chat.completions.create(
- *args, model=str(best_llm.model), **kwargs
- )
- LOGGER.debug(
- f"Routed prompt to {best_llm} for session ID {session_id}"
- )
- return response
-
-
-
-import logging
-from typing import Dict, List, Optional, Sequence, Tuple, Union
-
-import optuna
-import pandas as pd
-from langchain_core.callbacks import Callbacks
-from langchain_core.embeddings import Embeddings as LangchainEmbeddings
-from langchain_core.language_models import BaseLanguageModel as LangchainLLM
-from langchain_core.prompt_values import StringPromptValue
-from ragas import EvaluationDataset, SingleTurnSample
-from ragas._analytics import track_was_completed
-from ragas.cost import TokenUsageParser
-from ragas.embeddings.base import BaseRagasEmbeddings
-from ragas.evaluation import evaluate as ragas_evaluate
-from ragas.llms import BaseRagasLLM
-from ragas.metrics.base import Metric
-from ragas.run_config import RunConfig
-from tqdm import tqdm
-
-from notdiamond.llms.config import LLMConfig
-from notdiamond.toolkit.rag.evaluation_dataset import (
- RAGEvaluationDataset,
- RAGSample,
-)
-from notdiamond.toolkit.rag.llms import get_llm
-from notdiamond.toolkit.rag.workflow import BaseNDRagWorkflow
-
-LOGGER = logging.getLogger(__name__)
-LOGGER.setLevel(logging.INFO)
-
-_DEFAULT_GENERATION_TEMPERATURE = 0.7
-
-
-
-[docs]
-def get_eval_dataset(
- test_queries: pd.DataFrame,
- workflow: BaseNDRagWorkflow,
- generation_prompt: str = None,
- generator_llm: Union[LLMConfig, str] = None,
-):
- """
- Create a dataset of RAGSample objects to evaluate the performance of a RAG workflow.
-
- Args:
- test_queries: A pandas DataFrame with schema implied by the method below.
- workflow: BaseNDRagWorkflow subclass created by the user.
-
- Schema for test_queries can be found below
- https://docs.ragas.io/en/stable/references/evaluation_schema/#ragas.dataset_schema.SingleTurnSample
- """
- samples = []
- for _, row in test_queries.iterrows():
- query = row["user_input"]
- reference = row["reference"]
- generation_prompt = generation_prompt or row.get("generation_prompt")
- if generation_prompt is None:
- raise ValueError(
- "Provided test queries DataFrame does not include"
- " 'generation_prompt' and 'generation_prompt' is not provided."
- )
- generator_llm = generator_llm or row.get("generator_llm")
- if generator_llm is None:
- raise ValueError(
- "Provided test queries DataFrame does not include"
- " 'generator_llm' and 'generator_llm' is not provided."
- )
-
- retrieved_contexts = workflow.get_retrieved_context(query)
- response = workflow.get_response(query)
-
- sample = RAGSample(
- user_input=query,
- retrieved_contexts=retrieved_contexts,
- response=response,
- reference=reference,
- generation_prompt=generation_prompt,
- generator_llm=generator_llm,
- )
- samples.append(sample)
- eval_dataset = RAGEvaluationDataset(samples)
- return eval_dataset
-
-
-
-
-[docs]
-def auto_optimize(workflow: BaseNDRagWorkflow, n_trials: int):
- direction = "maximize" if workflow.objective_maximize else "minimize"
- study = optuna.create_study(
- study_name=workflow.job_name, direction=direction
- )
- study.optimize(workflow._outer_objective, n_trials=n_trials)
- workflow._set_param_values(study.best_params)
- return {"best_params": study.best_params, "trials": study.trials}
-
-
-
-def _map_to_ragas_samples(
- dataset: RAGEvaluationDataset,
-) -> Tuple[EvaluationDataset, pd.DataFrame]:
- ragas_samples = []
- extra_columns = {
- "generation_prompt": [],
- }
- for sample in dataset:
- ragas_sample = SingleTurnSample(
- user_input=sample.user_input,
- retrieved_contexts=sample.retrieved_contexts,
- reference_contexts=sample.reference_contexts,
- response=sample.response,
- multi_responses=sample.multi_responses,
- reference=sample.reference,
- rubrics=sample.rubrics,
- )
- ragas_samples.append(ragas_sample)
- extra_columns["generation_prompt"].append(sample.generation_prompt)
-
- extra_columns_df = pd.DataFrame.from_dict(extra_columns)
- return EvaluationDataset(ragas_samples), extra_columns_df
-
-
-def _evaluate_dataset(
- generator_llm: LLMConfig,
- dataset: EvaluationDataset,
- metrics: Optional[Sequence[Metric]] = None,
- llm: Optional[Union[BaseRagasLLM, LangchainLLM]] = None,
- embeddings: Optional[
- Union[BaseRagasEmbeddings, LangchainEmbeddings]
- ] = None,
- callbacks: Callbacks = None,
- in_ci: bool = False,
- run_config: RunConfig = RunConfig(),
- token_usage_parser: Optional[TokenUsageParser] = None,
- raise_exceptions: bool = False,
- column_map: Optional[Dict[str, str]] = None,
- show_progress: bool = True,
- batch_size: Optional[int] = None,
-) -> pd.DataFrame:
- LOGGER.info(f"Evaluating generations from {str(generator_llm)}")
-
- result = ragas_evaluate(
- dataset,
- metrics,
- llm,
- embeddings,
- callbacks,
- in_ci,
- run_config,
- token_usage_parser,
- raise_exceptions,
- column_map,
- show_progress,
- batch_size,
- )
- return result.to_pandas()
-
-
-def _generate_rag_eval_dataset(
- generator_llm: LLMConfig,
- dataset: RAGEvaluationDataset,
- temperature: float = _DEFAULT_GENERATION_TEMPERATURE,
-) -> RAGEvaluationDataset:
- LOGGER.info(f"Generating responses from {str(generator_llm)}")
-
- llm = get_llm(generator_llm)
- temperature = generator_llm.kwargs.get("temperature", temperature)
-
- eval_samples = []
- for sample in tqdm(dataset):
- response = llm.generate_text(
- StringPromptValue(text=sample.generation_prompt),
- temperature=temperature,
- )
- eval_sample = RAGSample(
- user_input=sample.user_input,
- retrieved_contexts=sample.retrieved_contexts,
- reference_contexts=sample.reference_contexts,
- response=response.generations[0][0].text,
- multi_responses=sample.multi_responses,
- reference=sample.reference,
- rubrics=sample.rubrics,
- generation_prompt=sample.generation_prompt,
- generator_llm=str(generator_llm),
- )
- eval_samples.append(eval_sample)
- return RAGEvaluationDataset(eval_samples)
-
-
-
-[docs]
-@track_was_completed
-def evaluate(
- dataset: RAGEvaluationDataset,
- metrics: Optional[Sequence[Metric]] = None,
- llm: Optional[Union[BaseRagasLLM, LangchainLLM]] = None,
- embeddings: Optional[
- Union[BaseRagasEmbeddings, LangchainEmbeddings]
- ] = None,
- callbacks: Callbacks = None,
- in_ci: bool = False,
- run_config: RunConfig = RunConfig(),
- token_usage_parser: Optional[TokenUsageParser] = None,
- raise_exceptions: bool = False,
- column_map: Optional[Dict[str, str]] = None,
- show_progress: bool = True,
- batch_size: Optional[int] = None,
- generator_llms: List[LLMConfig] = [],
-) -> Dict[str, pd.DataFrame]:
- dataset_llm_str = dataset[0].generator_llm
- dataset_llm_config = LLMConfig.from_string(dataset_llm_str)
-
- if dataset_llm_config not in generator_llms:
- generator_llms.append(dataset_llm_config)
-
- ragas_dataset, extra_columns = _map_to_ragas_samples(dataset)
-
- dataset_results = _evaluate_dataset(
- dataset_llm_config,
- ragas_dataset,
- metrics,
- llm,
- embeddings,
- callbacks,
- in_ci,
- run_config,
- token_usage_parser,
- raise_exceptions,
- column_map,
- show_progress,
- batch_size,
- )
-
- evaluation_results = {
- str(dataset_llm_config): pd.concat(
- [dataset_results, extra_columns], axis=1
- )
- }
-
- for llm_config in generator_llms:
- if str(llm_config) in evaluation_results:
- continue
-
- llm_dataset = _generate_rag_eval_dataset(llm_config, dataset)
- ragas_dataset, extra_columns = _map_to_ragas_samples(llm_dataset)
- dataset_results = _evaluate_dataset(
- llm_config,
- ragas_dataset,
- metrics,
- llm,
- embeddings,
- callbacks,
- in_ci,
- run_config,
- token_usage_parser,
- raise_exceptions,
- column_map,
- show_progress,
- batch_size,
- )
- evaluation_results[str(llm_config)] = pd.concat(
- [dataset_results, extra_columns], axis=1
- )
- return evaluation_results
-
-
-from dataclasses import dataclass
-from typing import Dict, List, Union, overload
-
-from ragas import MultiTurnSample, SingleTurnSample
-from ragas.dataset_schema import RagasDataset
-
-
-
-[docs]
-class RAGSample(SingleTurnSample):
- """
- Represents RAG evaluation samples.
-
- Attributes:
- user_input (Optional[str]): The input query from the user.
- retrieved_contexts (Optional[List[str]]): List of contexts retrieved for the query.
- reference_contexts (Optional[List[str]]): List of reference contexts for the query.
- response (Optional[str]): The generated response for the query.
- generation_prompt (str): The input prompt to the generator LLM.
- generator_llm (str): The LLM used to generate the response.
- multi_responses (Optional[List[str]]): List of multiple responses generated for the query.
- reference (Optional[str]): The reference answer for the query.
- rubric (Optional[Dict[str, str]]): Evaluation rubric for the sample.
- """
-
- generation_prompt: str
- generator_llm: str
-
-
-
-
-[docs]
-@dataclass
-class RAGEvaluationDataset(RagasDataset[RAGSample]):
- """
- Represents a dataset of RAG evaluation samples.
-
- Attributes:
- samples (List[BaseSample]): A list of evaluation samples.
-
- Methods:
- validate_samples(samples): Validates that all samples are of the same type.
- get_sample_type(): Returns the type of the samples in the dataset.
- to_hf_dataset(): Converts the dataset to a Hugging Face Dataset.
- to_pandas(): Converts the dataset to a pandas DataFrame.
- features(): Returns the features of the samples.
- from_list(mapping): Creates an EvaluationDataset from a list of dictionaries.
- from_dict(mapping): Creates an EvaluationDataset from a dictionary.
- to_csv(path): Converts the dataset to a CSV file.
- to_jsonl(path): Converts the dataset to a JSONL file.
- from_jsonl(path): Creates an EvaluationDataset from a JSONL file.
- """
-
- @overload
- def __getitem__(self, idx: int) -> RAGSample:
- ...
-
- @overload
- def __getitem__(self, idx: slice) -> "RAGEvaluationDataset":
- ...
-
- def __getitem__(
- self, idx: Union[int, slice]
- ) -> Union[RAGSample, "RAGEvaluationDataset"]:
- if isinstance(idx, int):
- return self.samples[idx]
- elif isinstance(idx, slice):
- return type(self)(samples=self.samples[idx])
- else:
- raise TypeError("Index must be int or slice")
-
-
-
-
-
-[docs]
- def to_list(self) -> List[Dict]:
- rows = [sample.to_dict() for sample in self.samples]
- return rows
-
-
-
-[docs]
- @classmethod
- def from_list(cls, data: List[Dict]):
- samples = []
- if all(
- "user_input" in item and isinstance(data[0]["user_input"], list)
- for item in data
- ):
- samples.extend(MultiTurnSample(**sample) for sample in data)
- else:
- samples.extend(SingleTurnSample(**sample) for sample in data)
- return cls(samples=samples)
-
-
- def __repr__(self) -> str:
- return f"RAGEvaluationDataset(features={self.features()}, len={len(self.samples)})"
-
-
-from typing import Union
-
-from langchain_cohere import CohereEmbeddings
-from langchain_mistralai import MistralAIEmbeddings
-from langchain_openai import OpenAIEmbeddings
-from ragas.embeddings import HuggingfaceEmbeddings, LangchainEmbeddingsWrapper
-from ragas.llms import LangchainLLMWrapper
-
-from ...exceptions import UnsupportedEmbeddingProvider
-from ...llms.client import NotDiamond
-from ...llms.config import EmbeddingConfig, LLMConfig
-
-
-
-[docs]
-def get_llm(llm_config_or_str: Union[LLMConfig, str]) -> LangchainLLMWrapper:
- """
- Build the LLM object compatible with evaluation metrics.
-
- Parameters:
- llm_config_or_str (Union[LLMConfig, str]): a LLMConfig object or a model string
- that specifies the LLM to construct.
- """
- if isinstance(llm_config_or_str, str):
- llm_config = LLMConfig.from_string(llm_config_or_str)
- else:
- llm_config = llm_config_or_str
-
- lc_llm = NotDiamond._llm_from_config(llm_config)
- return LangchainLLMWrapper(lc_llm)
-
-
-
-
-[docs]
-def get_embedding(
- embedding_model_config_or_str: Union[EmbeddingConfig, str]
-) -> Union[LangchainEmbeddingsWrapper, HuggingfaceEmbeddings]:
- """
- Build the embedding model object compatible with evaluation metrics.
-
- Parameters:
- embedding_model_config_or_str (Union[EmbeddingConfig, str]): an EmbeddingConfig object
- or an embedding model string that specifies the embedding model to construct.
- """
- if isinstance(embedding_model_config_or_str, str):
- embedding_config = EmbeddingConfig.from_string(
- embedding_model_config_or_str
- )
- else:
- embedding_config = embedding_model_config_or_str
-
- if embedding_config.provider == "openai":
- lc_embedding = OpenAIEmbeddings(
- model=embedding_config.model, **embedding_config.kwargs
- )
-
- elif embedding_config.provider == "cohere":
- lc_embedding = CohereEmbeddings(
- model=embedding_config.model, **embedding_config.kwargs
- )
-
- elif embedding_config.provider == "mistral":
- lc_embedding = MistralAIEmbeddings(
- model=embedding_config.model, **embedding_config.kwargs
- )
-
- elif embedding_config.provider == "huggingface":
- return HuggingfaceEmbeddings(model_name=embedding_config.model)
-
- else:
- raise UnsupportedEmbeddingProvider(
- f"Embedding model {str(embedding_config)} not supported."
- )
- return LangchainEmbeddingsWrapper(lc_embedding)
-
-
-from typing import List, Optional, Sequence, Union
-
-import pandas as pd
-from langchain_core.callbacks import Callbacks
-from langchain_core.documents import Document as LCDocument
-from llama_index.core.base.embeddings.base import (
- BaseEmbedding as LlamaIndexEmbedding,
-)
-from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM
-from llama_index.core.schema import Document as LlamaIndexDocument
-from ragas.embeddings import BaseRagasEmbeddings, LlamaIndexEmbeddingsWrapper
-from ragas.llms import BaseRagasLLM, LlamaIndexLLMWrapper
-from ragas.run_config import RunConfig
-from ragas.testset import TestsetGenerator
-from ragas.testset.graph import KnowledgeGraph, Node, NodeType
-from ragas.testset.persona import Persona
-from ragas.testset.synthesizers import QueryDistribution
-from ragas.testset.transforms import (
- Transforms,
- apply_transforms,
- default_transforms,
-)
-
-
-
-[docs]
-class TestDataGenerator(TestsetGenerator):
-
-[docs]
- def __init__(
- self,
- llm: BaseRagasLLM,
- embedding_model: BaseRagasEmbeddings,
- knowledge_graph: KnowledgeGraph = KnowledgeGraph(),
- persona_list: Optional[List[Persona]] = None,
- ):
- """
- RAG Test data generator class.
- Generates test cases from documents for evaluating RAG workflows.
-
- Parameters:
- llm (BaseRagasLLM): An LLM object inherited from BaseRagasLLM. Obtain this
- via the get_llm tool.
- embedding_model (BaseRagasEmbeddings): An embedding model object inherited
- from BaseRagasEmbeddings. Obtain this via the get_embedding tool.
- knowledge_graph (KnowledgeGraph): The knowledge graph to use for the generation
- process. Default empty.
- """
- super().__init__(
- llm=llm,
- embedding_model=embedding_model,
- knowledge_graph=knowledge_graph,
- persona_list=persona_list,
- )
-
-
-
-[docs]
- def generate_from_docs(
- self,
- documents: Union[Sequence[LCDocument], Sequence[LlamaIndexDocument]],
- testset_size: int,
- transforms: Optional[Transforms] = None,
- transforms_llm: Optional[Union[BaseRagasLLM, LlamaIndexLLM]] = None,
- transforms_embedding_model: Optional[
- Union[BaseRagasEmbeddings, LlamaIndexEmbedding]
- ] = None,
- query_distribution: Optional[QueryDistribution] = None,
- run_config: Optional[RunConfig] = None,
- callbacks: Optional[Callbacks] = None,
- with_debugging_logs: bool = False,
- raise_exceptions: bool = True,
- ) -> pd.DataFrame:
- """
- Generates an evaluation dataset based on given Langchain or Llama Index documents and parameters.
-
- Parameters:
- documents : Sequence[LCDocument]
- A sequence of Langchain documents to use as source material
- testset_size : int
- The number of test samples to generate
- transforms : Optional[Transforms], optional
- Custom transforms to apply to the documents, by default None
- transforms_llm : Optional[BaseRagasLLM], optional
- LLM to use for transforms if different from instance LLM, by default None
- transforms_embedding_model : Optional[BaseRagasEmbeddings], optional
- Embedding model to use for transforms if different from instance model, by default None
- query_distribution : Optional[QueryDistribution], optional
- Distribution of query types to generate, by default None
- run_config : Optional[RunConfig], optional
- Configuration for the generation run, by default None
- callbacks : Optional[Callbacks], optional
- Callbacks to use during generation, by default None
- with_debugging_logs : bool, optional
- Whether to include debug logs, by default False
- raise_exceptions : bool, optional
- Whether to raise exceptions during generation, by default True
-
- Returns:
- Testset
- The generated evaluation dataset
-
- Raises:
- ValueError
- If no LLM or embedding model is provided either during initialization or as arguments
- """
- assert isinstance(
- documents, list
- ), "Documents must be a list of langchain or llama-index documents."
-
- if isinstance(documents[0], LCDocument):
- dataset = self.generate_with_langchain_docs(
- documents=documents,
- testset_size=testset_size,
- transforms=transforms,
- transforms_llm=transforms_llm,
- transforms_embedding_model=transforms_embedding_model,
- query_distribution=query_distribution,
- run_config=run_config,
- callbacks=callbacks,
- with_debugging_logs=with_debugging_logs,
- raise_exceptions=raise_exceptions,
- )
- return dataset.to_pandas()
-
- elif isinstance(documents[0], LlamaIndexDocument):
- dataset = self.generate_with_llamaindex_docs(
- documents=documents,
- testset_size=testset_size,
- transforms=transforms,
- transforms_llm=transforms_llm,
- transforms_embedding_model=transforms_embedding_model,
- query_distribution=query_distribution,
- run_config=run_config,
- callbacks=callbacks,
- with_debugging_logs=with_debugging_logs,
- raise_exceptions=raise_exceptions,
- )
- return dataset.to_pandas()
-
- raise ValueError(
- "Documents must be a list of langchain or llama-index documents."
- )
-
-
-
-[docs]
- def generate_with_llamaindex_docs(
- self,
- documents: Sequence[LlamaIndexDocument],
- testset_size: int,
- transforms: Optional[Transforms] = None,
- transforms_llm: Optional[LlamaIndexLLM] = None,
- transforms_embedding_model: Optional[LlamaIndexEmbedding] = None,
- query_distribution: Optional[QueryDistribution] = None,
- run_config: Optional[RunConfig] = None,
- callbacks: Optional[Callbacks] = None,
- with_debugging_logs=False,
- raise_exceptions: bool = True,
- ):
- """
- Generates an evaluation dataset based on given scenarios and parameters.
- """
-
- run_config = run_config or RunConfig()
-
- # force the user to provide an llm and embedding client to prevent use of default LLMs
- if not self.llm and not transforms_llm:
- raise ValueError(
- "An llm client was not provided."
- " Provide an LLM on init or as an argument to this method."
- " Alternatively you can provide your own transforms through the `transforms` parameter."
- )
- if not self.embedding_model and not transforms_embedding_model:
- raise ValueError(
- "An embedding client was not provided."
- " Provide an embedding through the transforms_embedding_model parameter."
- " Alternatively you can provide your own transforms through the `transforms` parameter."
- )
-
- if not transforms:
- # use TestsetGenerator's LLM and embedding model if no transforms_llm or transforms_embedding_model is provided
- if transforms_llm is None:
- llm_for_transforms = self.llm
- else:
- llm_for_transforms = LlamaIndexLLMWrapper(transforms_llm)
- if transforms_embedding_model is None:
- embedding_model_for_transforms = self.embedding_model
- else:
- embedding_model_for_transforms = LlamaIndexEmbeddingsWrapper(
- transforms_embedding_model
- )
-
- # create the transforms
- transforms = default_transforms(
- documents=[
- LCDocument(page_content=doc.text) for doc in documents
- ],
- llm=llm_for_transforms,
- embedding_model=embedding_model_for_transforms,
- )
-
- # convert the documents to Ragas nodes
- nodes = []
- for doc in documents:
- if doc.text is not None and doc.text.strip() != "":
- node = Node(
- type=NodeType.DOCUMENT,
- properties={
- "page_content": doc.text,
- "document_metadata": doc.metadata,
- },
- )
- nodes.append(node)
-
- kg = KnowledgeGraph(nodes=nodes)
-
- # apply transforms and update the knowledge graph
- apply_transforms(kg, transforms, run_config)
- self.knowledge_graph = kg
-
- return self.generate(
- testset_size=testset_size,
- query_distribution=query_distribution,
- run_config=run_config,
- callbacks=callbacks,
- with_debugging_logs=with_debugging_logs,
- raise_exceptions=raise_exceptions,
- )
-
-
-
-from dataclasses import dataclass
-from typing import Any, ClassVar, Dict, List, Optional, Type, Union, get_args
-
-import optuna
-
-from notdiamond.toolkit.rag.evaluation_dataset import RAGEvaluationDataset
-
-
-
-[docs]
-@dataclass
-class IntValueRange:
- """
- A range of int values for an auto-evaluated RAG pipeline. Useful for, eg. RAG context chunk size.
- """
-
- lo: int
- hi: int
- step: int
-
- def __contains__(self, value: int) -> bool:
- return (
- self.lo <= value <= self.hi and (value - self.lo) % self.step == 0
- )
-
-
-
-
-[docs]
-@dataclass
-class FloatValueRange:
- """
- A range of float values for an auto-evaluated RAG pipeline. Useful for, eg. LLM temperature.
- """
-
- lo: float
- hi: float
- step: float
-
- def __contains__(self, value: float) -> bool:
- return self.lo <= value <= self.hi
-
-
-
-
-[docs]
-@dataclass
-class CategoricalValueOptions:
- """
- A list of categorical values for an auto-evaluated RAG pipeline. Useful for, eg. embedding algorithms.
- """
-
- values: List[str]
-
- def __contains__(self, value: str) -> bool:
- return value in self.values
-
-
-
-_ALLOWED_TYPES = [IntValueRange, FloatValueRange, CategoricalValueOptions]
-
-
-
-[docs]
-class BaseNDRagWorkflow:
- """
- A base interface for a RAG workflow to be auto-evaluated by Not Diamond.
-
- Subclasses should define parameter_specs to type parameters they need to optimize,
- by using type annotations with the above dataclasses. For example:
-
- .. code-block:: python
-
- class ExampleNDRagWorkflow(BaseNDRagWorkflow):
- parameter_specs = {
- "chunk_size": (Annotated[int, IntValueRange(1000, 2500, 500)], 1000),
- "chunk_overlap": (Annotated[int, IntValueRange(50, 200, 25)], 100),
- "top_k": (Annotated[int, IntValueRange(1, 20, 1)], 5),
- "algo": (
- Annotated[
- str,
- CategoricalValueOptions(
- [
- "BM25",
- "openai_small",
- "openai_large",
- "cohere_eng",
- "cohere_multi",
- ]
- ),
- ],
- "BM25",
- ),
- "temperature": (Annotated[float, FloatValueRange(0.0, 1.0, 0.1)], 0.9),
- }
-
- """
-
- parameter_specs: ClassVar[Dict[str, tuple[Type, Any]]] = {}
-
-
-[docs]
- def __init__(
- self,
- documents: Any,
- test_queries: Optional[Any] = None,
- objective_maximize: bool = True,
- **kwargs,
- ):
- """
- Args:
- evaluation_dataset: A dataset of RAG evaluation samples.
- documents: The documents to use for RAG.
- objective_maximize: Whether to maximize or minimize the objective defined below.
- """
- if not self.parameter_specs:
- raise NotImplementedError(
- f"Class {self.__class__.__name__} must define parameter_specs"
- )
-
- self._param_ranges = {}
- self._base_param_types = {}
- for param_name, (
- param_type,
- default_value,
- ) in self.parameter_specs.items():
- type_args = get_args(param_type)
- base_type, range_type = type_args
- if range_type is None:
- raise ValueError(
- f"Expected parameter type in {_ALLOWED_TYPES} but received {param_type}"
- )
- self._param_ranges[param_name] = range_type
- self._base_param_types[param_name] = base_type
- if not isinstance(default_value, base_type):
- raise ValueError(
- f"Expected default value type {base_type} but received {type(default_value)}"
- )
- setattr(self, param_name, default_value)
-
- self.documents = documents
- self.test_queries = test_queries
- self.objective_maximize = objective_maximize
-
-
-
-[docs]
- def get_parameter_type(self, param_name: str) -> Type:
- param_type = self._base_param_types.get(param_name)
- if param_type is None:
- raise ValueError(
- f"Parameter {param_name} not found in parameter_specs"
- )
- return param_type
-
-
-
-[docs]
- def get_parameter_range(self, param_name: str) -> Type:
- param_range = self._param_ranges.get(param_name)
- if param_range is None:
- raise ValueError(
- f"Parameter {param_name} not found in parameter_specs"
- )
- return param_range
-
-
-
-[docs]
- def rag_workflow(
- self, documents: Any, test_queries: Optional[List[str]]
- ) -> RAGEvaluationDataset:
- """
- Define RAG workflow components here by setting instance attrs on `self`. Those attributes will be set
- at init-time and available when retrieving context or generating responses.
- """
- raise NotImplementedError()
-
-
-
-[docs]
- def objective(self):
- """
- Define the objective function for your RAG workflow. The workflow's hyperparameters will be optimized
- according to values of this objective.
- """
- raise NotImplementedError()
-
-
- def _outer_objective(self, trial: optuna.Trial):
- for param_name in self.parameter_specs.keys():
- param_range = self.get_parameter_range(param_name)
- if isinstance(param_range, IntValueRange):
- param_value = trial.suggest_int(
- param_name,
- param_range.lo,
- param_range.hi,
- step=param_range.step,
- )
- elif isinstance(param_range, FloatValueRange):
- param_value = trial.suggest_float(
- param_name,
- param_range.lo,
- param_range.hi,
- step=param_range.step,
- )
- elif isinstance(param_range, CategoricalValueOptions):
- param_value = trial.suggest_categorical(
- param_name, param_range.values
- )
- else:
- raise ValueError(
- f"Expected parameter type in {_ALLOWED_TYPES} but received unknown parameter type: {param_range}"
- )
- setattr(self, param_name, param_value)
-
- self.evaluation_dataset = self.rag_workflow(
- self.documents, self.test_queries
- )
- result = self.objective()
- self._reset_param_values()
- return result
-
- def _get_default_param_values(self):
- return {
- param_name: default_value
- for param_name, (_, default_value) in self.parameter_specs.items()
- }
-
- def _set_param_values(
- self, param_values: Dict[str, Union[int, float, str]]
- ):
- for param_name in self.parameter_specs.keys():
- param_value = param_values.get(param_name)
- param_type = self.get_parameter_type(param_name)
- param_range = self.get_parameter_range(param_name)
- if param_value is None:
- raise ValueError(
- f"Best value for {param_name} not found. This should not happen."
- )
- elif not isinstance(param_value, param_type):
- raise ValueError(
- f"Expected parameter type {param_type} but received {type(param_value)}"
- )
- elif param_value not in param_range:
- raise ValueError(
- f"Parameter value {param_value} not in range {param_range}"
- )
- setattr(self, param_name, param_value)
-
- def _reset_param_values(self):
- for param_name in self.parameter_specs.keys():
- setattr(
- self, param_name, self._get_default_param_values()[param_name]
- )
-
-
-
-
-
-
-
-
-
-
-
-from typing import Any, Dict, List
-
-from pydantic import BaseModel, field_validator
-
-from notdiamond.exceptions import InvalidApiKey, MissingApiKey
-
-
-
-[docs]
-class NDApiKeyValidator(BaseModel):
- api_key: str
-
-
-[docs]
- @field_validator("api_key", mode="before")
- @classmethod
- def api_key_must_be_a_string(cls, v) -> str:
- if not isinstance(v, str):
- raise InvalidApiKey("ND API key should be a string")
- return v
-
-
-
-[docs]
- @field_validator("api_key", mode="after")
- @classmethod
- def string_must_not_be_empty(cls, v):
- if len(v) == 0:
- raise MissingApiKey("ND API key should be longer than 0")
- return v
-
-
-
-
-
-[docs]
-class ModelSelectRequestPayload(BaseModel):
- prompt_template: str
- formatted_prompt: str
- components: Dict[str, Dict]
- llm_configs: List[Dict]
- metric: str
- max_model_depth: int
-
-
-
-
-[docs]
-class FeedbackRequestPayload(BaseModel):
- session_id: str
- provider: Dict[str, Any]
- feedback: Dict[str, int]
-
-
-"""Tools to provide pretty/human-readable display of objects."""
-
-from __future__ import annotations as _annotations
-
-import types
-import typing
-from typing import Any
-
-import typing_extensions
-
-from . import _typing_extra
-
-if typing.TYPE_CHECKING:
- ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
- RichReprResult: typing_extensions.TypeAlias = (
- 'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
- )
-
-
-class PlainRepr(str):
- """String class where repr doesn't include quotes. Useful with Representation when you want to return a string
- representation of something that is valid (or pseudo-valid) python.
- """
-
- def __repr__(self) -> str:
- return str(self)
-
-
-class Representation:
- # Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods.
- # `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/).
- # `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
- # (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
-
- # we don't want to use a type annotation here as it can break get_type_hints
- __slots__ = () # type: typing.Collection[str]
-
- def __repr_args__(self) -> ReprArgs:
- """Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
-
- Can either return:
- * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
- * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
- """
- attrs_names = self.__slots__
- if not attrs_names and hasattr(self, '__dict__'):
- attrs_names = self.__dict__.keys()
- attrs = ((s, getattr(self, s)) for s in attrs_names)
- return [(a, v) for a, v in attrs if v is not None]
-
- def __repr_name__(self) -> str:
- """Name of the instance's class, used in __repr__."""
- return self.__class__.__name__
-
- def __repr_str__(self, join_str: str) -> str:
- return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
-
- def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
- """Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
- yield self.__repr_name__() + '('
- yield 1
- for name, value in self.__repr_args__():
- if name is not None:
- yield name + '='
- yield fmt(value)
- yield ','
- yield 0
- yield -1
- yield ')'
-
- def __rich_repr__(self) -> RichReprResult:
- """Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects."""
- for name, field_repr in self.__repr_args__():
- if name is None:
- yield field_repr
- else:
- yield name, field_repr
-
- def __str__(self) -> str:
- return self.__repr_str__(' ')
-
- def __repr__(self) -> str:
- return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
-
-
-def display_as_type(obj: Any) -> str:
- """Pretty representation of a type, should be as close as possible to the original type definition string.
-
- Takes some logic from `typing._type_repr`.
- """
- if isinstance(obj, types.FunctionType):
- return obj.__name__
- elif obj is ...:
- return '...'
- elif isinstance(obj, Representation):
- return repr(obj)
- elif isinstance(obj, typing_extensions.TypeAliasType):
- return str(obj)
-
- if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
- obj = obj.__class__
-
- if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
- args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
- return f'Union[{args}]'
- elif isinstance(obj, _typing_extra.WithArgsTypes):
- if typing_extensions.get_origin(obj) == typing_extensions.Literal:
- args = ', '.join(map(repr, typing_extensions.get_args(obj)))
- else:
- args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
- try:
- return f'{obj.__qualname__}[{args}]'
- except AttributeError:
- return str(obj) # handles TypeAliasType in 3.12
- elif isinstance(obj, type):
- return obj.__qualname__
- else:
- return repr(obj).replace('typing.', '').replace('typing_extensions.', '')
-
-"""Logic for creating models."""
-
-from __future__ import annotations as _annotations
-
-import operator
-import sys
-import types
-import typing
-import warnings
-from copy import copy, deepcopy
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- ClassVar,
- Dict,
- Generator,
- Literal,
- Mapping,
- Set,
- Tuple,
- TypeVar,
- Union,
- cast,
- overload,
-)
-
-import pydantic_core
-import typing_extensions
-from pydantic_core import PydanticUndefined
-from typing_extensions import Self, TypeAlias, Unpack
-
-from ._internal import (
- _config,
- _decorators,
- _fields,
- _forward_ref,
- _generics,
- _import_utils,
- _mock_val_ser,
- _model_construction,
- _repr,
- _typing_extra,
- _utils,
-)
-from ._migration import getattr_migration
-from .aliases import AliasChoices, AliasPath
-from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
-from .config import ConfigDict
-from .errors import PydanticUndefinedAnnotation, PydanticUserError
-from .json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue, model_json_schema
-from .plugin._schema_validator import PluggableSchemaValidator
-from .warnings import PydanticDeprecatedSince20
-
-if TYPE_CHECKING:
- from inspect import Signature
- from pathlib import Path
-
- from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
-
- from ._internal._utils import AbstractSetIntStr, MappingIntStrAny
- from .deprecated.parse import Protocol as DeprecatedParseProtocol
- from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
-else:
- # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
- # and https://youtrack.jetbrains.com/issue/PY-51428
- DeprecationWarning = PydanticDeprecatedSince20
-
-__all__ = 'BaseModel', 'create_model'
-
-# Keep these type aliases available at runtime:
-TupleGenerator: TypeAlias = Generator[Tuple[str, Any], None, None]
-# Keep this type alias in sync with the stub definition in `pydantic-core`:
-IncEx: TypeAlias = Union[
- Set[int], Set[str], Mapping[int, Union['IncEx', Literal[True]]], Mapping[str, Union['IncEx', Literal[True]]]
-]
-
-_object_setattr = _model_construction.object_setattr
-
-
-class BaseModel(metaclass=_model_construction.ModelMetaclass):
- """Usage docs: https://docs.pydantic.dev/2.9/concepts/models/
-
- A base class for creating Pydantic models.
-
- Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized `__init__` [`Signature`][inspect.Signature] of the model.
-
- __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom `__init__` function.
- __pydantic_decorators__: Metadata containing the decorators defined on the model.
- This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
- __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
- __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
- __pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
- __pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
-
- __pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
- is set to `'allow'`.
- __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
- """
-
- # Class attributes:
- # `model_fields` and `__pydantic_decorators__` must be set for
- # `GenerateSchema.model_schema` to work for a plain `BaseModel` annotation.
-
- model_config: ClassVar[ConfigDict] = ConfigDict()
- """
- Configuration for the model, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict].
- """
-
- # Because `dict` is in the local namespace of the `BaseModel` class, we use `Dict` for annotations.
- # TODO v3 fallback to `dict` when the deprecated `dict` method gets removed.
- model_fields: ClassVar[Dict[str, FieldInfo]] = {} # noqa: UP006
- """
- Metadata about the fields defined on the model,
- mapping of field names to [`FieldInfo`][pydantic.fields.FieldInfo] objects.
-
- This replaces `Model.__fields__` from Pydantic V1.
- """
-
- model_computed_fields: ClassVar[Dict[str, ComputedFieldInfo]] = {} # noqa: UP006
- """A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects."""
-
- __class_vars__: ClassVar[set[str]]
- """The names of the class variables defined on the model."""
-
- __private_attributes__: ClassVar[Dict[str, ModelPrivateAttr]] # noqa: UP006
- """Metadata about the private attributes of the model."""
-
- __signature__: ClassVar[Signature]
- """The synthesized `__init__` [`Signature`][inspect.Signature] of the model."""
-
- __pydantic_complete__: ClassVar[bool] = False
- """Whether model building is completed, or if there are still undefined fields."""
-
- __pydantic_core_schema__: ClassVar[CoreSchema]
- """The core schema of the model."""
-
- __pydantic_custom_init__: ClassVar[bool]
- """Whether the model has a custom `__init__` method."""
-
- __pydantic_decorators__: ClassVar[_decorators.DecoratorInfos] = _decorators.DecoratorInfos()
- """Metadata containing the decorators defined on the model.
- This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1."""
-
- __pydantic_generic_metadata__: ClassVar[_generics.PydanticGenericMetadata]
- """Metadata for generic models; contains data used for a similar purpose to
- __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these."""
-
- __pydantic_parent_namespace__: ClassVar[Dict[str, Any] | None] = None # noqa: UP006
- """Parent namespace of the model, used for automatic rebuilding of models."""
-
- __pydantic_post_init__: ClassVar[None | Literal['model_post_init']]
- """The name of the post-init method for the model, if defined."""
-
- __pydantic_root_model__: ClassVar[bool] = False
- """Whether the model is a [`RootModel`][pydantic.root_model.RootModel]."""
-
- __pydantic_serializer__: ClassVar[SchemaSerializer]
- """The `pydantic-core` `SchemaSerializer` used to dump instances of the model."""
-
- __pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
- """The `pydantic-core` `SchemaValidator` used to validate instances of the model."""
-
- __pydantic_extra__: dict[str, Any] | None = _model_construction.NoInitField(init=False)
- """A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra] is set to `'allow'`."""
-
- __pydantic_fields_set__: set[str] = _model_construction.NoInitField(init=False)
- """The names of fields explicitly set during instantiation."""
-
- __pydantic_private__: dict[str, Any] | None = _model_construction.NoInitField(init=False)
- """Values of private attributes set on the model instance."""
-
- if not TYPE_CHECKING:
- # Prevent `BaseModel` from being instantiated directly
- # (defined in an `if not TYPE_CHECKING` block for clarity and to avoid type checking errors):
- __pydantic_core_schema__ = _mock_val_ser.MockCoreSchema(
- 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly',
- code='base-model-instantiated',
- )
- __pydantic_validator__ = _mock_val_ser.MockValSer(
- 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly',
- val_or_ser='validator',
- code='base-model-instantiated',
- )
- __pydantic_serializer__ = _mock_val_ser.MockValSer(
- 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly',
- val_or_ser='serializer',
- code='base-model-instantiated',
- )
-
- __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__'
-
- def __init__(self, /, **data: Any) -> None:
- """Create a new model by parsing and validating input data from keyword arguments.
-
- Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
- validated to form a valid model.
-
- `self` is explicitly positional-only to allow `self` as a field name.
- """
- # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
- __tracebackhide__ = True
- validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)
- if self is not validated_self:
- warnings.warn(
- 'A custom validator is returning a value other than `self`.\n'
- "Returning anything other than `self` from a top level model validator isn't supported when validating via `__init__`.\n"
- 'See the `model_validator` docs (https://docs.pydantic.dev/latest/concepts/validators/#model-validators) for more details.',
- category=None,
- )
-
- # The following line sets a flag that we use to determine when `__init__` gets overridden by the user
- __init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
-
- @property
- def model_extra(self) -> dict[str, Any] | None:
- """Get extra fields set during validation.
-
- Returns:
- A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`.
- """
- return self.__pydantic_extra__
-
- @property
- def model_fields_set(self) -> set[str]:
- """Returns the set of fields that have been explicitly set on this model instance.
-
- Returns:
- A set of strings representing the fields that have been set,
- i.e. that were not filled from defaults.
- """
- return self.__pydantic_fields_set__
-
- @classmethod
- def model_construct(cls, _fields_set: set[str] | None = None, **values: Any) -> Self: # noqa: C901
- """Creates a new instance of the `Model` class with validated data.
-
- Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
- Default values are respected, but no other validation is performed.
-
- !!! note
- `model_construct()` generally respects the `model_config.extra` setting on the provided model.
- That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
- and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
- Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
- an error if extra values are passed, but they will be ignored.
-
- Args:
- _fields_set: A set of field names that were originally explicitly set during instantiation. If provided,
- this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute.
- Otherwise, the field names from the `values` argument will be used.
- values: Trusted or pre-validated data dictionary.
-
- Returns:
- A new instance of the `Model` class with validated data.
- """
- m = cls.__new__(cls)
- fields_values: dict[str, Any] = {}
- fields_set = set()
-
- for name, field in cls.model_fields.items():
- if field.alias is not None and field.alias in values:
- fields_values[name] = values.pop(field.alias)
- fields_set.add(name)
-
- if (name not in fields_set) and (field.validation_alias is not None):
- validation_aliases: list[str | AliasPath] = (
- field.validation_alias.choices
- if isinstance(field.validation_alias, AliasChoices)
- else [field.validation_alias]
- )
-
- for alias in validation_aliases:
- if isinstance(alias, str) and alias in values:
- fields_values[name] = values.pop(alias)
- fields_set.add(name)
- break
- elif isinstance(alias, AliasPath):
- value = alias.search_dict_for_path(values)
- if value is not PydanticUndefined:
- fields_values[name] = value
- fields_set.add(name)
- break
-
- if name not in fields_set:
- if name in values:
- fields_values[name] = values.pop(name)
- fields_set.add(name)
- elif not field.is_required():
- fields_values[name] = field.get_default(call_default_factory=True)
- if _fields_set is None:
- _fields_set = fields_set
-
- _extra: dict[str, Any] | None = values if cls.model_config.get('extra') == 'allow' else None
- _object_setattr(m, '__dict__', fields_values)
- _object_setattr(m, '__pydantic_fields_set__', _fields_set)
- if not cls.__pydantic_root_model__:
- _object_setattr(m, '__pydantic_extra__', _extra)
-
- if cls.__pydantic_post_init__:
- m.model_post_init(None)
- # update private attributes with values set
- if hasattr(m, '__pydantic_private__') and m.__pydantic_private__ is not None:
- for k, v in values.items():
- if k in m.__private_attributes__:
- m.__pydantic_private__[k] = v
-
- elif not cls.__pydantic_root_model__:
- # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
- # Since it doesn't, that means that `__pydantic_private__` should be set to None
- _object_setattr(m, '__pydantic_private__', None)
-
- return m
-
- def model_copy(self, *, update: dict[str, Any] | None = None, deep: bool = False) -> Self:
- """Usage docs: https://docs.pydantic.dev/2.9/concepts/serialization/#model_copy
-
- Returns a copy of the model.
-
- Args:
- update: Values to change/add in the new model. Note: the data is not validated
- before creating the new model. You should trust this data.
- deep: Set to `True` to make a deep copy of the model.
-
- Returns:
- New model instance.
- """
- copied = self.__deepcopy__() if deep else self.__copy__()
- if update:
- if self.model_config.get('extra') == 'allow':
- for k, v in update.items():
- if k in self.model_fields:
- copied.__dict__[k] = v
- else:
- if copied.__pydantic_extra__ is None:
- copied.__pydantic_extra__ = {}
- copied.__pydantic_extra__[k] = v
- else:
- copied.__dict__.update(update)
- copied.__pydantic_fields_set__.update(update.keys())
- return copied
-
- def model_dump(
- self,
- *,
- mode: Literal['json', 'python'] | str = 'python',
- include: IncEx | None = None,
- exclude: IncEx | None = None,
- context: Any | None = None,
- by_alias: bool = False,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- round_trip: bool = False,
- warnings: bool | Literal['none', 'warn', 'error'] = True,
- serialize_as_any: bool = False,
- ) -> dict[str, Any]:
- """Usage docs: https://docs.pydantic.dev/2.9/concepts/serialization/#modelmodel_dump
-
- Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
-
- Args:
- mode: The mode in which `to_python` should run.
- If mode is 'json', the output will only contain JSON serializable types.
- If mode is 'python', the output may contain non-JSON-serializable Python objects.
- include: A set of fields to include in the output.
- exclude: A set of fields to exclude from the output.
- context: Additional context to pass to the serializer.
- by_alias: Whether to use the field's alias in the dictionary key if defined.
- exclude_unset: Whether to exclude fields that have not been explicitly set.
- exclude_defaults: Whether to exclude fields that are set to their default value.
- exclude_none: Whether to exclude fields that have a value of `None`.
- round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
- warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
- "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
- serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
-
- Returns:
- A dictionary representation of the model.
- """
- return self.__pydantic_serializer__.to_python(
- self,
- mode=mode,
- by_alias=by_alias,
- include=include,
- exclude=exclude,
- context=context,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- round_trip=round_trip,
- warnings=warnings,
- serialize_as_any=serialize_as_any,
- )
-
- def model_dump_json(
- self,
- *,
- indent: int | None = None,
- include: IncEx | None = None,
- exclude: IncEx | None = None,
- context: Any | None = None,
- by_alias: bool = False,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- round_trip: bool = False,
- warnings: bool | Literal['none', 'warn', 'error'] = True,
- serialize_as_any: bool = False,
- ) -> str:
- """Usage docs: https://docs.pydantic.dev/2.9/concepts/serialization/#modelmodel_dump_json
-
- Generates a JSON representation of the model using Pydantic's `to_json` method.
-
- Args:
- indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
- include: Field(s) to include in the JSON output.
- exclude: Field(s) to exclude from the JSON output.
- context: Additional context to pass to the serializer.
- by_alias: Whether to serialize using field aliases.
- exclude_unset: Whether to exclude fields that have not been explicitly set.
- exclude_defaults: Whether to exclude fields that are set to their default value.
- exclude_none: Whether to exclude fields that have a value of `None`.
- round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
- warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
- "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
- serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
-
- Returns:
- A JSON string representation of the model.
- """
- return self.__pydantic_serializer__.to_json(
- self,
- indent=indent,
- include=include,
- exclude=exclude,
- context=context,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- round_trip=round_trip,
- warnings=warnings,
- serialize_as_any=serialize_as_any,
- ).decode()
-
- @classmethod
- def model_json_schema(
- cls,
- by_alias: bool = True,
- ref_template: str = DEFAULT_REF_TEMPLATE,
- schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
- mode: JsonSchemaMode = 'validation',
- ) -> dict[str, Any]:
- """Generates a JSON schema for a model class.
-
- Args:
- by_alias: Whether to use attribute aliases or not.
- ref_template: The reference template.
- schema_generator: To override the logic used to generate the JSON schema, as a subclass of
- `GenerateJsonSchema` with your desired modifications
- mode: The mode in which to generate the schema.
-
- Returns:
- The JSON schema for the given model class.
- """
- return model_json_schema(
- cls, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator, mode=mode
- )
-
- @classmethod
- def model_parametrized_name(cls, params: tuple[type[Any], ...]) -> str:
- """Compute the class name for parametrizations of generic classes.
-
- This method can be overridden to achieve a custom naming scheme for generic BaseModels.
-
- Args:
- params: Tuple of types of the class. Given a generic class
- `Model` with 2 type variables and a concrete model `Model[str, int]`,
- the value `(str, int)` would be passed to `params`.
-
- Returns:
- String representing the new class where `params` are passed to `cls` as type variables.
-
- Raises:
- TypeError: Raised when trying to generate concrete names for non-generic models.
- """
- if not issubclass(cls, typing.Generic):
- raise TypeError('Concrete names should only be generated for generic models.')
-
- # Any strings received should represent forward references, so we handle them specially below.
- # If we eventually move toward wrapping them in a ForwardRef in __class_getitem__ in the future,
- # we may be able to remove this special case.
- param_names = [param if isinstance(param, str) else _repr.display_as_type(param) for param in params]
- params_component = ', '.join(param_names)
- return f'{cls.__name__}[{params_component}]'
-
- def model_post_init(self, __context: Any) -> None:
- """Override this method to perform additional initialization after `__init__` and `model_construct`.
- This is useful if you want to do some validation that requires the entire model to be initialized.
- """
- pass
-
- @classmethod
- def model_rebuild(
- cls,
- *,
- force: bool = False,
- raise_errors: bool = True,
- _parent_namespace_depth: int = 2,
- _types_namespace: dict[str, Any] | None = None,
- ) -> bool | None:
- """Try to rebuild the pydantic-core schema for the model.
-
- This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
- the initial attempt to build the schema, and automatic rebuilding fails.
-
- Args:
- force: Whether to force the rebuilding of the model schema, defaults to `False`.
- raise_errors: Whether to raise errors, defaults to `True`.
- _parent_namespace_depth: The depth level of the parent namespace, defaults to 2.
- _types_namespace: The types namespace, defaults to `None`.
-
- Returns:
- Returns `None` if the schema is already "complete" and rebuilding was not required.
- If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
- """
- if not force and cls.__pydantic_complete__:
- return None
- else:
- if '__pydantic_core_schema__' in cls.__dict__:
- delattr(cls, '__pydantic_core_schema__') # delete cached value to ensure full rebuild happens
- if _types_namespace is not None:
- types_namespace: dict[str, Any] | None = _types_namespace.copy()
- else:
- if _parent_namespace_depth > 0:
- frame_parent_ns = (
- _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
- )
- cls_parent_ns = (
- _model_construction.unpack_lenient_weakvaluedict(cls.__pydantic_parent_namespace__) or {}
- )
- types_namespace = {**cls_parent_ns, **frame_parent_ns}
- cls.__pydantic_parent_namespace__ = _model_construction.build_lenient_weakvaluedict(types_namespace)
- else:
- types_namespace = _model_construction.unpack_lenient_weakvaluedict(
- cls.__pydantic_parent_namespace__
- )
-
- types_namespace = _typing_extra.merge_cls_and_parent_ns(cls, types_namespace)
-
- # manually override defer_build so complete_model_class doesn't skip building the model again
- config = {**cls.model_config, 'defer_build': False}
- return _model_construction.complete_model_class(
- cls,
- cls.__name__,
- _config.ConfigWrapper(config, check=False),
- raise_errors=raise_errors,
- types_namespace=types_namespace,
- )
-
- @classmethod
- def model_validate(
- cls,
- obj: Any,
- *,
- strict: bool | None = None,
- from_attributes: bool | None = None,
- context: Any | None = None,
- ) -> Self:
- """Validate a pydantic model instance.
-
- Args:
- obj: The object to validate.
- strict: Whether to enforce types strictly.
- from_attributes: Whether to extract data from object attributes.
- context: Additional context to pass to the validator.
-
- Raises:
- ValidationError: If the object could not be validated.
-
- Returns:
- The validated model instance.
- """
- # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
- __tracebackhide__ = True
- return cls.__pydantic_validator__.validate_python(
- obj, strict=strict, from_attributes=from_attributes, context=context
- )
-
- @classmethod
- def model_validate_json(
- cls,
- json_data: str | bytes | bytearray,
- *,
- strict: bool | None = None,
- context: Any | None = None,
- ) -> Self:
- """Usage docs: https://docs.pydantic.dev/2.9/concepts/json/#json-parsing
-
- Validate the given JSON data against the Pydantic model.
-
- Args:
- json_data: The JSON data to validate.
- strict: Whether to enforce types strictly.
- context: Extra variables to pass to the validator.
-
- Returns:
- The validated Pydantic model.
-
- Raises:
- ValidationError: If `json_data` is not a JSON string or the object could not be validated.
- """
- # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
- __tracebackhide__ = True
- return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
-
- @classmethod
- def model_validate_strings(
- cls,
- obj: Any,
- *,
- strict: bool | None = None,
- context: Any | None = None,
- ) -> Self:
- """Validate the given object with string data against the Pydantic model.
-
- Args:
- obj: The object containing string data to validate.
- strict: Whether to enforce types strictly.
- context: Extra variables to pass to the validator.
-
- Returns:
- The validated Pydantic model.
- """
- # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
- __tracebackhide__ = True
- return cls.__pydantic_validator__.validate_strings(obj, strict=strict, context=context)
-
- @classmethod
- def __get_pydantic_core_schema__(cls, source: type[BaseModel], handler: GetCoreSchemaHandler, /) -> CoreSchema:
- """Hook into generating the model's CoreSchema.
-
- Args:
- source: The class we are generating a schema for.
- This will generally be the same as the `cls` argument if this is a classmethod.
- handler: A callable that calls into Pydantic's internal CoreSchema generation logic.
-
- Returns:
- A `pydantic-core` `CoreSchema`.
- """
- # Only use the cached value from this _exact_ class; we don't want one from a parent class
- # This is why we check `cls.__dict__` and don't use `cls.__pydantic_core_schema__` or similar.
- schema = cls.__dict__.get('__pydantic_core_schema__')
- if schema is not None and not isinstance(schema, _mock_val_ser.MockCoreSchema):
- # Due to the way generic classes are built, it's possible that an invalid schema may be temporarily
- # set on generic classes. I think we could resolve this to ensure that we get proper schema caching
- # for generics, but for simplicity for now, we just always rebuild if the class has a generic origin.
- if not cls.__pydantic_generic_metadata__['origin']:
- return cls.__pydantic_core_schema__
-
- return handler(source)
-
- @classmethod
- def __get_pydantic_json_schema__(
- cls,
- core_schema: CoreSchema,
- handler: GetJsonSchemaHandler,
- /,
- ) -> JsonSchemaValue:
- """Hook into generating the model's JSON schema.
-
- Args:
- core_schema: A `pydantic-core` CoreSchema.
- You can ignore this argument and call the handler with a new CoreSchema,
- wrap this CoreSchema (`{'type': 'nullable', 'schema': current_schema}`),
- or just call the handler with the original schema.
- handler: Call into Pydantic's internal JSON schema generation.
- This will raise a `pydantic.errors.PydanticInvalidForJsonSchema` if JSON schema
- generation fails.
- Since this gets called by `BaseModel.model_json_schema` you can override the
- `schema_generator` argument to that function to change JSON schema generation globally
- for a type.
-
- Returns:
- A JSON schema, as a Python object.
- """
- return handler(core_schema)
-
- @classmethod
- def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
- """This is intended to behave just like `__init_subclass__`, but is called by `ModelMetaclass`
- only after the class is actually fully initialized. In particular, attributes like `model_fields` will
- be present when this is called.
-
- This is necessary because `__init_subclass__` will always be called by `type.__new__`,
- and it would require a prohibitively large refactor to the `ModelMetaclass` to ensure that
- `type.__new__` was called in such a manner that the class would already be sufficiently initialized.
-
- This will receive the same `kwargs` that would be passed to the standard `__init_subclass__`, namely,
- any kwargs passed to the class definition that aren't used internally by pydantic.
-
- Args:
- **kwargs: Any keyword arguments passed to the class definition that aren't used internally
- by pydantic.
- """
- pass
-
- def __class_getitem__(
- cls, typevar_values: type[Any] | tuple[type[Any], ...]
- ) -> type[BaseModel] | _forward_ref.PydanticRecursiveRef:
- cached = _generics.get_cached_generic_type_early(cls, typevar_values)
- if cached is not None:
- return cached
-
- if cls is BaseModel:
- raise TypeError('Type parameters should be placed on typing.Generic, not BaseModel')
- if not hasattr(cls, '__parameters__'):
- raise TypeError(f'{cls} cannot be parametrized because it does not inherit from typing.Generic')
- if not cls.__pydantic_generic_metadata__['parameters'] and typing.Generic not in cls.__bases__:
- raise TypeError(f'{cls} is not a generic class')
-
- if not isinstance(typevar_values, tuple):
- typevar_values = (typevar_values,)
- _generics.check_parameters_count(cls, typevar_values)
-
- # Build map from generic typevars to passed params
- typevars_map: dict[_typing_extra.TypeVarType, type[Any]] = dict(
- zip(cls.__pydantic_generic_metadata__['parameters'], typevar_values)
- )
-
- if _utils.all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
- submodel = cls # if arguments are equal to parameters it's the same object
- _generics.set_cached_generic_type(cls, typevar_values, submodel)
- else:
- parent_args = cls.__pydantic_generic_metadata__['args']
- if not parent_args:
- args = typevar_values
- else:
- args = tuple(_generics.replace_types(arg, typevars_map) for arg in parent_args)
-
- origin = cls.__pydantic_generic_metadata__['origin'] or cls
- model_name = origin.model_parametrized_name(args)
- params = tuple(
- {param: None for param in _generics.iter_contained_typevars(typevars_map.values())}
- ) # use dict as ordered set
-
- with _generics.generic_recursion_self_type(origin, args) as maybe_self_type:
- if maybe_self_type is not None:
- return maybe_self_type
-
- cached = _generics.get_cached_generic_type_late(cls, typevar_values, origin, args)
- if cached is not None:
- return cached
-
- # Attempt to rebuild the origin in case new types have been defined
- try:
- # depth 3 gets you above this __class_getitem__ call
- origin.model_rebuild(_parent_namespace_depth=3)
- except PydanticUndefinedAnnotation:
- # It's okay if it fails, it just means there are still undefined types
- # that could be evaluated later.
- # TODO: Make sure validation fails if there are still undefined types, perhaps using MockValidator
- pass
-
- submodel = _generics.create_generic_submodel(model_name, origin, args, params)
-
- # Update cache
- _generics.set_cached_generic_type(cls, typevar_values, submodel, origin, args)
-
- return submodel
-
- def __copy__(self) -> Self:
- """Returns a shallow copy of the model."""
- cls = type(self)
- m = cls.__new__(cls)
- _object_setattr(m, '__dict__', copy(self.__dict__))
- _object_setattr(m, '__pydantic_extra__', copy(self.__pydantic_extra__))
- _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
-
- if not hasattr(self, '__pydantic_private__') or self.__pydantic_private__ is None:
- _object_setattr(m, '__pydantic_private__', None)
- else:
- _object_setattr(
- m,
- '__pydantic_private__',
- {k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined},
- )
-
- return m
-
- def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
- """Returns a deep copy of the model."""
- cls = type(self)
- m = cls.__new__(cls)
- _object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
- _object_setattr(m, '__pydantic_extra__', deepcopy(self.__pydantic_extra__, memo=memo))
- # This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
- # and attempting a deepcopy would be marginally slower.
- _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
-
- if not hasattr(self, '__pydantic_private__') or self.__pydantic_private__ is None:
- _object_setattr(m, '__pydantic_private__', None)
- else:
- _object_setattr(
- m,
- '__pydantic_private__',
- deepcopy({k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined}, memo=memo),
- )
-
- return m
-
- if not TYPE_CHECKING:
- # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
- # The same goes for __setattr__ and __delattr__, see: https://github.com/pydantic/pydantic/issues/8643
-
- def __getattr__(self, item: str) -> Any:
- private_attributes = object.__getattribute__(self, '__private_attributes__')
- if item in private_attributes:
- attribute = private_attributes[item]
- if hasattr(attribute, '__get__'):
- return attribute.__get__(self, type(self)) # type: ignore
-
- try:
- # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items
- return self.__pydantic_private__[item] # type: ignore
- except KeyError as exc:
- raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc
- else:
- # `__pydantic_extra__` can fail to be set if the model is not yet fully initialized.
- # See `BaseModel.__repr_args__` for more details
- try:
- pydantic_extra = object.__getattribute__(self, '__pydantic_extra__')
- except AttributeError:
- pydantic_extra = None
-
- if pydantic_extra:
- try:
- return pydantic_extra[item]
- except KeyError as exc:
- raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc
- else:
- if hasattr(self.__class__, item):
- return super().__getattribute__(item) # Raises AttributeError if appropriate
- else:
- # this is the current error
- raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
-
- def __setattr__(self, name: str, value: Any) -> None:
- if name in self.__class_vars__:
- raise AttributeError(
- f'{name!r} is a ClassVar of `{self.__class__.__name__}` and cannot be set on an instance. '
- f'If you want to set a value on the class, use `{self.__class__.__name__}.{name} = value`.'
- )
- elif not _fields.is_valid_field_name(name):
- if self.__pydantic_private__ is None or name not in self.__private_attributes__:
- _object_setattr(self, name, value)
- else:
- attribute = self.__private_attributes__[name]
- if hasattr(attribute, '__set__'):
- attribute.__set__(self, value) # type: ignore
- else:
- self.__pydantic_private__[name] = value
- return
-
- self._check_frozen(name, value)
-
- attr = getattr(self.__class__, name, None)
- if isinstance(attr, property):
- attr.__set__(self, value)
- elif self.model_config.get('validate_assignment', None):
- self.__pydantic_validator__.validate_assignment(self, name, value)
- elif self.model_config.get('extra') != 'allow' and name not in self.model_fields:
- # TODO - matching error
- raise ValueError(f'"{self.__class__.__name__}" object has no field "{name}"')
- elif self.model_config.get('extra') == 'allow' and name not in self.model_fields:
- if self.model_extra and name in self.model_extra:
- self.__pydantic_extra__[name] = value # type: ignore
- else:
- try:
- getattr(self, name)
- except AttributeError:
- # attribute does not already exist on instance, so put it in extra
- self.__pydantic_extra__[name] = value # type: ignore
- else:
- # attribute _does_ already exist on instance, and was not in extra, so update it
- _object_setattr(self, name, value)
- else:
- self.__dict__[name] = value
- self.__pydantic_fields_set__.add(name)
-
- def __delattr__(self, item: str) -> Any:
- if item in self.__private_attributes__:
- attribute = self.__private_attributes__[item]
- if hasattr(attribute, '__delete__'):
- attribute.__delete__(self) # type: ignore
- return
-
- try:
- # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items
- del self.__pydantic_private__[item] # type: ignore
- return
- except KeyError as exc:
- raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc
-
- self._check_frozen(item, None)
-
- if item in self.model_fields:
- object.__delattr__(self, item)
- elif self.__pydantic_extra__ is not None and item in self.__pydantic_extra__:
- del self.__pydantic_extra__[item]
- else:
- try:
- object.__delattr__(self, item)
- except AttributeError:
- raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
-
- def _check_frozen(self, name: str, value: Any) -> None:
- if self.model_config.get('frozen', None):
- typ = 'frozen_instance'
- elif getattr(self.model_fields.get(name), 'frozen', False):
- typ = 'frozen_field'
- else:
- return
- error: pydantic_core.InitErrorDetails = {
- 'type': typ,
- 'loc': (name,),
- 'input': value,
- }
- raise pydantic_core.ValidationError.from_exception_data(self.__class__.__name__, [error])
-
- def __getstate__(self) -> dict[Any, Any]:
- private = self.__pydantic_private__
- if private:
- private = {k: v for k, v in private.items() if v is not PydanticUndefined}
- return {
- '__dict__': self.__dict__,
- '__pydantic_extra__': self.__pydantic_extra__,
- '__pydantic_fields_set__': self.__pydantic_fields_set__,
- '__pydantic_private__': private,
- }
-
- def __setstate__(self, state: dict[Any, Any]) -> None:
- _object_setattr(self, '__pydantic_fields_set__', state.get('__pydantic_fields_set__', {}))
- _object_setattr(self, '__pydantic_extra__', state.get('__pydantic_extra__', {}))
- _object_setattr(self, '__pydantic_private__', state.get('__pydantic_private__', {}))
- _object_setattr(self, '__dict__', state.get('__dict__', {}))
-
- if not TYPE_CHECKING:
-
- def __eq__(self, other: Any) -> bool:
- if isinstance(other, BaseModel):
- # When comparing instances of generic types for equality, as long as all field values are equal,
- # only require their generic origin types to be equal, rather than exact type equality.
- # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
- self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
- other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__
-
- # Perform common checks first
- if not (
- self_type == other_type
- and getattr(self, '__pydantic_private__', None) == getattr(other, '__pydantic_private__', None)
- and self.__pydantic_extra__ == other.__pydantic_extra__
- ):
- return False
-
- # We only want to compare pydantic fields but ignoring fields is costly.
- # We'll perform a fast check first, and fallback only when needed
- # See GH-7444 and GH-7825 for rationale and a performance benchmark
-
- # First, do the fast (and sometimes faulty) __dict__ comparison
- if self.__dict__ == other.__dict__:
- # If the check above passes, then pydantic fields are equal, we can return early
- return True
-
- # We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return
- # early if there are no keys to ignore (we would just return False later on anyway)
- model_fields = type(self).model_fields.keys()
- if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields:
- return False
-
- # If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore
- # Resort to costly filtering of the __dict__ objects
- # We use operator.itemgetter because it is much faster than dict comprehensions
- # NOTE: Contrary to standard python class and instances, when the Model class has a default value for an
- # attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute
- # raises an error in BaseModel.__getattr__ instead of returning the class attribute
- # So we can use operator.itemgetter() instead of operator.attrgetter()
- getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL
- try:
- return getter(self.__dict__) == getter(other.__dict__)
- except KeyError:
- # In rare cases (such as when using the deprecated BaseModel.copy() method),
- # the __dict__ may not contain all model fields, which is how we can get here.
- # getter(self.__dict__) is much faster than any 'safe' method that accounts
- # for missing keys, and wrapping it in a `try` doesn't slow things down much
- # in the common case.
- self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__)
- other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__)
- return getter(self_fields_proxy) == getter(other_fields_proxy)
-
- # other instance is not a BaseModel
- else:
- return NotImplemented # delegate to the other item in the comparison
-
- if TYPE_CHECKING:
- # We put `__init_subclass__` in a TYPE_CHECKING block because, even though we want the type-checking benefits
- # described in the signature of `__init_subclass__` below, we don't want to modify the default behavior of
- # subclass initialization.
-
- def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]):
- """This signature is included purely to help type-checkers check arguments to class declaration, which
- provides a way to conveniently set model_config key/value pairs.
-
- ```py
- from pydantic import BaseModel
-
- class MyModel(BaseModel, extra='allow'): ...
- ```
-
- However, this may be deceiving, since the _actual_ calls to `__init_subclass__` will not receive any
- of the config arguments, and will only receive any keyword arguments passed during class initialization
- that are _not_ expected keys in ConfigDict. (This is due to the way `ModelMetaclass.__new__` works.)
-
- Args:
- **kwargs: Keyword arguments passed to the class definition, which set model_config
-
- Note:
- You may want to override `__pydantic_init_subclass__` instead, which behaves similarly but is called
- *after* the class is fully initialized.
- """
-
- def __iter__(self) -> TupleGenerator:
- """So `dict(model)` works."""
- yield from [(k, v) for (k, v) in self.__dict__.items() if not k.startswith('_')]
- extra = self.__pydantic_extra__
- if extra:
- yield from extra.items()
-
- def __repr__(self) -> str:
- return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
-
- def __repr_args__(self) -> _repr.ReprArgs:
- for k, v in self.__dict__.items():
- field = self.model_fields.get(k)
- if field and field.repr:
- yield k, v
-
- # `__pydantic_extra__` can fail to be set if the model is not yet fully initialized.
- # This can happen if a `ValidationError` is raised during initialization and the instance's
- # repr is generated as part of the exception handling. Therefore, we use `getattr` here
- # with a fallback, even though the type hints indicate the attribute will always be present.
- try:
- pydantic_extra = object.__getattribute__(self, '__pydantic_extra__')
- except AttributeError:
- pydantic_extra = None
-
- if pydantic_extra is not None:
- yield from ((k, v) for k, v in pydantic_extra.items())
- yield from ((k, getattr(self, k)) for k, v in self.model_computed_fields.items() if v.repr)
-
- # take logic from `_repr.Representation` without the side effects of inheritance, see #5740
- __repr_name__ = _repr.Representation.__repr_name__
- __repr_str__ = _repr.Representation.__repr_str__
- __pretty__ = _repr.Representation.__pretty__
- __rich_repr__ = _repr.Representation.__rich_repr__
-
- def __str__(self) -> str:
- return self.__repr_str__(' ')
-
- # ##### Deprecated methods from v1 #####
- @property
- @typing_extensions.deprecated(
- 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None
- )
- def __fields__(self) -> dict[str, FieldInfo]:
- warnings.warn(
- 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
- )
- return self.model_fields
-
- @property
- @typing_extensions.deprecated(
- 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.',
- category=None,
- )
- def __fields_set__(self) -> set[str]:
- warnings.warn(
- 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.',
- category=PydanticDeprecatedSince20,
- )
- return self.__pydantic_fields_set__
-
- @typing_extensions.deprecated('The `dict` method is deprecated; use `model_dump` instead.', category=None)
- def dict( # noqa: D102
- self,
- *,
- include: IncEx | None = None,
- exclude: IncEx | None = None,
- by_alias: bool = False,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- ) -> Dict[str, Any]: # noqa UP006
- warnings.warn('The `dict` method is deprecated; use `model_dump` instead.', category=PydanticDeprecatedSince20)
- return self.model_dump(
- include=include,
- exclude=exclude,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- )
-
- @typing_extensions.deprecated('The `json` method is deprecated; use `model_dump_json` instead.', category=None)
- def json( # noqa: D102
- self,
- *,
- include: IncEx | None = None,
- exclude: IncEx | None = None,
- by_alias: bool = False,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- encoder: Callable[[Any], Any] | None = PydanticUndefined, # type: ignore[assignment]
- models_as_dict: bool = PydanticUndefined, # type: ignore[assignment]
- **dumps_kwargs: Any,
- ) -> str:
- warnings.warn(
- 'The `json` method is deprecated; use `model_dump_json` instead.', category=PydanticDeprecatedSince20
- )
- if encoder is not PydanticUndefined:
- raise TypeError('The `encoder` argument is no longer supported; use field serializers instead.')
- if models_as_dict is not PydanticUndefined:
- raise TypeError('The `models_as_dict` argument is no longer supported; use a model serializer instead.')
- if dumps_kwargs:
- raise TypeError('`dumps_kwargs` keyword arguments are no longer supported.')
- return self.model_dump_json(
- include=include,
- exclude=exclude,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- )
-
- @classmethod
- @typing_extensions.deprecated('The `parse_obj` method is deprecated; use `model_validate` instead.', category=None)
- def parse_obj(cls, obj: Any) -> Self: # noqa: D102
- warnings.warn(
- 'The `parse_obj` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20
- )
- return cls.model_validate(obj)
-
- @classmethod
- @typing_extensions.deprecated(
- 'The `parse_raw` method is deprecated; if your data is JSON use `model_validate_json`, '
- 'otherwise load the data then use `model_validate` instead.',
- category=None,
- )
- def parse_raw( # noqa: D102
- cls,
- b: str | bytes,
- *,
- content_type: str | None = None,
- encoding: str = 'utf8',
- proto: DeprecatedParseProtocol | None = None,
- allow_pickle: bool = False,
- ) -> Self: # pragma: no cover
- warnings.warn(
- 'The `parse_raw` method is deprecated; if your data is JSON use `model_validate_json`, '
- 'otherwise load the data then use `model_validate` instead.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import parse
-
- try:
- obj = parse.load_str_bytes(
- b,
- proto=proto,
- content_type=content_type,
- encoding=encoding,
- allow_pickle=allow_pickle,
- )
- except (ValueError, TypeError) as exc:
- import json
-
- # try to match V1
- if isinstance(exc, UnicodeDecodeError):
- type_str = 'value_error.unicodedecode'
- elif isinstance(exc, json.JSONDecodeError):
- type_str = 'value_error.jsondecode'
- elif isinstance(exc, ValueError):
- type_str = 'value_error'
- else:
- type_str = 'type_error'
-
- # ctx is missing here, but since we've added `input` to the error, we're not pretending it's the same
- error: pydantic_core.InitErrorDetails = {
- # The type: ignore on the next line is to ignore the requirement of LiteralString
- 'type': pydantic_core.PydanticCustomError(type_str, str(exc)), # type: ignore
- 'loc': ('__root__',),
- 'input': b,
- }
- raise pydantic_core.ValidationError.from_exception_data(cls.__name__, [error])
- return cls.model_validate(obj)
-
- @classmethod
- @typing_extensions.deprecated(
- 'The `parse_file` method is deprecated; load the data from file, then if your data is JSON '
- 'use `model_validate_json`, otherwise `model_validate` instead.',
- category=None,
- )
- def parse_file( # noqa: D102
- cls,
- path: str | Path,
- *,
- content_type: str | None = None,
- encoding: str = 'utf8',
- proto: DeprecatedParseProtocol | None = None,
- allow_pickle: bool = False,
- ) -> Self:
- warnings.warn(
- 'The `parse_file` method is deprecated; load the data from file, then if your data is JSON '
- 'use `model_validate_json`, otherwise `model_validate` instead.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import parse
-
- obj = parse.load_file(
- path,
- proto=proto,
- content_type=content_type,
- encoding=encoding,
- allow_pickle=allow_pickle,
- )
- return cls.parse_obj(obj)
-
- @classmethod
- @typing_extensions.deprecated(
- 'The `from_orm` method is deprecated; set '
- "`model_config['from_attributes']=True` and use `model_validate` instead.",
- category=None,
- )
- def from_orm(cls, obj: Any) -> Self: # noqa: D102
- warnings.warn(
- 'The `from_orm` method is deprecated; set '
- "`model_config['from_attributes']=True` and use `model_validate` instead.",
- category=PydanticDeprecatedSince20,
- )
- if not cls.model_config.get('from_attributes', None):
- raise PydanticUserError(
- 'You must set the config attribute `from_attributes=True` to use from_orm', code=None
- )
- return cls.model_validate(obj)
-
- @classmethod
- @typing_extensions.deprecated('The `construct` method is deprecated; use `model_construct` instead.', category=None)
- def construct(cls, _fields_set: set[str] | None = None, **values: Any) -> Self: # noqa: D102
- warnings.warn(
- 'The `construct` method is deprecated; use `model_construct` instead.', category=PydanticDeprecatedSince20
- )
- return cls.model_construct(_fields_set=_fields_set, **values)
-
- @typing_extensions.deprecated(
- 'The `copy` method is deprecated; use `model_copy` instead. '
- 'See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`.',
- category=None,
- )
- def copy(
- self,
- *,
- include: AbstractSetIntStr | MappingIntStrAny | None = None,
- exclude: AbstractSetIntStr | MappingIntStrAny | None = None,
- update: Dict[str, Any] | None = None, # noqa UP006
- deep: bool = False,
- ) -> Self: # pragma: no cover
- """Returns a copy of the model.
-
- !!! warning "Deprecated"
- This method is now deprecated; use `model_copy` instead.
-
- If you need `include` or `exclude`, use:
-
- ```py
- data = self.model_dump(include=include, exclude=exclude, round_trip=True)
- data = {**data, **(update or {})}
- copied = self.model_validate(data)
- ```
-
- Args:
- include: Optional set or mapping specifying which fields to include in the copied model.
- exclude: Optional set or mapping specifying which fields to exclude in the copied model.
- update: Optional dictionary of field-value pairs to override field values in the copied model.
- deep: If True, the values of fields that are Pydantic models will be deep-copied.
-
- Returns:
- A copy of the model with included, excluded and updated fields as specified.
- """
- warnings.warn(
- 'The `copy` method is deprecated; use `model_copy` instead. '
- 'See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import copy_internals
-
- values = dict(
- copy_internals._iter(
- self, to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False
- ),
- **(update or {}),
- )
- if self.__pydantic_private__ is None:
- private = None
- else:
- private = {k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined}
-
- if self.__pydantic_extra__ is None:
- extra: dict[str, Any] | None = None
- else:
- extra = self.__pydantic_extra__.copy()
- for k in list(self.__pydantic_extra__):
- if k not in values: # k was in the exclude
- extra.pop(k)
- for k in list(values):
- if k in self.__pydantic_extra__: # k must have come from extra
- extra[k] = values.pop(k)
-
- # new `__pydantic_fields_set__` can have unset optional fields with a set value in `update` kwarg
- if update:
- fields_set = self.__pydantic_fields_set__ | update.keys()
- else:
- fields_set = set(self.__pydantic_fields_set__)
-
- # removing excluded fields from `__pydantic_fields_set__`
- if exclude:
- fields_set -= set(exclude)
-
- return copy_internals._copy_and_set_values(self, values, fields_set, extra, private, deep=deep)
-
- @classmethod
- @typing_extensions.deprecated('The `schema` method is deprecated; use `model_json_schema` instead.', category=None)
- def schema( # noqa: D102
- cls, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE
- ) -> Dict[str, Any]: # noqa UP006
- warnings.warn(
- 'The `schema` method is deprecated; use `model_json_schema` instead.', category=PydanticDeprecatedSince20
- )
- return cls.model_json_schema(by_alias=by_alias, ref_template=ref_template)
-
- @classmethod
- @typing_extensions.deprecated(
- 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.',
- category=None,
- )
- def schema_json( # noqa: D102
- cls, *, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE, **dumps_kwargs: Any
- ) -> str: # pragma: no cover
- warnings.warn(
- 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.',
- category=PydanticDeprecatedSince20,
- )
- import json
-
- from .deprecated.json import pydantic_encoder
-
- return json.dumps(
- cls.model_json_schema(by_alias=by_alias, ref_template=ref_template),
- default=pydantic_encoder,
- **dumps_kwargs,
- )
-
- @classmethod
- @typing_extensions.deprecated('The `validate` method is deprecated; use `model_validate` instead.', category=None)
- def validate(cls, value: Any) -> Self: # noqa: D102
- warnings.warn(
- 'The `validate` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20
- )
- return cls.model_validate(value)
-
- @classmethod
- @typing_extensions.deprecated(
- 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.',
- category=None,
- )
- def update_forward_refs(cls, **localns: Any) -> None: # noqa: D102
- warnings.warn(
- 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.',
- category=PydanticDeprecatedSince20,
- )
- if localns: # pragma: no cover
- raise TypeError('`localns` arguments are not longer accepted.')
- cls.model_rebuild(force=True)
-
- @typing_extensions.deprecated(
- 'The private method `_iter` will be removed and should no longer be used.', category=None
- )
- def _iter(self, *args: Any, **kwargs: Any) -> Any:
- warnings.warn(
- 'The private method `_iter` will be removed and should no longer be used.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import copy_internals
-
- return copy_internals._iter(self, *args, **kwargs)
-
- @typing_extensions.deprecated(
- 'The private method `_copy_and_set_values` will be removed and should no longer be used.',
- category=None,
- )
- def _copy_and_set_values(self, *args: Any, **kwargs: Any) -> Any:
- warnings.warn(
- 'The private method `_copy_and_set_values` will be removed and should no longer be used.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import copy_internals
-
- return copy_internals._copy_and_set_values(self, *args, **kwargs)
-
- @classmethod
- @typing_extensions.deprecated(
- 'The private method `_get_value` will be removed and should no longer be used.',
- category=None,
- )
- def _get_value(cls, *args: Any, **kwargs: Any) -> Any:
- warnings.warn(
- 'The private method `_get_value` will be removed and should no longer be used.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import copy_internals
-
- return copy_internals._get_value(cls, *args, **kwargs)
-
- @typing_extensions.deprecated(
- 'The private method `_calculate_keys` will be removed and should no longer be used.',
- category=None,
- )
- def _calculate_keys(self, *args: Any, **kwargs: Any) -> Any:
- warnings.warn(
- 'The private method `_calculate_keys` will be removed and should no longer be used.',
- category=PydanticDeprecatedSince20,
- )
- from .deprecated import copy_internals
-
- return copy_internals._calculate_keys(self, *args, **kwargs)
-
-
-ModelT = TypeVar('ModelT', bound=BaseModel)
-
-
-@overload
-def create_model(
- model_name: str,
- /,
- *,
- __config__: ConfigDict | None = None,
- __doc__: str | None = None,
- __base__: None = None,
- __module__: str = __name__,
- __validators__: dict[str, Callable[..., Any]] | None = None,
- __cls_kwargs__: dict[str, Any] | None = None,
- **field_definitions: Any,
-) -> type[BaseModel]: ...
-
-
-@overload
-def create_model(
- model_name: str,
- /,
- *,
- __config__: ConfigDict | None = None,
- __doc__: str | None = None,
- __base__: type[ModelT] | tuple[type[ModelT], ...],
- __module__: str = __name__,
- __validators__: dict[str, Callable[..., Any]] | None = None,
- __cls_kwargs__: dict[str, Any] | None = None,
- **field_definitions: Any,
-) -> type[ModelT]: ...
-
-
-def create_model( # noqa: C901
- model_name: str,
- /,
- *,
- __config__: ConfigDict | None = None,
- __doc__: str | None = None,
- __base__: type[ModelT] | tuple[type[ModelT], ...] | None = None,
- __module__: str | None = None,
- __validators__: dict[str, Callable[..., Any]] | None = None,
- __cls_kwargs__: dict[str, Any] | None = None,
- __slots__: tuple[str, ...] | None = None,
- **field_definitions: Any,
-) -> type[ModelT]:
- """Usage docs: https://docs.pydantic.dev/2.9/concepts/models/#dynamic-model-creation
-
- Dynamically creates and returns a new Pydantic model, in other words, `create_model` dynamically creates a
- subclass of [`BaseModel`][pydantic.BaseModel].
-
- Args:
- model_name: The name of the newly created model.
- __config__: The configuration of the new model.
- __doc__: The docstring of the new model.
- __base__: The base class or classes for the new model.
- __module__: The name of the module that the model belongs to;
- if `None`, the value is taken from `sys._getframe(1)`
- __validators__: A dictionary of methods that validate fields. The keys are the names of the validation methods to
- be added to the model, and the values are the validation methods themselves. You can read more about functional
- validators [here](https://docs.pydantic.dev/2.9/concepts/validators/#field-validators).
- __cls_kwargs__: A dictionary of keyword arguments for class creation, such as `metaclass`.
- __slots__: Deprecated. Should not be passed to `create_model`.
- **field_definitions: Attributes of the new model. They should be passed in the format:
- `<name>=(<type>, <default value>)`, `<name>=(<type>, <FieldInfo>)`, or `typing.Annotated[<type>, <FieldInfo>]`.
- Any additional metadata in `typing.Annotated[<type>, <FieldInfo>, ...]` will be ignored.
-
- Returns:
- The new [model][pydantic.BaseModel].
-
- Raises:
- PydanticUserError: If `__base__` and `__config__` are both passed.
- """
- if __slots__ is not None:
- # __slots__ will be ignored from here on
- warnings.warn('__slots__ should not be passed to create_model', RuntimeWarning)
-
- if __base__ is not None:
- if __config__ is not None:
- raise PydanticUserError(
- 'to avoid confusion `__config__` and `__base__` cannot be used together',
- code='create-model-config-base',
- )
- if not isinstance(__base__, tuple):
- __base__ = (__base__,)
- else:
- __base__ = (cast('type[ModelT]', BaseModel),)
-
- __cls_kwargs__ = __cls_kwargs__ or {}
-
- fields = {}
- annotations = {}
-
- for f_name, f_def in field_definitions.items():
- if not _fields.is_valid_field_name(f_name):
- warnings.warn(f'fields may not start with an underscore, ignoring "{f_name}"', RuntimeWarning)
- if isinstance(f_def, tuple):
- f_def = cast('tuple[str, Any]', f_def)
- try:
- f_annotation, f_value = f_def
- except ValueError as e:
- raise PydanticUserError(
- 'Field definitions should be a `(<type>, <default>)`.',
- code='create-model-field-definitions',
- ) from e
-
- elif _typing_extra.is_annotated(f_def):
- (f_annotation, f_value, *_) = typing_extensions.get_args(
- f_def
- ) # first two input are expected from Annotated, refer to https://docs.python.org/3/library/typing.html#typing.Annotated
- FieldInfo = _import_utils.import_cached_field_info()
-
- if not isinstance(f_value, FieldInfo):
- raise PydanticUserError(
- 'Field definitions should be a Annotated[<type>, <FieldInfo>]',
- code='create-model-field-definitions',
- )
-
- else:
- f_annotation, f_value = None, f_def
-
- if f_annotation:
- annotations[f_name] = f_annotation
- fields[f_name] = f_value
-
- if __module__ is None:
- f = sys._getframe(1)
- __module__ = f.f_globals['__name__']
-
- namespace: dict[str, Any] = {'__annotations__': annotations, '__module__': __module__}
- if __doc__:
- namespace.update({'__doc__': __doc__})
- if __validators__:
- namespace.update(__validators__)
- namespace.update(fields)
- if __config__:
- namespace['model_config'] = _config.ConfigWrapper(__config__).config_dict
- resolved_bases = types.resolve_bases(__base__)
- meta, ns, kwds = types.prepare_class(model_name, resolved_bases, kwds=__cls_kwargs__)
- if resolved_bases is not __base__:
- ns['__orig_bases__'] = __base__
- namespace.update(ns)
-
- return meta(
- model_name,
- resolved_bases,
- namespace,
- __pydantic_reset_parent_namespace__=False,
- _create_model_module=__module__,
- **kwds,
- )
-
-
-__getattr__ = getattr_migration(__name__)
-
-from __future__ import annotations
-
-import json
-import random
-import typing as t
-from abc import ABC, abstractmethod
-from collections import defaultdict
-from dataclasses import dataclass, field
-from uuid import UUID
-
-import numpy as np
-import requests
-from datasets import Dataset as HFDataset
-from pydantic import BaseModel, field_validator
-
-from ragas._version import __version__
-from ragas.callbacks import ChainRunEncoder, parse_run_traces
-from ragas.cost import CostCallbackHandler
-from ragas.exceptions import UploadException
-from ragas.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
-from ragas.sdk import (
- RAGAS_API_SOURCE,
- build_evaluation_app_url,
- check_api_response,
- get_api_url,
- get_app_token,
- get_app_url,
- upload_packet,
-)
-from ragas.utils import safe_nanmean
-
-if t.TYPE_CHECKING:
- from pathlib import Path
-
- from datasets import Dataset as HFDataset
- from pandas import DataFrame as PandasDataframe
-
- from ragas.callbacks import ChainRun
- from ragas.cost import TokenUsage
-
-
-class BaseSample(BaseModel):
- """
- Base class for evaluation samples.
- """
-
- def to_dict(self) -> t.Dict:
- """
- Get the dictionary representation of the sample without attributes that are None.
- """
- return self.model_dump(exclude_none=True)
-
- def get_features(self) -> t.List[str]:
- """
- Get the features of the sample that are not None.
- """
- return list(self.to_dict().keys())
-
- def to_string(self) -> str:
- """
- Get the string representation of the sample.
- """
- sample_dict = self.to_dict()
- return "".join(f"\n{key}:\n\t{val}\n" for key, val in sample_dict.items())
-
-
-class SingleTurnSample(BaseSample):
- """
- Represents evaluation samples for single-turn interactions.
-
- Attributes
- ----------
- user_input : Optional[str]
- The input query from the user.
- retrieved_contexts : Optional[List[str]]
- List of contexts retrieved for the query.
- reference_contexts : Optional[List[str]]
- List of reference contexts for the query.
- response : Optional[str]
- The generated response for the query.
- multi_responses : Optional[List[str]]
- List of multiple responses generated for the query.
- reference : Optional[str]
- The reference answer for the query.
- rubric : Optional[Dict[str, str]]
- Evaluation rubric for the sample.
- """
-
- user_input: t.Optional[str] = None
- retrieved_contexts: t.Optional[t.List[str]] = None
- reference_contexts: t.Optional[t.List[str]] = None
- response: t.Optional[str] = None
- multi_responses: t.Optional[t.List[str]] = None
- reference: t.Optional[str] = None
- rubrics: t.Optional[t.Dict[str, str]] = None
-
-
-class MultiTurnSample(BaseSample):
- """
- Represents evaluation samples for multi-turn interactions.
-
- Attributes
- ----------
- user_input : List[Union[HumanMessage, AIMessage, ToolMessage]]
- A list of messages representing the conversation turns.
- reference : Optional[str], optional
- The reference answer or expected outcome for the conversation.
- reference_tool_calls : Optional[List[ToolCall]], optional
- A list of expected tool calls for the conversation.
- rubrics : Optional[Dict[str, str]], optional
- Evaluation rubrics for the conversation.
- reference_topics : Optional[List[str]], optional
- A list of reference topics for the conversation.
- """
-
- user_input: t.List[t.Union[HumanMessage, AIMessage, ToolMessage]]
- reference: t.Optional[str] = None
- reference_tool_calls: t.Optional[t.List[ToolCall]] = None
- rubrics: t.Optional[t.Dict[str, str]] = None
- reference_topics: t.Optional[t.List[str]] = None
-
- @field_validator("user_input")
- @classmethod
- def validate_user_input(
- cls,
- messages: t.List[t.Union[HumanMessage, AIMessage, ToolMessage]],
- ) -> t.List[t.Union[HumanMessage, AIMessage, ToolMessage]]:
- """Validates the user input messages."""
- if not (
- isinstance(m, (HumanMessage, AIMessage, ToolMessage)) for m in messages
- ):
- raise ValueError(
- "All inputs must be instances of HumanMessage, AIMessage, or ToolMessage."
- )
-
- prev_message = None
- for m in messages:
- if isinstance(m, ToolMessage):
- if not isinstance(prev_message, AIMessage):
- raise ValueError(
- "ToolMessage instances must be preceded by an AIMessage instance."
- )
- if prev_message.tool_calls is None:
- raise ValueError(
- f"ToolMessage instances must be preceded by an AIMessage instance with tool_calls. Got {prev_message}"
- )
- prev_message = m
-
- return messages
-
- def to_messages(self):
- """Converts the user input messages to a list of dictionaries."""
- return [m.model_dump() for m in self.user_input]
-
- def pretty_repr(self):
- """Returns a pretty string representation of the conversation."""
- lines = []
- for m in self.user_input:
- lines.append(m.pretty_repr())
-
- return "\n".join(lines)
-
-
-Sample = t.TypeVar("Sample", bound=BaseSample)
-T = t.TypeVar("T", bound="RagasDataset")
-
-
-@dataclass
-class RagasDataset(ABC, t.Generic[Sample]):
- samples: t.List[Sample]
-
- def __post_init__(self):
- self.samples = self.validate_samples(self.samples)
-
- @abstractmethod
- def to_list(self) -> t.List[t.Dict]:
- """Converts the samples to a list of dictionaries."""
- pass
-
- @classmethod
- @abstractmethod
- def from_list(cls: t.Type[T], data: t.List[t.Dict]) -> T:
- """Creates an RagasDataset from a list of dictionaries."""
- pass
-
- def validate_samples(self, samples: t.List[Sample]) -> t.List[Sample]:
- """Validates that all samples are of the same type."""
- if len(samples) == 0:
- return samples
-
- first_sample_type = type(samples[0])
- for i, sample in enumerate(samples):
- if not isinstance(sample, first_sample_type):
- raise ValueError(
- f"Sample at index {i} is of type {type(sample)}, expected {first_sample_type}"
- )
-
- return samples
-
- def get_sample_type(self) -> t.Type[Sample]:
- """Returns the type of the samples in the dataset."""
- return type(self.samples[0])
-
- def to_hf_dataset(self) -> HFDataset:
- """Converts the dataset to a Hugging Face Dataset."""
- try:
- from datasets import Dataset as HFDataset
- except ImportError:
- raise ImportError(
- "datasets is not installed. Please install it to use this function."
- )
-
- return HFDataset.from_list(self.to_list())
-
- @classmethod
- def from_hf_dataset(cls: t.Type[T], dataset: HFDataset) -> T:
- """Creates an EvaluationDataset from a Hugging Face Dataset."""
- return cls.from_list(dataset.to_list())
-
- def to_pandas(self) -> PandasDataframe:
- """Converts the dataset to a pandas DataFrame."""
- try:
- import pandas as pd
- except ImportError:
- raise ImportError(
- "pandas is not installed. Please install it to use this function."
- )
-
- data = self.to_list()
- return pd.DataFrame(data)
-
- @classmethod
- def from_pandas(cls, dataframe: PandasDataframe):
- """Creates an EvaluationDataset from a pandas DataFrame."""
- return cls.from_list(dataframe.to_dict(orient="records"))
-
- def features(self):
- """Returns the features of the samples."""
- return self.samples[0].get_features()
-
- @classmethod
- def from_dict(cls: t.Type[T], mapping: t.Dict) -> T:
- """Creates an EvaluationDataset from a dictionary."""
- samples = []
- if all(
- "user_input" in item and isinstance(mapping[0]["user_input"], list)
- for item in mapping
- ):
- samples.extend(MultiTurnSample(**sample) for sample in mapping)
- else:
- samples.extend(SingleTurnSample(**sample) for sample in mapping)
- return cls(samples=samples)
-
- def to_csv(self, path: t.Union[str, Path]):
- """Converts the dataset to a CSV file."""
- import csv
-
- data = self.to_list()
- if not data:
- return
-
- fieldnames = data[0].keys()
-
- with open(path, "w", newline="") as csvfile:
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
- writer.writeheader()
- for row in data:
- writer.writerow(row)
-
- def to_jsonl(self, path: t.Union[str, Path]):
- """Converts the dataset to a JSONL file."""
- with open(path, "w") as jsonlfile:
- for sample in self.to_list():
- jsonlfile.write(json.dumps(sample, ensure_ascii=False) + "\n")
-
- @classmethod
- def from_jsonl(cls: t.Type[T], path: t.Union[str, Path]) -> T:
- """Creates an EvaluationDataset from a JSONL file."""
- with open(path, "r") as jsonlfile:
- data = [json.loads(line) for line in jsonlfile]
- return cls.from_list(data)
-
- def __iter__(self) -> t.Iterator[Sample]: # type: ignore
- return iter(self.samples)
-
- def __len__(self) -> int:
- return len(self.samples)
-
- def __str__(self) -> str:
- return f"EvaluationDataset(features={self.features()}, len={len(self.samples)})"
-
- def __repr__(self) -> str:
- return self.__str__()
-
-
-SingleTurnSampleOrMultiTurnSample = t.Union[SingleTurnSample, MultiTurnSample]
-
-
-@dataclass
-class EvaluationDataset(RagasDataset[SingleTurnSampleOrMultiTurnSample]):
- """
- Represents a dataset of evaluation samples.
-
- Attributes
- ----------
- samples : List[BaseSample]
- A list of evaluation samples.
-
- Methods
- -------
- validate_samples(samples)
- Validates that all samples are of the same type.
- get_sample_type()
- Returns the type of the samples in the dataset.
- to_hf_dataset()
- Converts the dataset to a Hugging Face Dataset.
- to_pandas()
- Converts the dataset to a pandas DataFrame.
- features()
- Returns the features of the samples.
- from_list(mapping)
- Creates an EvaluationDataset from a list of dictionaries.
- from_dict(mapping)
- Creates an EvaluationDataset from a dictionary.
- to_csv(path)
- Converts the dataset to a CSV file.
- to_jsonl(path)
- Converts the dataset to a JSONL file.
- from_jsonl(path)
- Creates an EvaluationDataset from a JSONL file.
- """
-
- @t.overload
- def __getitem__(self, idx: int) -> SingleTurnSampleOrMultiTurnSample: ...
-
- @t.overload
- def __getitem__(self, idx: slice) -> "EvaluationDataset": ...
-
- def __getitem__(
- self, idx: t.Union[int, slice]
- ) -> t.Union[SingleTurnSampleOrMultiTurnSample, "EvaluationDataset"]:
- if isinstance(idx, int):
- return self.samples[idx]
- elif isinstance(idx, slice):
- return type(self)(samples=self.samples[idx])
- else:
- raise TypeError("Index must be int or slice")
-
- def is_multi_turn(self) -> bool:
- return self.get_sample_type() == MultiTurnSample
-
- def to_list(self) -> t.List[t.Dict]:
- rows = [sample.to_dict() for sample in self.samples]
-
- if self.get_sample_type() == MultiTurnSample:
- for sample in rows:
- for item in sample["user_input"]:
- if not isinstance(item["content"], str):
- item["content"] = json.dumps(
- item["content"], ensure_ascii=False
- )
-
- return rows
-
- @classmethod
- def from_list(cls, data: t.List[t.Dict]) -> EvaluationDataset:
- samples = []
- if all(
- "user_input" in item and isinstance(data[0]["user_input"], list)
- for item in data
- ):
- samples.extend(MultiTurnSample(**sample) for sample in data)
- else:
- samples.extend(SingleTurnSample(**sample) for sample in data)
- return cls(samples=samples)
-
- def __repr__(self) -> str:
- return f"EvaluationDataset(features={self.features()}, len={len(self.samples)})"
-
-
-@dataclass
-class EvaluationResult:
- """
- A class to store and process the results of the evaluation.
-
- Attributes
- ----------
- scores : Dataset
- The dataset containing the scores of the evaluation.
- dataset : Dataset, optional
- The original dataset used for the evaluation. Default is None.
- binary_columns : list of str, optional
- List of columns that are binary metrics. Default is an empty list.
- cost_cb : CostCallbackHandler, optional
- The callback handler for cost computation. Default is None.
- """
-
- scores: t.List[t.Dict[str, t.Any]]
- dataset: EvaluationDataset
- binary_columns: t.List[str] = field(default_factory=list)
- cost_cb: t.Optional[CostCallbackHandler] = None
- traces: t.List[t.Dict[str, t.Any]] = field(default_factory=list)
- ragas_traces: t.Dict[str, ChainRun] = field(default_factory=dict, repr=False)
- run_id: t.Optional[UUID] = None
-
- def __post_init__(self):
- # transform scores from list of dicts to dict of lists
- self._scores_dict = {
- k: [d[k] for d in self.scores] for k in self.scores[0].keys()
- }
-
- values = []
- self._repr_dict = {}
- for metric_name in self._scores_dict.keys():
- value = safe_nanmean(self._scores_dict[metric_name])
- self._repr_dict[metric_name] = value
- if metric_name not in self.binary_columns:
- value = t.cast(float, value)
- values.append(value + 1e-10)
-
- # parse the traces
- run_id = str(self.run_id) if self.run_id is not None else None
- self.traces = parse_run_traces(self.ragas_traces, run_id)
-
- def __repr__(self) -> str:
- score_strs = [f"'{k}': {v:0.4f}" for k, v in self._repr_dict.items()]
- return "{" + ", ".join(score_strs) + "}"
-
- def __getitem__(self, key: str) -> t.List[float]:
- return self._scores_dict[key]
-
- def to_pandas(self, batch_size: int | None = None, batched: bool = False):
- """
- Convert the result to a pandas DataFrame.
-
- Parameters
- ----------
- batch_size : int, optional
- The batch size for conversion. Default is None.
- batched : bool, optional
- Whether to convert in batches. Default is False.
-
- Returns
- -------
- pandas.DataFrame
- The result as a pandas DataFrame.
-
- Raises
- ------
- ValueError
- If the dataset is not provided.
- """
- try:
- import pandas as pd
- except ImportError:
- raise ImportError(
- "pandas is not installed. Please install it to use this function."
- )
-
- if self.dataset is None:
- raise ValueError("dataset is not provided for the results class")
- assert len(self.scores) == len(self.dataset)
- # convert both to pandas dataframes and concatenate
- scores_df = pd.DataFrame(self.scores)
- dataset_df = self.dataset.to_pandas()
- return pd.concat([dataset_df, scores_df], axis=1)
-
- def total_tokens(self) -> t.Union[t.List[TokenUsage], TokenUsage]:
- """
- Compute the total tokens used in the evaluation.
-
- Returns
- -------
- list of TokenUsage or TokenUsage
- The total tokens used.
-
- Raises
- ------
- ValueError
- If the cost callback handler is not provided.
- """
- if self.cost_cb is None:
- raise ValueError(
- "The evaluate() run was not configured for computing cost. Please provide a token_usage_parser function to evaluate() to compute cost."
- )
- return self.cost_cb.total_tokens()
-
- def total_cost(
- self,
- cost_per_input_token: t.Optional[float] = None,
- cost_per_output_token: t.Optional[float] = None,
- per_model_costs: t.Dict[str, t.Tuple[float, float]] = {},
- ) -> float:
- """
- Compute the total cost of the evaluation.
-
- Parameters
- ----------
- cost_per_input_token : float, optional
- The cost per input token. Default is None.
- cost_per_output_token : float, optional
- The cost per output token. Default is None.
- per_model_costs : dict of str to tuple of float, optional
- The per model costs. Default is an empty dictionary.
-
- Returns
- -------
- float
- The total cost of the evaluation.
-
- Raises
- ------
- ValueError
- If the cost callback handler is not provided.
- """
- if self.cost_cb is None:
- raise ValueError(
- "The evaluate() run was not configured for computing cost. Please provide a token_usage_parser function to evaluate() to compute cost."
- )
- return self.cost_cb.total_cost(
- cost_per_input_token, cost_per_output_token, per_model_costs
- )
-
- def upload(
- self,
- verbose: bool = True,
- ) -> str:
- from datetime import datetime, timezone
-
- timestamp = datetime.now(timezone.utc).isoformat()
- root_trace = [
- trace for trace in self.ragas_traces.values() if trace.parent_run_id is None
- ][0]
- packet = json.dumps(
- {
- "run_id": str(root_trace.run_id),
- "created_at": timestamp,
- "evaluation_run": [t.model_dump() for t in self.ragas_traces.values()],
- },
- cls=ChainRunEncoder,
- )
- response = upload_packet(
- path="/alignment/evaluation",
- data_json_string=packet,
- )
-
- # check status codes
- app_url = get_app_url()
- evaluation_app_url = build_evaluation_app_url(app_url, root_trace.run_id)
- if response.status_code == 409:
- # this evalution already exists
- if verbose:
- print(f"Evaluation run already exists. View at {evaluation_app_url}")
- return evaluation_app_url
- elif response.status_code != 200:
- # any other error
- raise UploadException(
- status_code=response.status_code,
- message=f"Failed to upload results: {response.text}",
- )
-
- if verbose:
- print(f"Evaluation results uploaded! View at {evaluation_app_url}")
- return evaluation_app_url
-
-
-class PromptAnnotation(BaseModel):
- prompt_input: t.Dict[str, t.Any]
- prompt_output: t.Dict[str, t.Any]
- edited_output: t.Optional[t.Dict[str, t.Any]] = None
-
- def __getitem__(self, key):
- return getattr(self, key)
-
-
-class SampleAnnotation(BaseModel):
- metric_input: t.Dict[str, t.Any]
- metric_output: float
- prompts: t.Dict[str, PromptAnnotation]
- is_accepted: bool
- target: t.Optional[float] = None
-
- def __getitem__(self, key):
- return getattr(self, key)
-
-
-class MetricAnnotation(BaseModel):
- root: t.Dict[str, t.List[SampleAnnotation]]
-
- def __getitem__(self, key):
- return SingleMetricAnnotation(name=key, samples=self.root[key])
-
- @classmethod
- def _process_dataset(
- cls, dataset: dict, metric_name: t.Optional[str]
- ) -> "MetricAnnotation":
- """
- Process raw dataset into MetricAnnotation format
-
- Parameters
- ----------
- dataset : dict
- Raw dataset to process
- metric_name : str, optional
- Name of the specific metric to filter
-
- Returns
- -------
- MetricAnnotation
- Processed annotation data
- """
- if metric_name is not None and metric_name not in dataset:
- raise ValueError(f"Split {metric_name} not found in the dataset.")
-
- return cls(
- root={
- key: [SampleAnnotation(**sample) for sample in value]
- for key, value in dataset.items()
- if metric_name is None or key == metric_name
- }
- )
-
- @classmethod
- def from_json(cls, path: str, metric_name: t.Optional[str]) -> "MetricAnnotation":
- """Load annotations from a JSON file"""
- dataset = json.load(open(path))
- return cls._process_dataset(dataset, metric_name)
-
- @classmethod
- def from_app(
- cls,
- run_id: str,
- metric_name: t.Optional[str] = None,
- ) -> "MetricAnnotation":
- """
- Fetch annotations from a URL using either evaluation result or run_id
-
- Parameters
- ----------
- run_id : str
- Direct run ID to fetch annotations
- metric_name : str, optional
- Name of the specific metric to filter
-
- Returns
- -------
- MetricAnnotation
- Annotation data from the API
-
- Raises
- ------
- ValueError
- If run_id is not provided
- """
- if run_id is None:
- raise ValueError("run_id must be provided")
-
- endpoint = f"/api/v1/alignment/evaluation/annotation/{run_id}"
-
- app_token = get_app_token()
- base_url = get_api_url()
- app_url = get_app_url()
-
- response = requests.get(
- f"{base_url}{endpoint}",
- headers={
- "Content-Type": "application/json",
- "x-app-token": app_token,
- "x-source": RAGAS_API_SOURCE,
- "x-app-version": __version__,
- },
- )
-
- check_api_response(response)
- dataset = response.json()["data"]
-
- if not dataset:
- evaluation_url = build_evaluation_app_url(app_url, run_id)
- raise ValueError(
- f"No annotations found. Please annotate the Evaluation first then run this method. "
- f"\nNote: you can annotate the evaluations using the Ragas app by going to {evaluation_url}"
- )
-
- return cls._process_dataset(dataset, metric_name)
-
- def __len__(self):
- return sum(len(value) for value in self.root.values())
-
-
-class SingleMetricAnnotation(BaseModel):
- name: str
- samples: t.List[SampleAnnotation]
-
- def to_evaluation_dataset(self) -> EvaluationDataset:
- samples = [sample.metric_input for sample in self.samples]
- return EvaluationDataset.from_list(samples)
-
- def __getitem__(self, idx):
- return self.samples[idx]
-
- def __repr__(self):
- return f"SingleMetricAnnotation(name={self.name}, len={len(self.samples)})"
-
- def __iter__(self) -> t.Iterator[SampleAnnotation]: # type: ignore
- return iter(self.samples)
-
- def select(self, indices: t.List[int]) -> "SingleMetricAnnotation":
- return SingleMetricAnnotation(
- name=self.name,
- samples=[self.samples[idx] for idx in indices],
- )
-
- @classmethod
- def from_json(cls, path) -> "SingleMetricAnnotation":
- dataset = json.load(open(path))
-
- return cls(
- name=dataset["name"],
- samples=[SampleAnnotation(**sample) for sample in dataset["samples"]],
- )
-
- def filter(self, function: t.Optional[t.Callable] = None):
- if function is None:
- function = lambda x: True # noqa: E731
-
- return SingleMetricAnnotation(
- name=self.name,
- samples=[sample for sample in self.samples if function(sample)],
- )
-
- def __len__(self):
- return len(self.samples)
-
- def train_test_split(
- self,
- test_size: float = 0.2,
- seed: int = 42,
- stratify: t.Optional[t.List[t.Any]] = None,
- ) -> t.Tuple["SingleMetricAnnotation", "SingleMetricAnnotation"]:
- """
- Split the dataset into training and testing sets.
-
- Parameters:
- test_size (float): The proportion of the dataset to include in the test split.
- seed (int): Random seed for reproducibility.
- stratify (list): The column values to stratify the split on.
- """
- raise NotImplementedError
-
- def sample(
- self, n: int, stratify_key: t.Optional[str] = None
- ) -> "SingleMetricAnnotation":
- """
- Create a subset of the dataset.
-
- Parameters:
- n (int): The number of samples to include in the subset.
- stratify_key (str): The column to stratify the subset on.
-
- Returns:
- SingleMetricAnnotation: A subset of the dataset with `n` samples.
- """
- if n > len(self.samples):
- raise ValueError(
- "Requested sample size exceeds the number of available samples."
- )
-
- if stratify_key is None:
- # Simple random sampling
- sampled_indices = random.sample(range(len(self.samples)), n)
- sampled_samples = [self.samples[i] for i in sampled_indices]
- else:
- # Stratified sampling
- class_groups = defaultdict(list)
- for idx, sample in enumerate(self.samples):
- key = sample[stratify_key]
- class_groups[key].append(idx)
-
- # Determine the proportion of samples to take from each class
- total_samples = sum(len(indices) for indices in class_groups.values())
- proportions = {
- cls: len(indices) / total_samples
- for cls, indices in class_groups.items()
- }
-
- sampled_indices = []
- for cls, indices in class_groups.items():
- cls_sample_count = int(np.round(proportions[cls] * n))
- cls_sample_count = min(
- cls_sample_count, len(indices)
- ) # Don't oversample
- sampled_indices.extend(random.sample(indices, cls_sample_count))
-
- # Handle any rounding discrepancies to ensure exactly `n` samples
- while len(sampled_indices) < n:
- remaining_indices = set(range(len(self.samples))) - set(sampled_indices)
- if not remaining_indices:
- break
- sampled_indices.append(random.choice(list(remaining_indices)))
-
- sampled_samples = [self.samples[i] for i in sampled_indices]
-
- return SingleMetricAnnotation(name=self.name, samples=sampled_samples)
-
- def batch(
- self,
- batch_size: int,
- drop_last_batch: bool = False,
- ):
- """
- Create a batch iterator.
-
- Parameters:
- batch_size (int): The number of samples in each batch.
- stratify (str): The column to stratify the batches on.
- drop_last_batch (bool): Whether to drop the last batch if it is smaller than the specified batch size.
- """
-
- samples = self.samples[:]
- random.shuffle(samples)
-
- all_batches = [
- samples[i : i + batch_size]
- for i in range(0, len(samples), batch_size)
- if len(samples[i : i + batch_size]) == batch_size or not drop_last_batch
- ]
-
- return all_batches
-
- def stratified_batches(
- self,
- batch_size: int,
- stratify_key: str,
- drop_last_batch: bool = False,
- replace: bool = False,
- ) -> t.List[t.List[SampleAnnotation]]:
- """
- Create stratified batches based on a specified key, ensuring proportional representation.
-
- Parameters:
- batch_size (int): Number of samples per batch.
- stratify_key (str): Key in `metric_input` used for stratification (e.g., class labels).
- drop_last_batch (bool): If True, drops the last batch if it has fewer samples than `batch_size`.
- replace (bool): If True, allows reusing samples from the same class to fill a batch if necessary.
-
- Returns:
- List[List[SampleAnnotation]]: A list of stratified batches, each batch being a list of SampleAnnotation objects.
- """
- # Group samples based on the stratification key
- class_groups = defaultdict(list)
- for sample in self.samples:
- key = sample[stratify_key]
- class_groups[key].append(sample)
-
- # Shuffle each class group for randomness
- for group in class_groups.values():
- random.shuffle(group)
-
- # Determine the number of batches required
- total_samples = len(self.samples)
- num_batches = (
- np.ceil(total_samples / batch_size).astype(int)
- if drop_last_batch
- else np.floor(total_samples / batch_size).astype(int)
- )
- samples_per_class_per_batch = {
- cls: max(1, len(samples) // num_batches)
- for cls, samples in class_groups.items()
- }
-
- # Create stratified batches
- all_batches = []
- while len(all_batches) < num_batches:
- batch = []
- for cls, samples in list(class_groups.items()):
- # Determine the number of samples to take from this class
- count = min(
- samples_per_class_per_batch[cls],
- len(samples),
- batch_size - len(batch),
- )
- if count > 0:
- # Add samples from the current class
- batch.extend(samples[:count])
- class_groups[cls] = samples[count:] # Remove used samples
- elif replace and len(batch) < batch_size:
- # Reuse samples if `replace` is True
- batch.extend(random.choices(samples, k=batch_size - len(batch)))
-
- # Shuffle the batch to mix classes
- random.shuffle(batch)
- if len(batch) == batch_size or not drop_last_batch:
- all_batches.append(batch)
-
- return all_batches
-
- def get_prompt_annotations(self) -> t.Dict[str, t.List[PromptAnnotation]]:
- """
- Get all the prompt annotations for each prompt as a list.
- """
- prompt_annotations = defaultdict(list)
- for sample in self.samples:
- if sample.is_accepted:
- for prompt_name, prompt_annotation in sample.prompts.items():
- prompt_annotations[prompt_name].append(prompt_annotation)
- return prompt_annotations
-
-"""
-The typing module: Support for gradual typing as defined by PEP 484 and subsequent PEPs.
-
-Among other things, the module includes the following:
-* Generic, Protocol, and internal machinery to support generic aliases.
- All subscripted types like X[int], Union[int, str] are generic aliases.
-* Various "special forms" that have unique meanings in type annotations:
- NoReturn, Never, ClassVar, Self, Concatenate, Unpack, and others.
-* Classes whose instances can be type arguments to generic classes and functions:
- TypeVar, ParamSpec, TypeVarTuple.
-* Public helper functions: get_type_hints, overload, cast, final, and others.
-* Several protocols to support duck-typing:
- SupportsFloat, SupportsIndex, SupportsAbs, and others.
-* Special types: NewType, NamedTuple, TypedDict.
-* Deprecated wrapper submodules for re and io related types.
-* Deprecated aliases for builtin types and collections.abc ABCs.
-
-Any name not present in __all__ is an implementation detail
-that may be changed without notice. Use at your own risk!
-"""
-
-from abc import abstractmethod, ABCMeta
-import collections
-from collections import defaultdict
-import collections.abc
-import contextlib
-import functools
-import operator
-import re as stdlib_re # Avoid confusion with the re we export.
-import sys
-import types
-import warnings
-from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType, GenericAlias
-
-
-try:
- from _typing import _idfunc
-except ImportError:
- def _idfunc(_, x):
- return x
-
-# Please keep __all__ alphabetized within each category.
-__all__ = [
- # Super-special typing primitives.
- 'Annotated',
- 'Any',
- 'Callable',
- 'ClassVar',
- 'Concatenate',
- 'Final',
- 'ForwardRef',
- 'Generic',
- 'Literal',
- 'Optional',
- 'ParamSpec',
- 'Protocol',
- 'Tuple',
- 'Type',
- 'TypeVar',
- 'TypeVarTuple',
- 'Union',
-
- # ABCs (from collections.abc).
- 'AbstractSet', # collections.abc.Set.
- 'ByteString',
- 'Container',
- 'ContextManager',
- 'Hashable',
- 'ItemsView',
- 'Iterable',
- 'Iterator',
- 'KeysView',
- 'Mapping',
- 'MappingView',
- 'MutableMapping',
- 'MutableSequence',
- 'MutableSet',
- 'Sequence',
- 'Sized',
- 'ValuesView',
- 'Awaitable',
- 'AsyncIterator',
- 'AsyncIterable',
- 'Coroutine',
- 'Collection',
- 'AsyncGenerator',
- 'AsyncContextManager',
-
- # Structural checks, a.k.a. protocols.
- 'Reversible',
- 'SupportsAbs',
- 'SupportsBytes',
- 'SupportsComplex',
- 'SupportsFloat',
- 'SupportsIndex',
- 'SupportsInt',
- 'SupportsRound',
-
- # Concrete collection types.
- 'ChainMap',
- 'Counter',
- 'Deque',
- 'Dict',
- 'DefaultDict',
- 'List',
- 'OrderedDict',
- 'Set',
- 'FrozenSet',
- 'NamedTuple', # Not really a type.
- 'TypedDict', # Not really a type.
- 'Generator',
-
- # Other concrete types.
- 'BinaryIO',
- 'IO',
- 'Match',
- 'Pattern',
- 'TextIO',
-
- # One-off things.
- 'AnyStr',
- 'assert_type',
- 'assert_never',
- 'cast',
- 'clear_overloads',
- 'dataclass_transform',
- 'final',
- 'get_args',
- 'get_origin',
- 'get_overloads',
- 'get_type_hints',
- 'is_typeddict',
- 'LiteralString',
- 'Never',
- 'NewType',
- 'no_type_check',
- 'no_type_check_decorator',
- 'NoReturn',
- 'NotRequired',
- 'overload',
- 'ParamSpecArgs',
- 'ParamSpecKwargs',
- 'Required',
- 'reveal_type',
- 'runtime_checkable',
- 'Self',
- 'Text',
- 'TYPE_CHECKING',
- 'TypeAlias',
- 'TypeGuard',
- 'Unpack',
-]
-
-# The pseudo-submodules 're' and 'io' are part of the public
-# namespace, but excluded from __all__ because they might stomp on
-# legitimate imports of those modules.
-
-
-def _type_convert(arg, module=None, *, allow_special_forms=False):
- """For converting None to type(None), and strings to ForwardRef."""
- if arg is None:
- return type(None)
- if isinstance(arg, str):
- return ForwardRef(arg, module=module, is_class=allow_special_forms)
- return arg
-
-
-def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False):
- """Check that the argument is a type, and return it (internal helper).
-
- As a special case, accept None and return type(None) instead. Also wrap strings
- into ForwardRef instances. Consider several corner cases, for example plain
- special forms like Union are not valid, while Union[int, str] is OK, etc.
- The msg argument is a human-readable error message, e.g.::
-
- "Union[arg, ...]: arg should be a type."
-
- We append the repr() of the actual value (truncated to 100 chars).
- """
- invalid_generic_forms = (Generic, Protocol)
- if not allow_special_forms:
- invalid_generic_forms += (ClassVar,)
- if is_argument:
- invalid_generic_forms += (Final,)
-
- arg = _type_convert(arg, module=module, allow_special_forms=allow_special_forms)
- if (isinstance(arg, _GenericAlias) and
- arg.__origin__ in invalid_generic_forms):
- raise TypeError(f"{arg} is not valid as type argument")
- if arg in (Any, LiteralString, NoReturn, Never, Self, TypeAlias):
- return arg
- if allow_special_forms and arg in (ClassVar, Final):
- return arg
- if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
- raise TypeError(f"Plain {arg} is not valid as type argument")
- if type(arg) is tuple:
- raise TypeError(f"{msg} Got {arg!r:.100}.")
- return arg
-
-
-def _is_param_expr(arg):
- return arg is ... or isinstance(arg,
- (tuple, list, ParamSpec, _ConcatenateGenericAlias))
-
-
-def _should_unflatten_callable_args(typ, args):
- """Internal helper for munging collections.abc.Callable's __args__.
-
- The canonical representation for a Callable's __args__ flattens the
- argument types, see https://github.com/python/cpython/issues/86361.
-
- For example::
-
- >>> import collections.abc
- >>> P = ParamSpec('P')
- >>> collections.abc.Callable[[int, int], str].__args__ == (int, int, str)
- True
- >>> collections.abc.Callable[P, str].__args__ == (P, str)
- True
-
- As a result, if we need to reconstruct the Callable from its __args__,
- we need to unflatten it.
- """
- return (
- typ.__origin__ is collections.abc.Callable
- and not (len(args) == 2 and _is_param_expr(args[0]))
- )
-
-
-def _type_repr(obj):
- """Return the repr() of an object, special-casing types (internal helper).
-
- If obj is a type, we return a shorter version than the default
- type.__repr__, based on the module and qualified name, which is
- typically enough to uniquely identify a type. For everything
- else, we fall back on repr(obj).
- """
- if isinstance(obj, types.GenericAlias):
- return repr(obj)
- if isinstance(obj, type):
- if obj.__module__ == 'builtins':
- return obj.__qualname__
- return f'{obj.__module__}.{obj.__qualname__}'
- if obj is ...:
- return('...')
- if isinstance(obj, types.FunctionType):
- return obj.__name__
- return repr(obj)
-
-
-def _collect_parameters(args):
- """Collect all type variables and parameter specifications in args
- in order of first appearance (lexicographic order).
-
- For example::
-
- >>> P = ParamSpec('P')
- >>> T = TypeVar('T')
- >>> _collect_parameters((T, Callable[P, T]))
- (~T, ~P)
- """
- parameters = []
- for t in args:
- if isinstance(t, type):
- # We don't want __parameters__ descriptor of a bare Python class.
- pass
- elif isinstance(t, tuple):
- # `t` might be a tuple, when `ParamSpec` is substituted with
- # `[T, int]`, or `[int, *Ts]`, etc.
- for x in t:
- for collected in _collect_parameters([x]):
- if collected not in parameters:
- parameters.append(collected)
- elif hasattr(t, '__typing_subst__'):
- if t not in parameters:
- parameters.append(t)
- else:
- for x in getattr(t, '__parameters__', ()):
- if x not in parameters:
- parameters.append(x)
- return tuple(parameters)
-
-
-def _check_generic(cls, parameters, elen):
- """Check correct count for parameters of a generic cls (internal helper).
-
- This gives a nice error message in case of count mismatch.
- """
- if not elen:
- raise TypeError(f"{cls} is not a generic class")
- alen = len(parameters)
- if alen != elen:
- raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments for {cls};"
- f" actual {alen}, expected {elen}")
-
-def _unpack_args(args):
- newargs = []
- for arg in args:
- subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
- if subargs is not None and not (subargs and subargs[-1] is ...):
- newargs.extend(subargs)
- else:
- newargs.append(arg)
- return newargs
-
-def _deduplicate(params, *, unhashable_fallback=False):
- # Weed out strict duplicates, preserving the first of each occurrence.
- try:
- return dict.fromkeys(params)
- except TypeError:
- if not unhashable_fallback:
- raise
- # Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
- return _deduplicate_unhashable(params)
-
-def _deduplicate_unhashable(unhashable_params):
- new_unhashable = []
- for t in unhashable_params:
- if t not in new_unhashable:
- new_unhashable.append(t)
- return new_unhashable
-
-def _compare_args_orderless(first_args, second_args):
- first_unhashable = _deduplicate_unhashable(first_args)
- second_unhashable = _deduplicate_unhashable(second_args)
- t = list(second_unhashable)
- try:
- for elem in first_unhashable:
- t.remove(elem)
- except ValueError:
- return False
- return not t
-
-def _remove_dups_flatten(parameters):
- """Internal helper for Union creation and substitution.
-
- Flatten Unions among parameters, then remove duplicates.
- """
- # Flatten out Union[Union[...], ...].
- params = []
- for p in parameters:
- if isinstance(p, (_UnionGenericAlias, types.UnionType)):
- params.extend(p.__args__)
- else:
- params.append(p)
-
- return tuple(_deduplicate(params, unhashable_fallback=True))
-
-
-def _flatten_literal_params(parameters):
- """Internal helper for Literal creation: flatten Literals among parameters."""
- params = []
- for p in parameters:
- if isinstance(p, _LiteralGenericAlias):
- params.extend(p.__args__)
- else:
- params.append(p)
- return tuple(params)
-
-
-_cleanups = []
-
-
-def _tp_cache(func=None, /, *, typed=False):
- """Internal wrapper caching __getitem__ of generic types.
-
- For non-hashable arguments, the original function is used as a fallback.
- """
- def decorator(func):
- cached = functools.lru_cache(typed=typed)(func)
- _cleanups.append(cached.cache_clear)
-
- @functools.wraps(func)
- def inner(*args, **kwds):
- try:
- return cached(*args, **kwds)
- except TypeError:
- pass # All real errors (not unhashable args) are raised below.
- return func(*args, **kwds)
- return inner
-
- if func is not None:
- return decorator(func)
-
- return decorator
-
-def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
- """Evaluate all forward references in the given type t.
-
- For use of globalns and localns see the docstring for get_type_hints().
- recursive_guard is used to prevent infinite recursion with a recursive
- ForwardRef.
- """
- if isinstance(t, ForwardRef):
- return t._evaluate(globalns, localns, recursive_guard)
- if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
- if isinstance(t, GenericAlias):
- args = tuple(
- ForwardRef(arg) if isinstance(arg, str) else arg
- for arg in t.__args__
- )
- is_unpacked = t.__unpacked__
- if _should_unflatten_callable_args(t, args):
- t = t.__origin__[(args[:-1], args[-1])]
- else:
- t = t.__origin__[args]
- if is_unpacked:
- t = Unpack[t]
- ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
- if ev_args == t.__args__:
- return t
- if isinstance(t, GenericAlias):
- return GenericAlias(t.__origin__, ev_args)
- if isinstance(t, types.UnionType):
- return functools.reduce(operator.or_, ev_args)
- else:
- return t.copy_with(ev_args)
- return t
-
-
-class _Final:
- """Mixin to prohibit subclassing."""
-
- __slots__ = ('__weakref__',)
-
- def __init_subclass__(cls, /, *args, **kwds):
- if '_root' not in kwds:
- raise TypeError("Cannot subclass special typing classes")
-
-class _Immutable:
- """Mixin to indicate that object should not be copied."""
-
- __slots__ = ()
-
- def __copy__(self):
- return self
-
- def __deepcopy__(self, memo):
- return self
-
-
-class _NotIterable:
- """Mixin to prevent iteration, without being compatible with Iterable.
-
- That is, we could do::
-
- def __iter__(self): raise TypeError()
-
- But this would make users of this mixin duck type-compatible with
- collections.abc.Iterable - isinstance(foo, Iterable) would be True.
-
- Luckily, we can instead prevent iteration by setting __iter__ to None, which
- is treated specially.
- """
-
- __slots__ = ()
- __iter__ = None
-
-
-# Internal indicator of special typing constructs.
-# See __doc__ instance attribute for specific docs.
-class _SpecialForm(_Final, _NotIterable, _root=True):
- __slots__ = ('_name', '__doc__', '_getitem')
-
- def __init__(self, getitem):
- self._getitem = getitem
- self._name = getitem.__name__
- self.__doc__ = getitem.__doc__
-
- def __getattr__(self, item):
- if item in {'__name__', '__qualname__'}:
- return self._name
-
- raise AttributeError(item)
-
- def __mro_entries__(self, bases):
- raise TypeError(f"Cannot subclass {self!r}")
-
- def __repr__(self):
- return 'typing.' + self._name
-
- def __reduce__(self):
- return self._name
-
- def __call__(self, *args, **kwds):
- raise TypeError(f"Cannot instantiate {self!r}")
-
- def __or__(self, other):
- return Union[self, other]
-
- def __ror__(self, other):
- return Union[other, self]
-
- def __instancecheck__(self, obj):
- raise TypeError(f"{self} cannot be used with isinstance()")
-
- def __subclasscheck__(self, cls):
- raise TypeError(f"{self} cannot be used with issubclass()")
-
- @_tp_cache
- def __getitem__(self, parameters):
- return self._getitem(self, parameters)
-
-
-class _LiteralSpecialForm(_SpecialForm, _root=True):
- def __getitem__(self, parameters):
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- return self._getitem(self, *parameters)
-
-
-class _AnyMeta(type):
- def __instancecheck__(self, obj):
- if self is Any:
- raise TypeError("typing.Any cannot be used with isinstance()")
- return super().__instancecheck__(obj)
-
- def __repr__(self):
- if self is Any:
- return "typing.Any"
- return super().__repr__() # respect to subclasses
-
-
-class Any(metaclass=_AnyMeta):
- """Special type indicating an unconstrained type.
-
- - Any is compatible with every type.
- - Any assumed to have all methods.
- - All values assumed to be instances of Any.
-
- Note that all the above statements are true from the point of view of
- static type checkers. At runtime, Any should not be used with instance
- checks.
- """
-
- def __new__(cls, *args, **kwargs):
- if cls is Any:
- raise TypeError("Any cannot be instantiated")
- return super().__new__(cls)
-
-
-@_SpecialForm
-def NoReturn(self, parameters):
- """Special type indicating functions that never return.
-
- Example::
-
- from typing import NoReturn
-
- def stop() -> NoReturn:
- raise Exception('no way')
-
- NoReturn can also be used as a bottom type, a type that
- has no values. Starting in Python 3.11, the Never type should
- be used for this concept instead. Type checkers should treat the two
- equivalently.
- """
- raise TypeError(f"{self} is not subscriptable")
-
-# This is semantically identical to NoReturn, but it is implemented
-# separately so that type checkers can distinguish between the two
-# if they want.
-@_SpecialForm
-def Never(self, parameters):
- """The bottom type, a type that has no members.
-
- This can be used to define a function that should never be
- called, or a function that never returns::
-
- from typing import Never
-
- def never_call_me(arg: Never) -> None:
- pass
-
- def int_or_str(arg: int | str) -> None:
- never_call_me(arg) # type checker error
- match arg:
- case int():
- print("It's an int")
- case str():
- print("It's a str")
- case _:
- never_call_me(arg) # OK, arg is of type Never
- """
- raise TypeError(f"{self} is not subscriptable")
-
-
-@_SpecialForm
-def Self(self, parameters):
- """Used to spell the type of "self" in classes.
-
- Example::
-
- from typing import Self
-
- class Foo:
- def return_self(self) -> Self:
- ...
- return self
-
- This is especially useful for:
- - classmethods that are used as alternative constructors
- - annotating an `__enter__` method which returns self
- """
- raise TypeError(f"{self} is not subscriptable")
-
-
-@_SpecialForm
-def LiteralString(self, parameters):
- """Represents an arbitrary literal string.
-
- Example::
-
- from typing import LiteralString
-
- def run_query(sql: LiteralString) -> None:
- ...
-
- def caller(arbitrary_string: str, literal_string: LiteralString) -> None:
- run_query("SELECT * FROM students") # OK
- run_query(literal_string) # OK
- run_query("SELECT * FROM " + literal_string) # OK
- run_query(arbitrary_string) # type checker error
- run_query( # type checker error
- f"SELECT * FROM students WHERE name = {arbitrary_string}"
- )
-
- Only string literals and other LiteralStrings are compatible
- with LiteralString. This provides a tool to help prevent
- security issues such as SQL injection.
- """
- raise TypeError(f"{self} is not subscriptable")
-
-
-@_SpecialForm
-def ClassVar(self, parameters):
- """Special type construct to mark class variables.
-
- An annotation wrapped in ClassVar indicates that a given
- attribute is intended to be used as a class variable and
- should not be set on instances of that class.
-
- Usage::
-
- class Starship:
- stats: ClassVar[dict[str, int]] = {} # class variable
- damage: int = 10 # instance variable
-
- ClassVar accepts only types and cannot be further subscribed.
-
- Note that ClassVar is not a class itself, and should not
- be used with isinstance() or issubclass().
- """
- item = _type_check(parameters, f'{self} accepts only single type.')
- return _GenericAlias(self, (item,))
-
-@_SpecialForm
-def Final(self, parameters):
- """Special typing construct to indicate final names to type checkers.
-
- A final name cannot be re-assigned or overridden in a subclass.
-
- For example::
-
- MAX_SIZE: Final = 9000
- MAX_SIZE += 1 # Error reported by type checker
-
- class Connection:
- TIMEOUT: Final[int] = 10
-
- class FastConnector(Connection):
- TIMEOUT = 1 # Error reported by type checker
-
- There is no runtime checking of these properties.
- """
- item = _type_check(parameters, f'{self} accepts only single type.')
- return _GenericAlias(self, (item,))
-
-@_SpecialForm
-def Union(self, parameters):
- """Union type; Union[X, Y] means either X or Y.
-
- On Python 3.10 and higher, the | operator
- can also be used to denote unions;
- X | Y means the same thing to the type checker as Union[X, Y].
-
- To define a union, use e.g. Union[int, str]. Details:
- - The arguments must be types and there must be at least one.
- - None as an argument is a special case and is replaced by
- type(None).
- - Unions of unions are flattened, e.g.::
-
- assert Union[Union[int, str], float] == Union[int, str, float]
-
- - Unions of a single argument vanish, e.g.::
-
- assert Union[int] == int # The constructor actually returns int
-
- - Redundant arguments are skipped, e.g.::
-
- assert Union[int, str, int] == Union[int, str]
-
- - When comparing unions, the argument order is ignored, e.g.::
-
- assert Union[int, str] == Union[str, int]
-
- - You cannot subclass or instantiate a union.
- - You can use Optional[X] as a shorthand for Union[X, None].
- """
- if parameters == ():
- raise TypeError("Cannot take a Union of no types.")
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- msg = "Union[arg, ...]: each arg must be a type."
- parameters = tuple(_type_check(p, msg) for p in parameters)
- parameters = _remove_dups_flatten(parameters)
- if len(parameters) == 1:
- return parameters[0]
- if len(parameters) == 2 and type(None) in parameters:
- return _UnionGenericAlias(self, parameters, name="Optional")
- return _UnionGenericAlias(self, parameters)
-
-@_SpecialForm
-def Optional(self, parameters):
- """Optional[X] is equivalent to Union[X, None]."""
- arg = _type_check(parameters, f"{self} requires a single type.")
- return Union[arg, type(None)]
-
-@_LiteralSpecialForm
-@_tp_cache(typed=True)
-def Literal(self, *parameters):
- """Special typing form to define literal types (a.k.a. value types).
-
- This form can be used to indicate to type checkers that the corresponding
- variable or function parameter has a value equivalent to the provided
- literal (or one of several literals)::
-
- def validate_simple(data: Any) -> Literal[True]: # always returns True
- ...
-
- MODE = Literal['r', 'rb', 'w', 'wb']
- def open_helper(file: str, mode: MODE) -> str:
- ...
-
- open_helper('/some/path', 'r') # Passes type check
- open_helper('/other/path', 'typo') # Error in type checker
-
- Literal[...] cannot be subclassed. At runtime, an arbitrary value
- is allowed as type argument to Literal[...], but type checkers may
- impose restrictions.
- """
- # There is no '_type_check' call because arguments to Literal[...] are
- # values, not types.
- parameters = _flatten_literal_params(parameters)
-
- try:
- parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
- except TypeError: # unhashable parameters
- pass
-
- return _LiteralGenericAlias(self, parameters)
-
-
-@_SpecialForm
-def TypeAlias(self, parameters):
- """Special form for marking type aliases.
-
- Use TypeAlias to indicate that an assignment should
- be recognized as a proper type alias definition by type
- checkers.
-
- For example::
-
- Predicate: TypeAlias = Callable[..., bool]
-
- It's invalid when used anywhere except as in the example above.
- """
- raise TypeError(f"{self} is not subscriptable")
-
-
-@_SpecialForm
-def Concatenate(self, parameters):
- """Special form for annotating higher-order functions.
-
- ``Concatenate`` can be used in conjunction with ``ParamSpec`` and
- ``Callable`` to represent a higher-order function which adds, removes or
- transforms the parameters of a callable.
-
- For example::
-
- Callable[Concatenate[int, P], int]
-
- See PEP 612 for detailed information.
- """
- if parameters == ():
- raise TypeError("Cannot take a Concatenate of no types.")
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- if not (parameters[-1] is ... or isinstance(parameters[-1], ParamSpec)):
- raise TypeError("The last parameter to Concatenate should be a "
- "ParamSpec variable or ellipsis.")
- msg = "Concatenate[arg, ...]: each arg must be a type."
- parameters = (*(_type_check(p, msg) for p in parameters[:-1]), parameters[-1])
- return _ConcatenateGenericAlias(self, parameters,
- _paramspec_tvars=True)
-
-
-@_SpecialForm
-def TypeGuard(self, parameters):
- """Special typing construct for marking user-defined type guard functions.
-
- ``TypeGuard`` can be used to annotate the return type of a user-defined
- type guard function. ``TypeGuard`` only accepts a single type argument.
- At runtime, functions marked this way should return a boolean.
-
- ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static
- type checkers to determine a more precise type of an expression within a
- program's code flow. Usually type narrowing is done by analyzing
- conditional code flow and applying the narrowing to a block of code. The
- conditional expression here is sometimes referred to as a "type guard".
-
- Sometimes it would be convenient to use a user-defined boolean function
- as a type guard. Such a function should use ``TypeGuard[...]`` as its
- return type to alert static type checkers to this intention.
-
- Using ``-> TypeGuard`` tells the static type checker that for a given
- function:
-
- 1. The return value is a boolean.
- 2. If the return value is ``True``, the type of its argument
- is the type inside ``TypeGuard``.
-
- For example::
-
- def is_str(val: Union[str, float]):
- # "isinstance" type guard
- if isinstance(val, str):
- # Type of ``val`` is narrowed to ``str``
- ...
- else:
- # Else, type of ``val`` is narrowed to ``float``.
- ...
-
- Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower
- form of ``TypeA`` (it can even be a wider form) and this may lead to
- type-unsafe results. The main reason is to allow for things like
- narrowing ``List[object]`` to ``List[str]`` even though the latter is not
- a subtype of the former, since ``List`` is invariant. The responsibility of
- writing type-safe type guards is left to the user.
-
- ``TypeGuard`` also works with type variables. For more information, see
- PEP 647 (User-Defined Type Guards).
- """
- item = _type_check(parameters, f'{self} accepts only single type.')
- return _GenericAlias(self, (item,))
-
-
-class ForwardRef(_Final, _root=True):
- """Internal wrapper to hold a forward reference."""
-
- __slots__ = ('__forward_arg__', '__forward_code__',
- '__forward_evaluated__', '__forward_value__',
- '__forward_is_argument__', '__forward_is_class__',
- '__forward_module__')
-
- def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
- if not isinstance(arg, str):
- raise TypeError(f"Forward reference must be a string -- got {arg!r}")
-
- # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
- # Unfortunately, this isn't a valid expression on its own, so we
- # do the unpacking manually.
- if arg.startswith('*'):
- arg_to_compile = f'({arg},)[0]' # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
- else:
- arg_to_compile = arg
- try:
- code = compile(arg_to_compile, '<string>', 'eval')
- except SyntaxError:
- raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
-
- self.__forward_arg__ = arg
- self.__forward_code__ = code
- self.__forward_evaluated__ = False
- self.__forward_value__ = None
- self.__forward_is_argument__ = is_argument
- self.__forward_is_class__ = is_class
- self.__forward_module__ = module
-
- def _evaluate(self, globalns, localns, recursive_guard):
- if self.__forward_arg__ in recursive_guard:
- return self
- if not self.__forward_evaluated__ or localns is not globalns:
- if globalns is None and localns is None:
- globalns = localns = {}
- elif globalns is None:
- globalns = localns
- elif localns is None:
- localns = globalns
- if self.__forward_module__ is not None:
- globalns = getattr(
- sys.modules.get(self.__forward_module__, None), '__dict__', globalns
- )
- type_ = _type_check(
- eval(self.__forward_code__, globalns, localns),
- "Forward references must evaluate to types.",
- is_argument=self.__forward_is_argument__,
- allow_special_forms=self.__forward_is_class__,
- )
- self.__forward_value__ = _eval_type(
- type_, globalns, localns, recursive_guard | {self.__forward_arg__}
- )
- self.__forward_evaluated__ = True
- return self.__forward_value__
-
- def __eq__(self, other):
- if not isinstance(other, ForwardRef):
- return NotImplemented
- if self.__forward_evaluated__ and other.__forward_evaluated__:
- return (self.__forward_arg__ == other.__forward_arg__ and
- self.__forward_value__ == other.__forward_value__)
- return (self.__forward_arg__ == other.__forward_arg__ and
- self.__forward_module__ == other.__forward_module__)
-
- def __hash__(self):
- return hash((self.__forward_arg__, self.__forward_module__))
-
- def __or__(self, other):
- return Union[self, other]
-
- def __ror__(self, other):
- return Union[other, self]
-
- def __repr__(self):
- if self.__forward_module__ is None:
- module_repr = ''
- else:
- module_repr = f', module={self.__forward_module__!r}'
- return f'ForwardRef({self.__forward_arg__!r}{module_repr})'
-
-
-def _is_unpacked_typevartuple(x: Any) -> bool:
- return ((not isinstance(x, type)) and
- getattr(x, '__typing_is_unpacked_typevartuple__', False))
-
-
-def _is_typevar_like(x: Any) -> bool:
- return isinstance(x, (TypeVar, ParamSpec)) or _is_unpacked_typevartuple(x)
-
-
-class _PickleUsingNameMixin:
- """Mixin enabling pickling based on self.__name__."""
-
- def __reduce__(self):
- return self.__name__
-
-
-class _BoundVarianceMixin:
- """Mixin giving __init__ bound and variance arguments.
-
- This is used by TypeVar and ParamSpec, which both employ the notions of
- a type 'bound' (restricting type arguments to be a subtype of some
- specified type) and type 'variance' (determining subtype relations between
- generic types).
- """
- def __init__(self, bound, covariant, contravariant):
- """Used to setup TypeVars and ParamSpec's bound, covariant and
- contravariant attributes.
- """
- if covariant and contravariant:
- raise ValueError("Bivariant types are not supported.")
- self.__covariant__ = bool(covariant)
- self.__contravariant__ = bool(contravariant)
- if bound:
- self.__bound__ = _type_check(bound, "Bound must be a type.")
- else:
- self.__bound__ = None
-
- def __or__(self, right):
- return Union[self, right]
-
- def __ror__(self, left):
- return Union[left, self]
-
- def __repr__(self):
- if self.__covariant__:
- prefix = '+'
- elif self.__contravariant__:
- prefix = '-'
- else:
- prefix = '~'
- return prefix + self.__name__
-
-
-class TypeVar(_Final, _Immutable, _BoundVarianceMixin, _PickleUsingNameMixin,
- _root=True):
- """Type variable.
-
- Usage::
-
- T = TypeVar('T') # Can be anything
- A = TypeVar('A', str, bytes) # Must be str or bytes
-
- Type variables exist primarily for the benefit of static type
- checkers. They serve as the parameters for generic types as well
- as for generic function definitions. See class Generic for more
- information on generic types. Generic functions work as follows:
-
- def repeat(x: T, n: int) -> List[T]:
- '''Return a list containing n references to x.'''
- return [x]*n
-
- def longest(x: A, y: A) -> A:
- '''Return the longest of two strings.'''
- return x if len(x) >= len(y) else y
-
- The latter example's signature is essentially the overloading
- of (str, str) -> str and (bytes, bytes) -> bytes. Also note
- that if the arguments are instances of some subclass of str,
- the return type is still plain str.
-
- At runtime, isinstance(x, T) and issubclass(C, T) will raise TypeError.
-
- Type variables defined with covariant=True or contravariant=True
- can be used to declare covariant or contravariant generic types.
- See PEP 484 for more details. By default generic types are invariant
- in all type variables.
-
- Type variables can be introspected. e.g.:
-
- T.__name__ == 'T'
- T.__constraints__ == ()
- T.__covariant__ == False
- T.__contravariant__ = False
- A.__constraints__ == (str, bytes)
-
- Note that only type variables defined in global scope can be pickled.
- """
-
- def __init__(self, name, *constraints, bound=None,
- covariant=False, contravariant=False):
- self.__name__ = name
- super().__init__(bound, covariant, contravariant)
- if constraints and bound is not None:
- raise TypeError("Constraints cannot be combined with bound=...")
- if constraints and len(constraints) == 1:
- raise TypeError("A single constraint is not allowed")
- msg = "TypeVar(name, constraint, ...): constraints must be types."
- self.__constraints__ = tuple(_type_check(t, msg) for t in constraints)
- def_mod = _caller()
- if def_mod != 'typing':
- self.__module__ = def_mod
-
- def __typing_subst__(self, arg):
- msg = "Parameters to generic types must be types."
- arg = _type_check(arg, msg, is_argument=True)
- if ((isinstance(arg, _GenericAlias) and arg.__origin__ is Unpack) or
- (isinstance(arg, GenericAlias) and getattr(arg, '__unpacked__', False))):
- raise TypeError(f"{arg} is not valid as type argument")
- return arg
-
-
-class TypeVarTuple(_Final, _Immutable, _PickleUsingNameMixin, _root=True):
- """Type variable tuple.
-
- Usage:
-
- Ts = TypeVarTuple('Ts') # Can be given any name
-
- Just as a TypeVar (type variable) is a placeholder for a single type,
- a TypeVarTuple is a placeholder for an *arbitrary* number of types. For
- example, if we define a generic class using a TypeVarTuple:
-
- class C(Generic[*Ts]): ...
-
- Then we can parameterize that class with an arbitrary number of type
- arguments:
-
- C[int] # Fine
- C[int, str] # Also fine
- C[()] # Even this is fine
-
- For more details, see PEP 646.
-
- Note that only TypeVarTuples defined in global scope can be pickled.
- """
-
- def __init__(self, name):
- self.__name__ = name
-
- # Used for pickling.
- def_mod = _caller()
- if def_mod != 'typing':
- self.__module__ = def_mod
-
- def __iter__(self):
- yield Unpack[self]
-
- def __repr__(self):
- return self.__name__
-
- def __typing_subst__(self, arg):
- raise TypeError("Substitution of bare TypeVarTuple is not supported")
-
- def __typing_prepare_subst__(self, alias, args):
- params = alias.__parameters__
- typevartuple_index = params.index(self)
- for param in params[typevartuple_index + 1:]:
- if isinstance(param, TypeVarTuple):
- raise TypeError(f"More than one TypeVarTuple parameter in {alias}")
-
- alen = len(args)
- plen = len(params)
- left = typevartuple_index
- right = plen - typevartuple_index - 1
- var_tuple_index = None
- fillarg = None
- for k, arg in enumerate(args):
- if not isinstance(arg, type):
- subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
- if subargs and len(subargs) == 2 and subargs[-1] is ...:
- if var_tuple_index is not None:
- raise TypeError("More than one unpacked arbitrary-length tuple argument")
- var_tuple_index = k
- fillarg = subargs[0]
- if var_tuple_index is not None:
- left = min(left, var_tuple_index)
- right = min(right, alen - var_tuple_index - 1)
- elif left + right > alen:
- raise TypeError(f"Too few arguments for {alias};"
- f" actual {alen}, expected at least {plen-1}")
-
- return (
- *args[:left],
- *([fillarg]*(typevartuple_index - left)),
- tuple(args[left: alen - right]),
- *([fillarg]*(plen - right - left - typevartuple_index - 1)),
- *args[alen - right:],
- )
-
-
-class ParamSpecArgs(_Final, _Immutable, _root=True):
- """The args for a ParamSpec object.
-
- Given a ParamSpec object P, P.args is an instance of ParamSpecArgs.
-
- ParamSpecArgs objects have a reference back to their ParamSpec:
-
- P.args.__origin__ is P
-
- This type is meant for runtime introspection and has no special meaning to
- static type checkers.
- """
- def __init__(self, origin):
- self.__origin__ = origin
-
- def __repr__(self):
- return f"{self.__origin__.__name__}.args"
-
- def __eq__(self, other):
- if not isinstance(other, ParamSpecArgs):
- return NotImplemented
- return self.__origin__ == other.__origin__
-
-
-class ParamSpecKwargs(_Final, _Immutable, _root=True):
- """The kwargs for a ParamSpec object.
-
- Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs.
-
- ParamSpecKwargs objects have a reference back to their ParamSpec:
-
- P.kwargs.__origin__ is P
-
- This type is meant for runtime introspection and has no special meaning to
- static type checkers.
- """
- def __init__(self, origin):
- self.__origin__ = origin
-
- def __repr__(self):
- return f"{self.__origin__.__name__}.kwargs"
-
- def __eq__(self, other):
- if not isinstance(other, ParamSpecKwargs):
- return NotImplemented
- return self.__origin__ == other.__origin__
-
-
-class ParamSpec(_Final, _Immutable, _BoundVarianceMixin, _PickleUsingNameMixin,
- _root=True):
- """Parameter specification variable.
-
- Usage::
-
- P = ParamSpec('P')
-
- Parameter specification variables exist primarily for the benefit of static
- type checkers. They are used to forward the parameter types of one
- callable to another callable, a pattern commonly found in higher order
- functions and decorators. They are only valid when used in ``Concatenate``,
- or as the first argument to ``Callable``, or as parameters for user-defined
- Generics. See class Generic for more information on generic types. An
- example for annotating a decorator::
-
- T = TypeVar('T')
- P = ParamSpec('P')
-
- def add_logging(f: Callable[P, T]) -> Callable[P, T]:
- '''A type-safe decorator to add logging to a function.'''
- def inner(*args: P.args, **kwargs: P.kwargs) -> T:
- logging.info(f'{f.__name__} was called')
- return f(*args, **kwargs)
- return inner
-
- @add_logging
- def add_two(x: float, y: float) -> float:
- '''Add two numbers together.'''
- return x + y
-
- Parameter specification variables can be introspected. e.g.:
-
- P.__name__ == 'P'
-
- Note that only parameter specification variables defined in global scope can
- be pickled.
- """
-
- @property
- def args(self):
- return ParamSpecArgs(self)
-
- @property
- def kwargs(self):
- return ParamSpecKwargs(self)
-
- def __init__(self, name, *, bound=None, covariant=False, contravariant=False):
- self.__name__ = name
- super().__init__(bound, covariant, contravariant)
- def_mod = _caller()
- if def_mod != 'typing':
- self.__module__ = def_mod
-
- def __typing_subst__(self, arg):
- if isinstance(arg, (list, tuple)):
- arg = tuple(_type_check(a, "Expected a type.") for a in arg)
- elif not _is_param_expr(arg):
- raise TypeError(f"Expected a list of types, an ellipsis, "
- f"ParamSpec, or Concatenate. Got {arg}")
- return arg
-
- def __typing_prepare_subst__(self, alias, args):
- params = alias.__parameters__
- i = params.index(self)
- if i >= len(args):
- raise TypeError(f"Too few arguments for {alias}")
- # Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612.
- if len(params) == 1 and not _is_param_expr(args[0]):
- assert i == 0
- args = (args,)
- # Convert lists to tuples to help other libraries cache the results.
- elif isinstance(args[i], list):
- args = (*args[:i], tuple(args[i]), *args[i+1:])
- return args
-
-def _is_dunder(attr):
- return attr.startswith('__') and attr.endswith('__')
-
-class _BaseGenericAlias(_Final, _root=True):
- """The central part of the internal API.
-
- This represents a generic version of type 'origin' with type arguments 'params'.
- There are two kind of these aliases: user defined and special. The special ones
- are wrappers around builtin collections and ABCs in collections.abc. These must
- have 'name' always set. If 'inst' is False, then the alias can't be instantiated;
- this is used by e.g. typing.List and typing.Dict.
- """
-
- def __init__(self, origin, *, inst=True, name=None):
- self._inst = inst
- self._name = name
- self.__origin__ = origin
- self.__slots__ = None # This is not documented.
-
- def __call__(self, *args, **kwargs):
- if not self._inst:
- raise TypeError(f"Type {self._name} cannot be instantiated; "
- f"use {self.__origin__.__name__}() instead")
- result = self.__origin__(*args, **kwargs)
- try:
- result.__orig_class__ = self
- # Some objects raise TypeError (or something even more exotic)
- # if you try to set attributes on them; we guard against that here
- except Exception:
- pass
- return result
-
- def __mro_entries__(self, bases):
- res = []
- if self.__origin__ not in bases:
- res.append(self.__origin__)
- i = bases.index(self)
- for b in bases[i+1:]:
- if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic):
- break
- else:
- res.append(Generic)
- return tuple(res)
-
- def __getattr__(self, attr):
- if attr in {'__name__', '__qualname__'}:
- return self._name or self.__origin__.__name__
-
- # We are careful for copy and pickle.
- # Also for simplicity we don't relay any dunder names
- if '__origin__' in self.__dict__ and not _is_dunder(attr):
- return getattr(self.__origin__, attr)
- raise AttributeError(attr)
-
- def __setattr__(self, attr, val):
- if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams',
- '_paramspec_tvars'}:
- super().__setattr__(attr, val)
- else:
- setattr(self.__origin__, attr, val)
-
- def __instancecheck__(self, obj):
- return self.__subclasscheck__(type(obj))
-
- def __subclasscheck__(self, cls):
- raise TypeError("Subscripted generics cannot be used with"
- " class and instance checks")
-
- def __dir__(self):
- return list(set(super().__dir__()
- + [attr for attr in dir(self.__origin__) if not _is_dunder(attr)]))
-
-
-# Special typing constructs Union, Optional, Generic, Callable and Tuple
-# use three special attributes for internal bookkeeping of generic types:
-# * __parameters__ is a tuple of unique free type parameters of a generic
-# type, for example, Dict[T, T].__parameters__ == (T,);
-# * __origin__ keeps a reference to a type that was subscripted,
-# e.g., Union[T, int].__origin__ == Union, or the non-generic version of
-# the type.
-# * __args__ is a tuple of all arguments used in subscripting,
-# e.g., Dict[T, int].__args__ == (T, int).
-
-
-class _GenericAlias(_BaseGenericAlias, _root=True):
- # The type of parameterized generics.
- #
- # That is, for example, `type(List[int])` is `_GenericAlias`.
- #
- # Objects which are instances of this class include:
- # * Parameterized container types, e.g. `Tuple[int]`, `List[int]`.
- # * Note that native container types, e.g. `tuple`, `list`, use
- # `types.GenericAlias` instead.
- # * Parameterized classes:
- # T = TypeVar('T')
- # class C(Generic[T]): pass
- # # C[int] is a _GenericAlias
- # * `Callable` aliases, generic `Callable` aliases, and
- # parameterized `Callable` aliases:
- # T = TypeVar('T')
- # # _CallableGenericAlias inherits from _GenericAlias.
- # A = Callable[[], None] # _CallableGenericAlias
- # B = Callable[[T], None] # _CallableGenericAlias
- # C = B[int] # _CallableGenericAlias
- # * Parameterized `Final`, `ClassVar` and `TypeGuard`:
- # # All _GenericAlias
- # Final[int]
- # ClassVar[float]
- # TypeVar[bool]
-
- def __init__(self, origin, args, *, inst=True, name=None,
- _paramspec_tvars=False):
- super().__init__(origin, inst=inst, name=name)
- if not isinstance(args, tuple):
- args = (args,)
- self.__args__ = tuple(... if a is _TypingEllipsis else
- a for a in args)
- self.__parameters__ = _collect_parameters(args)
- self._paramspec_tvars = _paramspec_tvars
- if not name:
- self.__module__ = origin.__module__
-
- def __eq__(self, other):
- if not isinstance(other, _GenericAlias):
- return NotImplemented
- return (self.__origin__ == other.__origin__
- and self.__args__ == other.__args__)
-
- def __hash__(self):
- return hash((self.__origin__, self.__args__))
-
- def __or__(self, right):
- return Union[self, right]
-
- def __ror__(self, left):
- return Union[left, self]
-
- @_tp_cache
- def __getitem__(self, args):
- # Parameterizes an already-parameterized object.
- #
- # For example, we arrive here doing something like:
- # T1 = TypeVar('T1')
- # T2 = TypeVar('T2')
- # T3 = TypeVar('T3')
- # class A(Generic[T1]): pass
- # B = A[T2] # B is a _GenericAlias
- # C = B[T3] # Invokes _GenericAlias.__getitem__
- #
- # We also arrive here when parameterizing a generic `Callable` alias:
- # T = TypeVar('T')
- # C = Callable[[T], None]
- # C[int] # Invokes _GenericAlias.__getitem__
-
- if self.__origin__ in (Generic, Protocol):
- # Can't subscript Generic[...] or Protocol[...].
- raise TypeError(f"Cannot subscript already-subscripted {self}")
- if not self.__parameters__:
- raise TypeError(f"{self} is not a generic class")
-
- # Preprocess `args`.
- if not isinstance(args, tuple):
- args = (args,)
- args = tuple(_type_convert(p) for p in args)
- args = _unpack_args(args)
- new_args = self._determine_new_args(args)
- r = self.copy_with(new_args)
- return r
-
- def _determine_new_args(self, args):
- # Determines new __args__ for __getitem__.
- #
- # For example, suppose we had:
- # T1 = TypeVar('T1')
- # T2 = TypeVar('T2')
- # class A(Generic[T1, T2]): pass
- # T3 = TypeVar('T3')
- # B = A[int, T3]
- # C = B[str]
- # `B.__args__` is `(int, T3)`, so `C.__args__` should be `(int, str)`.
- # Unfortunately, this is harder than it looks, because if `T3` is
- # anything more exotic than a plain `TypeVar`, we need to consider
- # edge cases.
-
- params = self.__parameters__
- # In the example above, this would be {T3: str}
- for param in params:
- prepare = getattr(param, '__typing_prepare_subst__', None)
- if prepare is not None:
- args = prepare(self, args)
- alen = len(args)
- plen = len(params)
- if alen != plen:
- raise TypeError(f"Too {'many' if alen > plen else 'few'} arguments for {self};"
- f" actual {alen}, expected {plen}")
- new_arg_by_param = dict(zip(params, args))
- return tuple(self._make_substitution(self.__args__, new_arg_by_param))
-
- def _make_substitution(self, args, new_arg_by_param):
- """Create a list of new type arguments."""
- new_args = []
- for old_arg in args:
- if isinstance(old_arg, type):
- new_args.append(old_arg)
- continue
-
- substfunc = getattr(old_arg, '__typing_subst__', None)
- if substfunc:
- new_arg = substfunc(new_arg_by_param[old_arg])
- else:
- subparams = getattr(old_arg, '__parameters__', ())
- if not subparams:
- new_arg = old_arg
- else:
- subargs = []
- for x in subparams:
- if isinstance(x, TypeVarTuple):
- subargs.extend(new_arg_by_param[x])
- else:
- subargs.append(new_arg_by_param[x])
- new_arg = old_arg[tuple(subargs)]
-
- if self.__origin__ == collections.abc.Callable and isinstance(new_arg, tuple):
- # Consider the following `Callable`.
- # C = Callable[[int], str]
- # Here, `C.__args__` should be (int, str) - NOT ([int], str).
- # That means that if we had something like...
- # P = ParamSpec('P')
- # T = TypeVar('T')
- # C = Callable[P, T]
- # D = C[[int, str], float]
- # ...we need to be careful; `new_args` should end up as
- # `(int, str, float)` rather than `([int, str], float)`.
- new_args.extend(new_arg)
- elif _is_unpacked_typevartuple(old_arg):
- # Consider the following `_GenericAlias`, `B`:
- # class A(Generic[*Ts]): ...
- # B = A[T, *Ts]
- # If we then do:
- # B[float, int, str]
- # The `new_arg` corresponding to `T` will be `float`, and the
- # `new_arg` corresponding to `*Ts` will be `(int, str)`. We
- # should join all these types together in a flat list
- # `(float, int, str)` - so again, we should `extend`.
- new_args.extend(new_arg)
- elif isinstance(old_arg, tuple):
- # Corner case:
- # P = ParamSpec('P')
- # T = TypeVar('T')
- # class Base(Generic[P]): ...
- # Can be substituted like this:
- # X = Base[[int, T]]
- # In this case, `old_arg` will be a tuple:
- new_args.append(
- tuple(self._make_substitution(old_arg, new_arg_by_param)),
- )
- else:
- new_args.append(new_arg)
- return new_args
-
- def copy_with(self, args):
- return self.__class__(self.__origin__, args, name=self._name, inst=self._inst,
- _paramspec_tvars=self._paramspec_tvars)
-
- def __repr__(self):
- if self._name:
- name = 'typing.' + self._name
- else:
- name = _type_repr(self.__origin__)
- if self.__args__:
- args = ", ".join([_type_repr(a) for a in self.__args__])
- else:
- # To ensure the repr is eval-able.
- args = "()"
- return f'{name}[{args}]'
-
- def __reduce__(self):
- if self._name:
- origin = globals()[self._name]
- else:
- origin = self.__origin__
- args = tuple(self.__args__)
- if len(args) == 1 and not isinstance(args[0], tuple):
- args, = args
- return operator.getitem, (origin, args)
-
- def __mro_entries__(self, bases):
- if isinstance(self.__origin__, _SpecialForm):
- raise TypeError(f"Cannot subclass {self!r}")
-
- if self._name: # generic version of an ABC or built-in class
- return super().__mro_entries__(bases)
- if self.__origin__ is Generic:
- if Protocol in bases:
- return ()
- i = bases.index(self)
- for b in bases[i+1:]:
- if isinstance(b, _BaseGenericAlias) and b is not self:
- return ()
- return (self.__origin__,)
-
- def __iter__(self):
- yield Unpack[self]
-
-
-# _nparams is the number of accepted parameters, e.g. 0 for Hashable,
-# 1 for List and 2 for Dict. It may be -1 if variable number of
-# parameters are accepted (needs custom __getitem__).
-
-class _SpecialGenericAlias(_NotIterable, _BaseGenericAlias, _root=True):
- def __init__(self, origin, nparams, *, inst=True, name=None):
- if name is None:
- name = origin.__name__
- super().__init__(origin, inst=inst, name=name)
- self._nparams = nparams
- if origin.__module__ == 'builtins':
- self.__doc__ = f'A generic version of {origin.__qualname__}.'
- else:
- self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}.'
-
- @_tp_cache
- def __getitem__(self, params):
- if not isinstance(params, tuple):
- params = (params,)
- msg = "Parameters to generic types must be types."
- params = tuple(_type_check(p, msg) for p in params)
- _check_generic(self, params, self._nparams)
- return self.copy_with(params)
-
- def copy_with(self, params):
- return _GenericAlias(self.__origin__, params,
- name=self._name, inst=self._inst)
-
- def __repr__(self):
- return 'typing.' + self._name
-
- def __subclasscheck__(self, cls):
- if isinstance(cls, _SpecialGenericAlias):
- return issubclass(cls.__origin__, self.__origin__)
- if not isinstance(cls, _GenericAlias):
- return issubclass(cls, self.__origin__)
- return super().__subclasscheck__(cls)
-
- def __reduce__(self):
- return self._name
-
- def __or__(self, right):
- return Union[self, right]
-
- def __ror__(self, left):
- return Union[left, self]
-
-class _CallableGenericAlias(_NotIterable, _GenericAlias, _root=True):
- def __repr__(self):
- assert self._name == 'Callable'
- args = self.__args__
- if len(args) == 2 and _is_param_expr(args[0]):
- return super().__repr__()
- return (f'typing.Callable'
- f'[[{", ".join([_type_repr(a) for a in args[:-1]])}], '
- f'{_type_repr(args[-1])}]')
-
- def __reduce__(self):
- args = self.__args__
- if not (len(args) == 2 and _is_param_expr(args[0])):
- args = list(args[:-1]), args[-1]
- return operator.getitem, (Callable, args)
-
-
-class _CallableType(_SpecialGenericAlias, _root=True):
- def copy_with(self, params):
- return _CallableGenericAlias(self.__origin__, params,
- name=self._name, inst=self._inst,
- _paramspec_tvars=True)
-
- def __getitem__(self, params):
- if not isinstance(params, tuple) or len(params) != 2:
- raise TypeError("Callable must be used as "
- "Callable[[arg, ...], result].")
- args, result = params
- # This relaxes what args can be on purpose to allow things like
- # PEP 612 ParamSpec. Responsibility for whether a user is using
- # Callable[...] properly is deferred to static type checkers.
- if isinstance(args, list):
- params = (tuple(args), result)
- else:
- params = (args, result)
- return self.__getitem_inner__(params)
-
- @_tp_cache
- def __getitem_inner__(self, params):
- args, result = params
- msg = "Callable[args, result]: result must be a type."
- result = _type_check(result, msg)
- if args is Ellipsis:
- return self.copy_with((_TypingEllipsis, result))
- if not isinstance(args, tuple):
- args = (args,)
- args = tuple(_type_convert(arg) for arg in args)
- params = args + (result,)
- return self.copy_with(params)
-
-
-class _TupleType(_SpecialGenericAlias, _root=True):
- @_tp_cache
- def __getitem__(self, params):
- if not isinstance(params, tuple):
- params = (params,)
- if len(params) >= 2 and params[-1] is ...:
- msg = "Tuple[t, ...]: t must be a type."
- params = tuple(_type_check(p, msg) for p in params[:-1])
- return self.copy_with((*params, _TypingEllipsis))
- msg = "Tuple[t0, t1, ...]: each t must be a type."
- params = tuple(_type_check(p, msg) for p in params)
- return self.copy_with(params)
-
-
-class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
- def copy_with(self, params):
- return Union[params]
-
- def __eq__(self, other):
- if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
- return NotImplemented
- try: # fast path
- return set(self.__args__) == set(other.__args__)
- except TypeError: # not hashable, slow path
- return _compare_args_orderless(self.__args__, other.__args__)
-
- def __hash__(self):
- return hash(frozenset(self.__args__))
-
- def __repr__(self):
- args = self.__args__
- if len(args) == 2:
- if args[0] is type(None):
- return f'typing.Optional[{_type_repr(args[1])}]'
- elif args[1] is type(None):
- return f'typing.Optional[{_type_repr(args[0])}]'
- return super().__repr__()
-
- def __instancecheck__(self, obj):
- return self.__subclasscheck__(type(obj))
-
- def __subclasscheck__(self, cls):
- for arg in self.__args__:
- if issubclass(cls, arg):
- return True
-
- def __reduce__(self):
- func, (origin, args) = super().__reduce__()
- return func, (Union, args)
-
-
-def _value_and_type_iter(parameters):
- return ((p, type(p)) for p in parameters)
-
-
-class _LiteralGenericAlias(_GenericAlias, _root=True):
- def __eq__(self, other):
- if not isinstance(other, _LiteralGenericAlias):
- return NotImplemented
-
- return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
-
- def __hash__(self):
- return hash(frozenset(_value_and_type_iter(self.__args__)))
-
-
-class _ConcatenateGenericAlias(_GenericAlias, _root=True):
- def copy_with(self, params):
- if isinstance(params[-1], (list, tuple)):
- return (*params[:-1], *params[-1])
- if isinstance(params[-1], _ConcatenateGenericAlias):
- params = (*params[:-1], *params[-1].__args__)
- return super().copy_with(params)
-
-
-@_SpecialForm
-def Unpack(self, parameters):
- """Type unpack operator.
-
- The type unpack operator takes the child types from some container type,
- such as `tuple[int, str]` or a `TypeVarTuple`, and 'pulls them out'.
-
- For example::
-
- # For some generic class `Foo`:
- Foo[Unpack[tuple[int, str]]] # Equivalent to Foo[int, str]
-
- Ts = TypeVarTuple('Ts')
- # Specifies that `Bar` is generic in an arbitrary number of types.
- # (Think of `Ts` as a tuple of an arbitrary number of individual
- # `TypeVar`s, which the `Unpack` is 'pulling out' directly into the
- # `Generic[]`.)
- class Bar(Generic[Unpack[Ts]]): ...
- Bar[int] # Valid
- Bar[int, str] # Also valid
-
- From Python 3.11, this can also be done using the `*` operator::
-
- Foo[*tuple[int, str]]
- class Bar(Generic[*Ts]): ...
-
- Note that there is only some runtime checking of this operator. Not
- everything the runtime allows may be accepted by static type checkers.
-
- For more information, see PEP 646.
- """
- item = _type_check(parameters, f'{self} accepts only single type.')
- return _UnpackGenericAlias(origin=self, args=(item,))
-
-
-class _UnpackGenericAlias(_GenericAlias, _root=True):
- def __repr__(self):
- # `Unpack` only takes one argument, so __args__ should contain only
- # a single item.
- return '*' + repr(self.__args__[0])
-
- def __getitem__(self, args):
- if self.__typing_is_unpacked_typevartuple__:
- return args
- return super().__getitem__(args)
-
- @property
- def __typing_unpacked_tuple_args__(self):
- assert self.__origin__ is Unpack
- assert len(self.__args__) == 1
- arg, = self.__args__
- if isinstance(arg, _GenericAlias):
- assert arg.__origin__ is tuple
- return arg.__args__
- return None
-
- @property
- def __typing_is_unpacked_typevartuple__(self):
- assert self.__origin__ is Unpack
- assert len(self.__args__) == 1
- return isinstance(self.__args__[0], TypeVarTuple)
-
-
-class Generic:
- """Abstract base class for generic types.
-
- A generic type is typically declared by inheriting from
- this class parameterized with one or more type variables.
- For example, a generic mapping type might be defined as::
-
- class Mapping(Generic[KT, VT]):
- def __getitem__(self, key: KT) -> VT:
- ...
- # Etc.
-
- This class can then be used as follows::
-
- def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
- try:
- return mapping[key]
- except KeyError:
- return default
- """
- __slots__ = ()
- _is_protocol = False
-
- @_tp_cache
- def __class_getitem__(cls, params):
- """Parameterizes a generic class.
-
- At least, parameterizing a generic class is the *main* thing this method
- does. For example, for some generic class `Foo`, this is called when we
- do `Foo[int]` - there, with `cls=Foo` and `params=int`.
-
- However, note that this method is also called when defining generic
- classes in the first place with `class Foo(Generic[T]): ...`.
- """
- if not isinstance(params, tuple):
- params = (params,)
-
- params = tuple(_type_convert(p) for p in params)
- if cls in (Generic, Protocol):
- # Generic and Protocol can only be subscripted with unique type variables.
- if not params:
- raise TypeError(
- f"Parameter list to {cls.__qualname__}[...] cannot be empty"
- )
- if not all(_is_typevar_like(p) for p in params):
- raise TypeError(
- f"Parameters to {cls.__name__}[...] must all be type variables "
- f"or parameter specification variables.")
- if len(set(params)) != len(params):
- raise TypeError(
- f"Parameters to {cls.__name__}[...] must all be unique")
- else:
- # Subscripting a regular Generic subclass.
- for param in cls.__parameters__:
- prepare = getattr(param, '__typing_prepare_subst__', None)
- if prepare is not None:
- params = prepare(cls, params)
- _check_generic(cls, params, len(cls.__parameters__))
-
- new_args = []
- for param, new_arg in zip(cls.__parameters__, params):
- if isinstance(param, TypeVarTuple):
- new_args.extend(new_arg)
- else:
- new_args.append(new_arg)
- params = tuple(new_args)
-
- return _GenericAlias(cls, params,
- _paramspec_tvars=True)
-
- def __init_subclass__(cls, *args, **kwargs):
- super().__init_subclass__(*args, **kwargs)
- tvars = []
- if '__orig_bases__' in cls.__dict__:
- error = Generic in cls.__orig_bases__
- else:
- error = (Generic in cls.__bases__ and
- cls.__name__ != 'Protocol' and
- type(cls) != _TypedDictMeta)
- if error:
- raise TypeError("Cannot inherit from plain Generic")
- if '__orig_bases__' in cls.__dict__:
- tvars = _collect_parameters(cls.__orig_bases__)
- # Look for Generic[T1, ..., Tn].
- # If found, tvars must be a subset of it.
- # If not found, tvars is it.
- # Also check for and reject plain Generic,
- # and reject multiple Generic[...].
- gvars = None
- for base in cls.__orig_bases__:
- if (isinstance(base, _GenericAlias) and
- base.__origin__ is Generic):
- if gvars is not None:
- raise TypeError(
- "Cannot inherit from Generic[...] multiple times.")
- gvars = base.__parameters__
- if gvars is not None:
- tvarset = set(tvars)
- gvarset = set(gvars)
- if not tvarset <= gvarset:
- s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
- s_args = ', '.join(str(g) for g in gvars)
- raise TypeError(f"Some type variables ({s_vars}) are"
- f" not listed in Generic[{s_args}]")
- tvars = gvars
- cls.__parameters__ = tuple(tvars)
-
-
-class _TypingEllipsis:
- """Internal placeholder for ... (ellipsis)."""
-
-
-_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
- '_is_protocol', '_is_runtime_protocol', '__final__']
-
-_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
- '__init__', '__module__', '__new__', '__slots__',
- '__subclasshook__', '__weakref__', '__class_getitem__']
-
-# These special attributes will be not collected as protocol members.
-EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']
-
-
-def _get_protocol_attrs(cls):
- """Collect protocol members from a protocol class objects.
-
- This includes names actually defined in the class dictionary, as well
- as names that appear in annotations. Special names (above) are skipped.
- """
- attrs = set()
- for base in cls.__mro__[:-1]: # without object
- if base.__name__ in ('Protocol', 'Generic'):
- continue
- annotations = getattr(base, '__annotations__', {})
- for attr in list(base.__dict__.keys()) + list(annotations.keys()):
- if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
- attrs.add(attr)
- return attrs
-
-
-def _is_callable_members_only(cls):
- # PEP 544 prohibits using issubclass() with protocols that have non-method members.
- return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
-
-
-def _no_init_or_replace_init(self, *args, **kwargs):
- cls = type(self)
-
- if cls._is_protocol:
- raise TypeError('Protocols cannot be instantiated')
-
- # Already using a custom `__init__`. No need to calculate correct
- # `__init__` to call. This can lead to RecursionError. See bpo-45121.
- if cls.__init__ is not _no_init_or_replace_init:
- return
-
- # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
- # The first instantiation of the subclass will call `_no_init_or_replace_init` which
- # searches for a proper new `__init__` in the MRO. The new `__init__`
- # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
- # instantiation of the protocol subclass will thus use the new
- # `__init__` and no longer call `_no_init_or_replace_init`.
- for base in cls.__mro__:
- init = base.__dict__.get('__init__', _no_init_or_replace_init)
- if init is not _no_init_or_replace_init:
- cls.__init__ = init
- break
- else:
- # should not happen
- cls.__init__ = object.__init__
-
- cls.__init__(self, *args, **kwargs)
-
-
-def _caller(depth=1, default='__main__'):
- try:
- return sys._getframe(depth + 1).f_globals.get('__name__', default)
- except (AttributeError, ValueError): # For platforms without _getframe()
- return None
-
-
-def _allow_reckless_class_checks(depth=3):
- """Allow instance and class checks for special stdlib modules.
-
- The abc and functools modules indiscriminately call isinstance() and
- issubclass() on the whole MRO of a user class, which may contain protocols.
- """
- return _caller(depth) in {'abc', 'functools', None}
-
-
-_PROTO_ALLOWLIST = {
- 'collections.abc': [
- 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
- 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
- ],
- 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'],
-}
-
-
-class _ProtocolMeta(ABCMeta):
- # This metaclass is really unfortunate and exists only because of
- # the lack of __instancehook__.
- def __instancecheck__(cls, instance):
- # We need this method for situations where attributes are
- # assigned in __init__.
- if (
- getattr(cls, '_is_protocol', False) and
- not getattr(cls, '_is_runtime_protocol', False) and
- not _allow_reckless_class_checks(depth=2)
- ):
- raise TypeError("Instance and class checks can only be used with"
- " @runtime_checkable protocols")
-
- if ((not getattr(cls, '_is_protocol', False) or
- _is_callable_members_only(cls)) and
- issubclass(instance.__class__, cls)):
- return True
- if cls._is_protocol:
- if all(hasattr(instance, attr) and
- # All *methods* can be blocked by setting them to None.
- (not callable(getattr(cls, attr, None)) or
- getattr(instance, attr) is not None)
- for attr in _get_protocol_attrs(cls)):
- return True
- return super().__instancecheck__(instance)
-
-
-class Protocol(Generic, metaclass=_ProtocolMeta):
- """Base class for protocol classes.
-
- Protocol classes are defined as::
-
- class Proto(Protocol):
- def meth(self) -> int:
- ...
-
- Such classes are primarily used with static type checkers that recognize
- structural subtyping (static duck-typing).
-
- For example::
-
- class C:
- def meth(self) -> int:
- return 0
-
- def func(x: Proto) -> int:
- return x.meth()
-
- func(C()) # Passes static type check
-
- See PEP 544 for details. Protocol classes decorated with
- @typing.runtime_checkable act as simple-minded runtime protocols that check
- only the presence of given attributes, ignoring their type signatures.
- Protocol classes can be generic, they are defined as::
-
- class GenProto(Protocol[T]):
- def meth(self) -> T:
- ...
- """
-
- __slots__ = ()
- _is_protocol = True
- _is_runtime_protocol = False
-
- def __init_subclass__(cls, *args, **kwargs):
- super().__init_subclass__(*args, **kwargs)
-
- # Determine if this is a protocol or a concrete subclass.
- if not cls.__dict__.get('_is_protocol', False):
- cls._is_protocol = any(b is Protocol for b in cls.__bases__)
-
- # Set (or override) the protocol subclass hook.
- def _proto_hook(other):
- if not cls.__dict__.get('_is_protocol', False):
- return NotImplemented
-
- # First, perform various sanity checks.
- if not getattr(cls, '_is_runtime_protocol', False):
- if _allow_reckless_class_checks():
- return NotImplemented
- raise TypeError("Instance and class checks can only be used with"
- " @runtime_checkable protocols")
- if not _is_callable_members_only(cls):
- if _allow_reckless_class_checks():
- return NotImplemented
- raise TypeError("Protocols with non-method members"
- " don't support issubclass()")
- if not isinstance(other, type):
- # Same error message as for issubclass(1, int).
- raise TypeError('issubclass() arg 1 must be a class')
-
- # Second, perform the actual structural compatibility check.
- for attr in _get_protocol_attrs(cls):
- for base in other.__mro__:
- # Check if the members appears in the class dictionary...
- if attr in base.__dict__:
- if base.__dict__[attr] is None:
- return NotImplemented
- break
-
- # ...or in annotations, if it is a sub-protocol.
- annotations = getattr(base, '__annotations__', {})
- if (isinstance(annotations, collections.abc.Mapping) and
- attr in annotations and
- issubclass(other, Generic) and other._is_protocol):
- break
- else:
- return NotImplemented
- return True
-
- if '__subclasshook__' not in cls.__dict__:
- cls.__subclasshook__ = _proto_hook
-
- # We have nothing more to do for non-protocols...
- if not cls._is_protocol:
- return
-
- # ... otherwise check consistency of bases, and prohibit instantiation.
- for base in cls.__bases__:
- if not (base in (object, Generic) or
- base.__module__ in _PROTO_ALLOWLIST and
- base.__name__ in _PROTO_ALLOWLIST[base.__module__] or
- issubclass(base, Generic) and base._is_protocol):
- raise TypeError('Protocols can only inherit from other'
- ' protocols, got %r' % base)
- if cls.__init__ is Protocol.__init__:
- cls.__init__ = _no_init_or_replace_init
-
-
-class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True):
- """Runtime representation of an annotated type.
-
- At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
- with extra annotations. The alias behaves like a normal typing alias.
- Instantiating is the same as instantiating the underlying type; binding
- it to types is also the same.
-
- The metadata itself is stored in a '__metadata__' attribute as a tuple.
- """
-
- def __init__(self, origin, metadata):
- if isinstance(origin, _AnnotatedAlias):
- metadata = origin.__metadata__ + metadata
- origin = origin.__origin__
- super().__init__(origin, origin)
- self.__metadata__ = metadata
-
- def copy_with(self, params):
- assert len(params) == 1
- new_type = params[0]
- return _AnnotatedAlias(new_type, self.__metadata__)
-
- def __repr__(self):
- return "typing.Annotated[{}, {}]".format(
- _type_repr(self.__origin__),
- ", ".join(repr(a) for a in self.__metadata__)
- )
-
- def __reduce__(self):
- return operator.getitem, (
- Annotated, (self.__origin__,) + self.__metadata__
- )
-
- def __eq__(self, other):
- if not isinstance(other, _AnnotatedAlias):
- return NotImplemented
- return (self.__origin__ == other.__origin__
- and self.__metadata__ == other.__metadata__)
-
- def __hash__(self):
- return hash((self.__origin__, self.__metadata__))
-
- def __getattr__(self, attr):
- if attr in {'__name__', '__qualname__'}:
- return 'Annotated'
- return super().__getattr__(attr)
-
-
-class Annotated:
- """Add context-specific metadata to a type.
-
- Example: Annotated[int, runtime_check.Unsigned] indicates to the
- hypothetical runtime_check module that this type is an unsigned int.
- Every other consumer of this type can ignore this metadata and treat
- this type as int.
-
- The first argument to Annotated must be a valid type.
-
- Details:
-
- - It's an error to call `Annotated` with less than two arguments.
- - Access the metadata via the ``__metadata__`` attribute::
-
- assert Annotated[int, '$'].__metadata__ == ('$',)
-
- - Nested Annotated types are flattened::
-
- assert Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3]
-
- - Instantiating an annotated type is equivalent to instantiating the
- underlying type::
-
- assert Annotated[C, Ann1](5) == C(5)
-
- - Annotated can be used as a generic type alias::
-
- Optimized: TypeAlias = Annotated[T, runtime.Optimize()]
- assert Optimized[int] == Annotated[int, runtime.Optimize()]
-
- OptimizedList: TypeAlias = Annotated[list[T], runtime.Optimize()]
- assert OptimizedList[int] == Annotated[list[int], runtime.Optimize()]
-
- - Annotated cannot be used with an unpacked TypeVarTuple::
-
- Variadic: TypeAlias = Annotated[*Ts, Ann1] # NOT valid
-
- This would be equivalent to::
-
- Annotated[T1, T2, T3, ..., Ann1]
-
- where T1, T2 etc. are TypeVars, which would be invalid, because
- only one type should be passed to Annotated.
- """
-
- __slots__ = ()
-
- def __new__(cls, *args, **kwargs):
- raise TypeError("Type Annotated cannot be instantiated.")
-
- def __class_getitem__(cls, params):
- if not isinstance(params, tuple):
- params = (params,)
- return cls._class_getitem_inner(cls, *params)
-
- @_tp_cache(typed=True)
- def _class_getitem_inner(cls, *params):
- if len(params) < 2:
- raise TypeError("Annotated[...] should be used "
- "with at least two arguments (a type and an "
- "annotation).")
- if _is_unpacked_typevartuple(params[0]):
- raise TypeError("Annotated[...] should not be used with an "
- "unpacked TypeVarTuple")
- msg = "Annotated[t, ...]: t must be a type."
- origin = _type_check(params[0], msg, allow_special_forms=True)
- metadata = tuple(params[1:])
- return _AnnotatedAlias(origin, metadata)
-
- def __init_subclass__(cls, *args, **kwargs):
- raise TypeError(
- "Cannot subclass {}.Annotated".format(cls.__module__)
- )
-
-
-def runtime_checkable(cls):
- """Mark a protocol class as a runtime protocol.
-
- Such protocol can be used with isinstance() and issubclass().
- Raise TypeError if applied to a non-protocol class.
- This allows a simple-minded structural check very similar to
- one trick ponies in collections.abc such as Iterable.
-
- For example::
-
- @runtime_checkable
- class Closable(Protocol):
- def close(self): ...
-
- assert isinstance(open('/some/file'), Closable)
-
- Warning: this will check only the presence of the required methods,
- not their type signatures!
- """
- if not issubclass(cls, Generic) or not cls._is_protocol:
- raise TypeError('@runtime_checkable can be only applied to protocol classes,'
- ' got %r' % cls)
- cls._is_runtime_protocol = True
- return cls
-
-
-def cast(typ, val):
- """Cast a value to a type.
-
- This returns the value unchanged. To the type checker this
- signals that the return value has the designated type, but at
- runtime we intentionally don't check anything (we want this
- to be as fast as possible).
- """
- return val
-
-
-def assert_type(val, typ, /):
- """Ask a static type checker to confirm that the value is of the given type.
-
- At runtime this does nothing: it returns the first argument unchanged with no
- checks or side effects, no matter the actual type of the argument.
-
- When a static type checker encounters a call to assert_type(), it
- emits an error if the value is not of the specified type::
-
- def greet(name: str) -> None:
- assert_type(name, str) # OK
- assert_type(name, int) # type checker error
- """
- return val
-
-
-_allowed_types = (types.FunctionType, types.BuiltinFunctionType,
- types.MethodType, types.ModuleType,
- WrapperDescriptorType, MethodWrapperType, MethodDescriptorType)
-
-
-def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
- """Return type hints for an object.
-
- This is often the same as obj.__annotations__, but it handles
- forward references encoded as string literals and recursively replaces all
- 'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
-
- The argument may be a module, class, method, or function. The annotations
- are returned as a dictionary. For classes, annotations include also
- inherited members.
-
- TypeError is raised if the argument is not of a type that can contain
- annotations, and an empty dictionary is returned if no annotations are
- present.
-
- BEWARE -- the behavior of globalns and localns is counterintuitive
- (unless you are familiar with how eval() and exec() work). The
- search order is locals first, then globals.
-
- - If no dict arguments are passed, an attempt is made to use the
- globals from obj (or the respective module's globals for classes),
- and these are also used as the locals. If the object does not appear
- to have globals, an empty dictionary is used. For classes, the search
- order is globals first then locals.
-
- - If one dict argument is passed, it is used for both globals and
- locals.
-
- - If two dict arguments are passed, they specify globals and
- locals, respectively.
- """
- if getattr(obj, '__no_type_check__', None):
- return {}
- # Classes require a special treatment.
- if isinstance(obj, type):
- hints = {}
- for base in reversed(obj.__mro__):
- if globalns is None:
- base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
- else:
- base_globals = globalns
- ann = base.__dict__.get('__annotations__', {})
- if isinstance(ann, types.GetSetDescriptorType):
- ann = {}
- base_locals = dict(vars(base)) if localns is None else localns
- if localns is None and globalns is None:
- # This is surprising, but required. Before Python 3.10,
- # get_type_hints only evaluated the globalns of
- # a class. To maintain backwards compatibility, we reverse
- # the globalns and localns order so that eval() looks into
- # *base_globals* first rather than *base_locals*.
- # This only affects ForwardRefs.
- base_globals, base_locals = base_locals, base_globals
- for name, value in ann.items():
- if value is None:
- value = type(None)
- if isinstance(value, str):
- value = ForwardRef(value, is_argument=False, is_class=True)
- value = _eval_type(value, base_globals, base_locals)
- hints[name] = value
- return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
-
- if globalns is None:
- if isinstance(obj, types.ModuleType):
- globalns = obj.__dict__
- else:
- nsobj = obj
- # Find globalns for the unwrapped object.
- while hasattr(nsobj, '__wrapped__'):
- nsobj = nsobj.__wrapped__
- globalns = getattr(nsobj, '__globals__', {})
- if localns is None:
- localns = globalns
- elif localns is None:
- localns = globalns
- hints = getattr(obj, '__annotations__', None)
- if hints is None:
- # Return empty annotations for something that _could_ have them.
- if isinstance(obj, _allowed_types):
- return {}
- else:
- raise TypeError('{!r} is not a module, class, method, '
- 'or function.'.format(obj))
- hints = dict(hints)
- for name, value in hints.items():
- if value is None:
- value = type(None)
- if isinstance(value, str):
- # class-level forward refs were handled above, this must be either
- # a module-level annotation or a function argument annotation
- value = ForwardRef(
- value,
- is_argument=not isinstance(obj, types.ModuleType),
- is_class=False,
- )
- hints[name] = _eval_type(value, globalns, localns)
- return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
-
-
-def _strip_annotations(t):
- """Strip the annotations from a given type."""
- if isinstance(t, _AnnotatedAlias):
- return _strip_annotations(t.__origin__)
- if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
- return _strip_annotations(t.__args__[0])
- if isinstance(t, _GenericAlias):
- stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
- if stripped_args == t.__args__:
- return t
- return t.copy_with(stripped_args)
- if isinstance(t, GenericAlias):
- stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
- if stripped_args == t.__args__:
- return t
- return GenericAlias(t.__origin__, stripped_args)
- if isinstance(t, types.UnionType):
- stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
- if stripped_args == t.__args__:
- return t
- return functools.reduce(operator.or_, stripped_args)
-
- return t
-
-
-def get_origin(tp):
- """Get the unsubscripted version of a type.
-
- This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar,
- Annotated, and others. Return None for unsupported types.
-
- Examples::
-
- >>> P = ParamSpec('P')
- >>> assert get_origin(Literal[42]) is Literal
- >>> assert get_origin(int) is None
- >>> assert get_origin(ClassVar[int]) is ClassVar
- >>> assert get_origin(Generic) is Generic
- >>> assert get_origin(Generic[T]) is Generic
- >>> assert get_origin(Union[T, int]) is Union
- >>> assert get_origin(List[Tuple[T, T]][int]) is list
- >>> assert get_origin(P.args) is P
- """
- if isinstance(tp, _AnnotatedAlias):
- return Annotated
- if isinstance(tp, (_BaseGenericAlias, GenericAlias,
- ParamSpecArgs, ParamSpecKwargs)):
- return tp.__origin__
- if tp is Generic:
- return Generic
- if isinstance(tp, types.UnionType):
- return types.UnionType
- return None
-
-
-def get_args(tp):
- """Get type arguments with all substitutions performed.
-
- For unions, basic simplifications used by Union constructor are performed.
-
- Examples::
-
- >>> T = TypeVar('T')
- >>> assert get_args(Dict[str, int]) == (str, int)
- >>> assert get_args(int) == ()
- >>> assert get_args(Union[int, Union[T, int], str][int]) == (int, str)
- >>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
- >>> assert get_args(Callable[[], T][int]) == ([], int)
- """
- if isinstance(tp, _AnnotatedAlias):
- return (tp.__origin__,) + tp.__metadata__
- if isinstance(tp, (_GenericAlias, GenericAlias)):
- res = tp.__args__
- if _should_unflatten_callable_args(tp, res):
- res = (list(res[:-1]), res[-1])
- return res
- if isinstance(tp, types.UnionType):
- return tp.__args__
- return ()
-
-
-def is_typeddict(tp):
- """Check if an annotation is a TypedDict class.
-
- For example::
-
- >>> from typing import TypedDict
- >>> class Film(TypedDict):
- ... title: str
- ... year: int
- ...
- >>> is_typeddict(Film)
- True
- >>> is_typeddict(dict)
- False
- """
- return isinstance(tp, _TypedDictMeta)
-
-
-_ASSERT_NEVER_REPR_MAX_LENGTH = 100
-
-
-def assert_never(arg: Never, /) -> Never:
- """Statically assert that a line of code is unreachable.
-
- Example::
-
- def int_or_str(arg: int | str) -> None:
- match arg:
- case int():
- print("It's an int")
- case str():
- print("It's a str")
- case _:
- assert_never(arg)
-
- If a type checker finds that a call to assert_never() is
- reachable, it will emit an error.
-
- At runtime, this throws an exception when called.
- """
- value = repr(arg)
- if len(value) > _ASSERT_NEVER_REPR_MAX_LENGTH:
- value = value[:_ASSERT_NEVER_REPR_MAX_LENGTH] + '...'
- raise AssertionError(f"Expected code to be unreachable, but got: {value}")
-
-
-def no_type_check(arg):
- """Decorator to indicate that annotations are not type hints.
-
- The argument must be a class or function; if it is a class, it
- applies recursively to all methods and classes defined in that class
- (but not to methods defined in its superclasses or subclasses).
-
- This mutates the function(s) or class(es) in place.
- """
- if isinstance(arg, type):
- for key in dir(arg):
- obj = getattr(arg, key)
- if (
- not hasattr(obj, '__qualname__')
- or obj.__qualname__ != f'{arg.__qualname__}.{obj.__name__}'
- or getattr(obj, '__module__', None) != arg.__module__
- ):
- # We only modify objects that are defined in this type directly.
- # If classes / methods are nested in multiple layers,
- # we will modify them when processing their direct holders.
- continue
- # Instance, class, and static methods:
- if isinstance(obj, types.FunctionType):
- obj.__no_type_check__ = True
- if isinstance(obj, types.MethodType):
- obj.__func__.__no_type_check__ = True
- # Nested types:
- if isinstance(obj, type):
- no_type_check(obj)
- try:
- arg.__no_type_check__ = True
- except TypeError: # built-in classes
- pass
- return arg
-
-
-def no_type_check_decorator(decorator):
- """Decorator to give another decorator the @no_type_check effect.
-
- This wraps the decorator with something that wraps the decorated
- function in @no_type_check.
- """
- @functools.wraps(decorator)
- def wrapped_decorator(*args, **kwds):
- func = decorator(*args, **kwds)
- func = no_type_check(func)
- return func
-
- return wrapped_decorator
-
-
-def _overload_dummy(*args, **kwds):
- """Helper for @overload to raise when called."""
- raise NotImplementedError(
- "You should not call an overloaded function. "
- "A series of @overload-decorated functions "
- "outside a stub module should always be followed "
- "by an implementation that is not @overload-ed.")
-
-
-# {module: {qualname: {firstlineno: func}}}
-_overload_registry = defaultdict(functools.partial(defaultdict, dict))
-
-
-def overload(func):
- """Decorator for overloaded functions/methods.
-
- In a stub file, place two or more stub definitions for the same
- function in a row, each decorated with @overload.
-
- For example::
-
- @overload
- def utf8(value: None) -> None: ...
- @overload
- def utf8(value: bytes) -> bytes: ...
- @overload
- def utf8(value: str) -> bytes: ...
-
- In a non-stub file (i.e. a regular .py file), do the same but
- follow it with an implementation. The implementation should *not*
- be decorated with @overload::
-
- @overload
- def utf8(value: None) -> None: ...
- @overload
- def utf8(value: bytes) -> bytes: ...
- @overload
- def utf8(value: str) -> bytes: ...
- def utf8(value):
- ... # implementation goes here
-
- The overloads for a function can be retrieved at runtime using the
- get_overloads() function.
- """
- # classmethod and staticmethod
- f = getattr(func, "__func__", func)
- try:
- _overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func
- except AttributeError:
- # Not a normal function; ignore.
- pass
- return _overload_dummy
-
-
-def get_overloads(func):
- """Return all defined overloads for *func* as a sequence."""
- # classmethod and staticmethod
- f = getattr(func, "__func__", func)
- if f.__module__ not in _overload_registry:
- return []
- mod_dict = _overload_registry[f.__module__]
- if f.__qualname__ not in mod_dict:
- return []
- return list(mod_dict[f.__qualname__].values())
-
-
-def clear_overloads():
- """Clear all overloads in the registry."""
- _overload_registry.clear()
-
-
-def final(f):
- """Decorator to indicate final methods and final classes.
-
- Use this decorator to indicate to type checkers that the decorated
- method cannot be overridden, and decorated class cannot be subclassed.
-
- For example::
-
- class Base:
- @final
- def done(self) -> None:
- ...
- class Sub(Base):
- def done(self) -> None: # Error reported by type checker
- ...
-
- @final
- class Leaf:
- ...
- class Other(Leaf): # Error reported by type checker
- ...
-
- There is no runtime checking of these properties. The decorator
- attempts to set the ``__final__`` attribute to ``True`` on the decorated
- object to allow runtime introspection.
- """
- try:
- f.__final__ = True
- except (AttributeError, TypeError):
- # Skip the attribute silently if it is not writable.
- # AttributeError happens if the object has __slots__ or a
- # read-only property, TypeError if it's a builtin class.
- pass
- return f
-
-
-# Some unconstrained type variables. These are used by the container types.
-# (These are not for export.)
-T = TypeVar('T') # Any type.
-KT = TypeVar('KT') # Key type.
-VT = TypeVar('VT') # Value type.
-T_co = TypeVar('T_co', covariant=True) # Any type covariant containers.
-V_co = TypeVar('V_co', covariant=True) # Any type covariant containers.
-VT_co = TypeVar('VT_co', covariant=True) # Value type covariant containers.
-T_contra = TypeVar('T_contra', contravariant=True) # Ditto contravariant.
-# Internal type variable used for Type[].
-CT_co = TypeVar('CT_co', covariant=True, bound=type)
-
-# A useful type variable with constraints. This represents string types.
-# (This one *is* for export!)
-AnyStr = TypeVar('AnyStr', bytes, str)
-
-
-# Various ABCs mimicking those in collections.abc.
-_alias = _SpecialGenericAlias
-
-Hashable = _alias(collections.abc.Hashable, 0) # Not generic.
-Awaitable = _alias(collections.abc.Awaitable, 1)
-Coroutine = _alias(collections.abc.Coroutine, 3)
-AsyncIterable = _alias(collections.abc.AsyncIterable, 1)
-AsyncIterator = _alias(collections.abc.AsyncIterator, 1)
-Iterable = _alias(collections.abc.Iterable, 1)
-Iterator = _alias(collections.abc.Iterator, 1)
-Reversible = _alias(collections.abc.Reversible, 1)
-Sized = _alias(collections.abc.Sized, 0) # Not generic.
-Container = _alias(collections.abc.Container, 1)
-Collection = _alias(collections.abc.Collection, 1)
-Callable = _CallableType(collections.abc.Callable, 2)
-Callable.__doc__ = \
- """Deprecated alias to collections.abc.Callable.
-
- Callable[[int], str] signifies a function that takes a single
- parameter of type int and returns a str.
-
- The subscription syntax must always be used with exactly two
- values: the argument list and the return type.
- The argument list must be a list of types, a ParamSpec,
- Concatenate or ellipsis. The return type must be a single type.
-
- There is no syntax to indicate optional or keyword arguments;
- such function types are rarely used as callback types.
- """
-AbstractSet = _alias(collections.abc.Set, 1, name='AbstractSet')
-MutableSet = _alias(collections.abc.MutableSet, 1)
-# NOTE: Mapping is only covariant in the value type.
-Mapping = _alias(collections.abc.Mapping, 2)
-MutableMapping = _alias(collections.abc.MutableMapping, 2)
-Sequence = _alias(collections.abc.Sequence, 1)
-MutableSequence = _alias(collections.abc.MutableSequence, 1)
-ByteString = _alias(collections.abc.ByteString, 0) # Not generic
-# Tuple accepts variable number of parameters.
-Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
-Tuple.__doc__ = \
- """Deprecated alias to builtins.tuple.
-
- Tuple[X, Y] is the cross-product type of X and Y.
-
- Example: Tuple[T1, T2] is a tuple of two elements corresponding
- to type variables T1 and T2. Tuple[int, float, str] is a tuple
- of an int, a float and a string.
-
- To specify a variable-length tuple of homogeneous type, use Tuple[T, ...].
- """
-List = _alias(list, 1, inst=False, name='List')
-Deque = _alias(collections.deque, 1, name='Deque')
-Set = _alias(set, 1, inst=False, name='Set')
-FrozenSet = _alias(frozenset, 1, inst=False, name='FrozenSet')
-MappingView = _alias(collections.abc.MappingView, 1)
-KeysView = _alias(collections.abc.KeysView, 1)
-ItemsView = _alias(collections.abc.ItemsView, 2)
-ValuesView = _alias(collections.abc.ValuesView, 1)
-ContextManager = _alias(contextlib.AbstractContextManager, 1, name='ContextManager')
-AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, 1, name='AsyncContextManager')
-Dict = _alias(dict, 2, inst=False, name='Dict')
-DefaultDict = _alias(collections.defaultdict, 2, name='DefaultDict')
-OrderedDict = _alias(collections.OrderedDict, 2)
-Counter = _alias(collections.Counter, 1)
-ChainMap = _alias(collections.ChainMap, 2)
-Generator = _alias(collections.abc.Generator, 3)
-AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2)
-Type = _alias(type, 1, inst=False, name='Type')
-Type.__doc__ = \
- """Deprecated alias to builtins.type.
-
- builtins.type or typing.Type can be used to annotate class objects.
- For example, suppose we have the following classes::
-
- class User: ... # Abstract base for User classes
- class BasicUser(User): ...
- class ProUser(User): ...
- class TeamUser(User): ...
-
- And a function that takes a class argument that's a subclass of
- User and returns an instance of the corresponding class::
-
- U = TypeVar('U', bound=User)
- def new_user(user_class: Type[U]) -> U:
- user = user_class()
- # (Here we could write the user object to a database)
- return user
-
- joe = new_user(BasicUser)
-
- At this point the type checker knows that joe has type BasicUser.
- """
-
-
-@runtime_checkable
-class SupportsInt(Protocol):
- """An ABC with one abstract method __int__."""
-
- __slots__ = ()
-
- @abstractmethod
- def __int__(self) -> int:
- pass
-
-
-@runtime_checkable
-class SupportsFloat(Protocol):
- """An ABC with one abstract method __float__."""
-
- __slots__ = ()
-
- @abstractmethod
- def __float__(self) -> float:
- pass
-
-
-@runtime_checkable
-class SupportsComplex(Protocol):
- """An ABC with one abstract method __complex__."""
-
- __slots__ = ()
-
- @abstractmethod
- def __complex__(self) -> complex:
- pass
-
-
-@runtime_checkable
-class SupportsBytes(Protocol):
- """An ABC with one abstract method __bytes__."""
-
- __slots__ = ()
-
- @abstractmethod
- def __bytes__(self) -> bytes:
- pass
-
-
-@runtime_checkable
-class SupportsIndex(Protocol):
- """An ABC with one abstract method __index__."""
-
- __slots__ = ()
-
- @abstractmethod
- def __index__(self) -> int:
- pass
-
-
-@runtime_checkable
-class SupportsAbs(Protocol[T_co]):
- """An ABC with one abstract method __abs__ that is covariant in its return type."""
-
- __slots__ = ()
-
- @abstractmethod
- def __abs__(self) -> T_co:
- pass
-
-
-@runtime_checkable
-class SupportsRound(Protocol[T_co]):
- """An ABC with one abstract method __round__ that is covariant in its return type."""
-
- __slots__ = ()
-
- @abstractmethod
- def __round__(self, ndigits: int = 0) -> T_co:
- pass
-
-
-def _make_nmtuple(name, types, module, defaults = ()):
- fields = [n for n, t in types]
- types = {n: _type_check(t, f"field {n} annotation must be a type")
- for n, t in types}
- nm_tpl = collections.namedtuple(name, fields,
- defaults=defaults, module=module)
- nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = types
- return nm_tpl
-
-
-# attributes prohibited to set in NamedTuple class syntax
-_prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__',
- '_fields', '_field_defaults',
- '_make', '_replace', '_asdict', '_source'})
-
-_special = frozenset({'__module__', '__name__', '__annotations__'})
-
-
-class NamedTupleMeta(type):
- def __new__(cls, typename, bases, ns):
- assert _NamedTuple in bases
- for base in bases:
- if base is not _NamedTuple and base is not Generic:
- raise TypeError(
- 'can only inherit from a NamedTuple type and Generic')
- bases = tuple(tuple if base is _NamedTuple else base for base in bases)
- types = ns.get('__annotations__', {})
- default_names = []
- for field_name in types:
- if field_name in ns:
- default_names.append(field_name)
- elif default_names:
- raise TypeError(f"Non-default namedtuple field {field_name} "
- f"cannot follow default field"
- f"{'s' if len(default_names) > 1 else ''} "
- f"{', '.join(default_names)}")
- nm_tpl = _make_nmtuple(typename, types.items(),
- defaults=[ns[n] for n in default_names],
- module=ns['__module__'])
- nm_tpl.__bases__ = bases
- if Generic in bases:
- class_getitem = Generic.__class_getitem__.__func__
- nm_tpl.__class_getitem__ = classmethod(class_getitem)
- # update from user namespace without overriding special namedtuple attributes
- for key in ns:
- if key in _prohibited:
- raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
- elif key not in _special and key not in nm_tpl._fields:
- setattr(nm_tpl, key, ns[key])
- if Generic in bases:
- nm_tpl.__init_subclass__()
- return nm_tpl
-
-
-def NamedTuple(typename, fields=None, /, **kwargs):
- """Typed version of namedtuple.
-
- Usage::
-
- class Employee(NamedTuple):
- name: str
- id: int
-
- This is equivalent to::
-
- Employee = collections.namedtuple('Employee', ['name', 'id'])
-
- The resulting class has an extra __annotations__ attribute, giving a
- dict that maps field names to types. (The field names are also in
- the _fields attribute, which is part of the namedtuple API.)
- An alternative equivalent functional syntax is also accepted::
-
- Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- """
- if fields is None:
- fields = kwargs.items()
- elif kwargs:
- raise TypeError("Either list of fields or keywords"
- " can be provided to NamedTuple, not both")
- return _make_nmtuple(typename, fields, module=_caller())
-
-_NamedTuple = type.__new__(NamedTupleMeta, 'NamedTuple', (), {})
-
-def _namedtuple_mro_entries(bases):
- assert NamedTuple in bases
- return (_NamedTuple,)
-
-NamedTuple.__mro_entries__ = _namedtuple_mro_entries
-
-
-class _TypedDictMeta(type):
- def __new__(cls, name, bases, ns, total=True):
- """Create a new typed dict class object.
-
- This method is called when TypedDict is subclassed,
- or when TypedDict is instantiated. This way
- TypedDict supports all three syntax forms described in its docstring.
- Subclasses and instances of TypedDict return actual dictionaries.
- """
- for base in bases:
- if type(base) is not _TypedDictMeta and base is not Generic:
- raise TypeError('cannot inherit from both a TypedDict type '
- 'and a non-TypedDict base class')
-
- if any(issubclass(b, Generic) for b in bases):
- generic_base = (Generic,)
- else:
- generic_base = ()
-
- tp_dict = type.__new__(_TypedDictMeta, name, (*generic_base, dict), ns)
-
- annotations = {}
- own_annotations = ns.get('__annotations__', {})
- msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
- own_annotations = {
- n: _type_check(tp, msg, module=tp_dict.__module__)
- for n, tp in own_annotations.items()
- }
- required_keys = set()
- optional_keys = set()
-
- for base in bases:
- annotations.update(base.__dict__.get('__annotations__', {}))
-
- base_required = base.__dict__.get('__required_keys__', set())
- required_keys |= base_required
- optional_keys -= base_required
-
- base_optional = base.__dict__.get('__optional_keys__', set())
- required_keys -= base_optional
- optional_keys |= base_optional
-
- annotations.update(own_annotations)
- for annotation_key, annotation_type in own_annotations.items():
- annotation_origin = get_origin(annotation_type)
- if annotation_origin is Annotated:
- annotation_args = get_args(annotation_type)
- if annotation_args:
- annotation_type = annotation_args[0]
- annotation_origin = get_origin(annotation_type)
-
- if annotation_origin is Required:
- is_required = True
- elif annotation_origin is NotRequired:
- is_required = False
- else:
- is_required = total
-
- if is_required:
- required_keys.add(annotation_key)
- optional_keys.discard(annotation_key)
- else:
- optional_keys.add(annotation_key)
- required_keys.discard(annotation_key)
-
- assert required_keys.isdisjoint(optional_keys), (
- f"Required keys overlap with optional keys in {name}:"
- f" {required_keys=}, {optional_keys=}"
- )
- tp_dict.__annotations__ = annotations
- tp_dict.__required_keys__ = frozenset(required_keys)
- tp_dict.__optional_keys__ = frozenset(optional_keys)
- if not hasattr(tp_dict, '__total__'):
- tp_dict.__total__ = total
- return tp_dict
-
- __call__ = dict # static method
-
- def __subclasscheck__(cls, other):
- # Typed dicts are only for static structural subtyping.
- raise TypeError('TypedDict does not support instance and class checks')
-
- __instancecheck__ = __subclasscheck__
-
-
-def TypedDict(typename, fields=None, /, *, total=True, **kwargs):
- """A simple typed namespace. At runtime it is equivalent to a plain dict.
-
- TypedDict creates a dictionary type such that a type checker will expect all
- instances to have a certain set of keys, where each key is
- associated with a value of a consistent type. This expectation
- is not checked at runtime.
-
- Usage::
-
- >>> class Point2D(TypedDict):
- ... x: int
- ... y: int
- ... label: str
- ...
- >>> a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
- >>> b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
- >>> Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
- True
-
- The type info can be accessed via the Point2D.__annotations__ dict, and
- the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
- TypedDict supports an additional equivalent form::
-
- Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
-
- By default, all keys must be present in a TypedDict. It is possible
- to override this by specifying totality::
-
- class Point2D(TypedDict, total=False):
- x: int
- y: int
-
- This means that a Point2D TypedDict can have any of the keys omitted. A type
- checker is only expected to support a literal False or True as the value of
- the total argument. True is the default, and makes all items defined in the
- class body be required.
-
- The Required and NotRequired special forms can also be used to mark
- individual keys as being required or not required::
-
- class Point2D(TypedDict):
- x: int # the "x" key must always be present (Required is the default)
- y: NotRequired[int] # the "y" key can be omitted
-
- See PEP 655 for more details on Required and NotRequired.
- """
- if fields is None:
- fields = kwargs
- elif kwargs:
- raise TypeError("TypedDict takes either a dict or keyword arguments,"
- " but not both")
- if kwargs:
- warnings.warn(
- "The kwargs-based syntax for TypedDict definitions is deprecated "
- "in Python 3.11, will be removed in Python 3.13, and may not be "
- "understood by third-party type checkers.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- ns = {'__annotations__': dict(fields)}
- module = _caller()
- if module is not None:
- # Setting correct module is necessary to make typed dict classes pickleable.
- ns['__module__'] = module
-
- return _TypedDictMeta(typename, (), ns, total=total)
-
-_TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {})
-TypedDict.__mro_entries__ = lambda bases: (_TypedDict,)
-
-
-@_SpecialForm
-def Required(self, parameters):
- """Special typing construct to mark a TypedDict key as required.
-
- This is mainly useful for total=False TypedDicts.
-
- For example::
-
- class Movie(TypedDict, total=False):
- title: Required[str]
- year: int
-
- m = Movie(
- title='The Matrix', # typechecker error if key is omitted
- year=1999,
- )
-
- There is no runtime checking that a required key is actually provided
- when instantiating a related TypedDict.
- """
- item = _type_check(parameters, f'{self._name} accepts only a single type.')
- return _GenericAlias(self, (item,))
-
-
-@_SpecialForm
-def NotRequired(self, parameters):
- """Special typing construct to mark a TypedDict key as potentially missing.
-
- For example::
-
- class Movie(TypedDict):
- title: str
- year: NotRequired[int]
-
- m = Movie(
- title='The Matrix', # typechecker error if key is omitted
- year=1999,
- )
- """
- item = _type_check(parameters, f'{self._name} accepts only a single type.')
- return _GenericAlias(self, (item,))
-
-
-class NewType:
- """NewType creates simple unique types with almost zero runtime overhead.
-
- NewType(name, tp) is considered a subtype of tp
- by static type checkers. At runtime, NewType(name, tp) returns
- a dummy callable that simply returns its argument.
-
- Usage::
-
- UserId = NewType('UserId', int)
-
- def name_by_id(user_id: UserId) -> str:
- ...
-
- UserId('user') # Fails type check
-
- name_by_id(42) # Fails type check
- name_by_id(UserId(42)) # OK
-
- num = UserId(5) + 1 # type: int
- """
-
- __call__ = _idfunc
-
- def __init__(self, name, tp):
- self.__qualname__ = name
- if '.' in name:
- name = name.rpartition('.')[-1]
- self.__name__ = name
- self.__supertype__ = tp
- def_mod = _caller()
- if def_mod != 'typing':
- self.__module__ = def_mod
-
- def __mro_entries__(self, bases):
- # We defined __mro_entries__ to get a better error message
- # if a user attempts to subclass a NewType instance. bpo-46170
- superclass_name = self.__name__
-
- class Dummy:
- def __init_subclass__(cls):
- subclass_name = cls.__name__
- raise TypeError(
- f"Cannot subclass an instance of NewType. Perhaps you were looking for: "
- f"`{subclass_name} = NewType({subclass_name!r}, {superclass_name})`"
- )
-
- return (Dummy,)
-
- def __repr__(self):
- return f'{self.__module__}.{self.__qualname__}'
-
- def __reduce__(self):
- return self.__qualname__
-
- def __or__(self, other):
- return Union[self, other]
-
- def __ror__(self, other):
- return Union[other, self]
-
-
-# Python-version-specific alias (Python 2: unicode; Python 3: str)
-Text = str
-
-
-# Constant that's True when type checking, but False here.
-TYPE_CHECKING = False
-
-
-class IO(Generic[AnyStr]):
- """Generic base class for TextIO and BinaryIO.
-
- This is an abstract, generic version of the return of open().
-
- NOTE: This does not distinguish between the different possible
- classes (text vs. binary, read vs. write vs. read/write,
- append-only, unbuffered). The TextIO and BinaryIO subclasses
- below capture the distinctions between text vs. binary, which is
- pervasive in the interface; however we currently do not offer a
- way to track the other distinctions in the type system.
- """
-
- __slots__ = ()
-
- @property
- @abstractmethod
- def mode(self) -> str:
- pass
-
- @property
- @abstractmethod
- def name(self) -> str:
- pass
-
- @abstractmethod
- def close(self) -> None:
- pass
-
- @property
- @abstractmethod
- def closed(self) -> bool:
- pass
-
- @abstractmethod
- def fileno(self) -> int:
- pass
-
- @abstractmethod
- def flush(self) -> None:
- pass
-
- @abstractmethod
- def isatty(self) -> bool:
- pass
-
- @abstractmethod
- def read(self, n: int = -1) -> AnyStr:
- pass
-
- @abstractmethod
- def readable(self) -> bool:
- pass
-
- @abstractmethod
- def readline(self, limit: int = -1) -> AnyStr:
- pass
-
- @abstractmethod
- def readlines(self, hint: int = -1) -> List[AnyStr]:
- pass
-
- @abstractmethod
- def seek(self, offset: int, whence: int = 0) -> int:
- pass
-
- @abstractmethod
- def seekable(self) -> bool:
- pass
-
- @abstractmethod
- def tell(self) -> int:
- pass
-
- @abstractmethod
- def truncate(self, size: int = None) -> int:
- pass
-
- @abstractmethod
- def writable(self) -> bool:
- pass
-
- @abstractmethod
- def write(self, s: AnyStr) -> int:
- pass
-
- @abstractmethod
- def writelines(self, lines: List[AnyStr]) -> None:
- pass
-
- @abstractmethod
- def __enter__(self) -> 'IO[AnyStr]':
- pass
-
- @abstractmethod
- def __exit__(self, type, value, traceback) -> None:
- pass
-
-
-class BinaryIO(IO[bytes]):
- """Typed version of the return of open() in binary mode."""
-
- __slots__ = ()
-
- @abstractmethod
- def write(self, s: Union[bytes, bytearray]) -> int:
- pass
-
- @abstractmethod
- def __enter__(self) -> 'BinaryIO':
- pass
-
-
-class TextIO(IO[str]):
- """Typed version of the return of open() in text mode."""
-
- __slots__ = ()
-
- @property
- @abstractmethod
- def buffer(self) -> BinaryIO:
- pass
-
- @property
- @abstractmethod
- def encoding(self) -> str:
- pass
-
- @property
- @abstractmethod
- def errors(self) -> Optional[str]:
- pass
-
- @property
- @abstractmethod
- def line_buffering(self) -> bool:
- pass
-
- @property
- @abstractmethod
- def newlines(self) -> Any:
- pass
-
- @abstractmethod
- def __enter__(self) -> 'TextIO':
- pass
-
-
-class _DeprecatedType(type):
- def __getattribute__(cls, name):
- if name not in {"__dict__", "__module__", "__doc__"} and name in cls.__dict__:
- warnings.warn(
- f"{cls.__name__} is deprecated, import directly "
- f"from typing instead. {cls.__name__} will be removed "
- "in Python 3.12.",
- DeprecationWarning,
- stacklevel=2,
- )
- return super().__getattribute__(name)
-
-
-class io(metaclass=_DeprecatedType):
- """Wrapper namespace for IO generic classes."""
-
- __all__ = ['IO', 'TextIO', 'BinaryIO']
- IO = IO
- TextIO = TextIO
- BinaryIO = BinaryIO
-
-
-io.__name__ = __name__ + '.io'
-sys.modules[io.__name__] = io
-
-Pattern = _alias(stdlib_re.Pattern, 1)
-Match = _alias(stdlib_re.Match, 1)
-
-class re(metaclass=_DeprecatedType):
- """Wrapper namespace for re type aliases."""
-
- __all__ = ['Pattern', 'Match']
- Pattern = Pattern
- Match = Match
-
-
-re.__name__ = __name__ + '.re'
-sys.modules[re.__name__] = re
-
-
-def reveal_type(obj: T, /) -> T:
- """Ask a static type checker to reveal the inferred type of an expression.
-
- When a static type checker encounters a call to ``reveal_type()``,
- it will emit the inferred type of the argument::
-
- x: int = 1
- reveal_type(x)
-
- Running a static type checker (e.g., mypy) on this example
- will produce output similar to 'Revealed type is "builtins.int"'.
-
- At runtime, the function prints the runtime type of the
- argument and returns the argument unchanged.
- """
- print(f"Runtime type is {type(obj).__name__!r}", file=sys.stderr)
- return obj
-
-
-def dataclass_transform(
- *,
- eq_default: bool = True,
- order_default: bool = False,
- kw_only_default: bool = False,
- field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
- **kwargs: Any,
-) -> Callable[[T], T]:
- """Decorator to mark an object as providing dataclass-like behaviour.
-
- The decorator can be applied to a function, class, or metaclass.
-
- Example usage with a decorator function::
-
- T = TypeVar("T")
-
- @dataclass_transform()
- def create_model(cls: type[T]) -> type[T]:
- ...
- return cls
-
- @create_model
- class CustomerModel:
- id: int
- name: str
-
- On a base class::
-
- @dataclass_transform()
- class ModelBase: ...
-
- class CustomerModel(ModelBase):
- id: int
- name: str
-
- On a metaclass::
-
- @dataclass_transform()
- class ModelMeta(type): ...
-
- class ModelBase(metaclass=ModelMeta): ...
-
- class CustomerModel(ModelBase):
- id: int
- name: str
-
- The ``CustomerModel`` classes defined above will
- be treated by type checkers similarly to classes created with
- ``@dataclasses.dataclass``.
- For example, type checkers will assume these classes have
- ``__init__`` methods that accept ``id`` and ``name``.
-
- The arguments to this decorator can be used to customize this behavior:
- - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be
- ``True`` or ``False`` if it is omitted by the caller.
- - ``order_default`` indicates whether the ``order`` parameter is
- assumed to be True or False if it is omitted by the caller.
- - ``kw_only_default`` indicates whether the ``kw_only`` parameter is
- assumed to be True or False if it is omitted by the caller.
- - ``field_specifiers`` specifies a static list of supported classes
- or functions that describe fields, similar to ``dataclasses.field()``.
- - Arbitrary other keyword arguments are accepted in order to allow for
- possible future extensions.
-
- At runtime, this decorator records its arguments in the
- ``__dataclass_transform__`` attribute on the decorated object.
- It has no other runtime effect.
-
- See PEP 681 for more details.
- """
- def decorator(cls_or_fn):
- cls_or_fn.__dataclass_transform__ = {
- "eq_default": eq_default,
- "order_default": order_default,
- "kw_only_default": kw_only_default,
- "field_specifiers": field_specifiers,
- "kwargs": kwargs,
- }
- return cls_or_fn
- return decorator
-').appendTo(this.out); - this.output = $('
' + - '' + - _("Hide Search Matches") + - "
" - ) - ); - }, - - /** - * helper function to hide the search marks again - */ - hideSearchWords: () => { - document - .querySelectorAll("#searchbox .highlight-link") - .forEach((el) => el.remove()); - document - .querySelectorAll("span.highlighted") - .forEach((el) => el.classList.remove("highlighted")); - localStorage.removeItem("sphinx_highlight_terms") - }, - - initEscapeListener: () => { - // only install a listener if it is really needed - if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; - - document.addEventListener("keydown", (event) => { - // bail for input elements - if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; - // bail with special keys - if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; - if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { - SphinxHighlight.hideSearchWords(); - event.preventDefault(); - } - }); - }, -}; - -_ready(() => { - /* Do not call highlightSearchWords() when we are on the search page. - * It will highlight words from the *previous* search query. - */ - if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords(); - SphinxHighlight.initEscapeListener(); -}); diff --git a/docs/_build/html/index.html b/docs/_build/html/index.html deleted file mode 100644 index 6f97f108..00000000 --- a/docs/_build/html/index.html +++ /dev/null @@ -1,131 +0,0 @@ - - - - - - -Contents:
- -| - n | ||
| - |
- notdiamond | - |
| - |
- notdiamond.llms.config | - |
| - |
- notdiamond.llms.request | - |
| - |
- notdiamond.metrics.metric | - |
| - |
- notdiamond.metrics.request | - |
| - |
- notdiamond.toolkit | - |
_If you already have existing projects in either OpenAI SDK or LangChain, check out our OpenAI and Langchain integration guides. Otherwise, continue reading.
-Create a main.py file in the same folder as the .env file you created earlier, or try it in Colab
from notdiamond.llms.llm import NDLLM
-from notdiamond.prompts.prompt import NDPrompt, NDContext, NDQuery, NDPromptTemplate
-from notdiamond.llms.providers import NDLLMProviders
-
-
-# Define your prompt and query
-prompt = NDPrompt("You are a world class software developer.") # The system prompt, defines the LLM's role
-query = NDQuery("Write a merge sort in Python.") # The specific query written by an end-user
-
-# Define the prompt template to combine prompt and query into a single string
-prompt_template = NDPromptTemplate("About you: {prompt}\n{query}",
- partial_variables={"prompt": prompt, "query": query})
-
-# Define the available LLMs you'd like to route between
-llm_providers = [NDLLMProviders.gpt_3_5_turbo, NDLLMProviders.gpt_4, NDLLMProviders.claude_2_1, NDLLMProviders.claude_3_opus_20240229, NDLLMProviders.gemini_pro]
-
-# Create the NDLLM object -> like a 'meta-LLM' combining all of the specified models
-nd_llm = NDLLM(llm_providers=llm_providers)
-
-# After fuzzy hashing the inputs, the best LLM is determined by the ND API and the LLM is called client-side
-result, session_id, provider = nd_llm.invoke(prompt_template=prompt_template)
-
-
-print("ND session ID: ", session_id) # A unique ID of the invoke. Important for future references back to ND API
-print("LLM called: ", provider.model) # The LLM routed to
-print("LLM output: ", result.content) # The LLM response
-> 👍 Run it!
-python main.py
from langchain_core.prompts import PromptTemplate
-from langchain_openai import ChatOpenAI
-
-from notdiamond import settings
-from notdiamond.llms.llm import NDLLM
-from notdiamond.llms.provider import NDLLMProvider
-from notdiamond.llms.providers import NDLLMProviders
-
-# 1. Simple Langchain project with PromptTemplate
-context = "You live in a hidden city that has remained undiscovered for centuries. The city is located in a dense jungle."
-user_input = "Tell me a joke about your city."
-
-prompt_template = PromptTemplate.from_template(
- "You are a world class storyteller that writes funny jokes. {context} {user_input}"
-)
-
-model = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=settings.OPENAI_API_KEY)
-chain = prompt_template | model
-
-result = chain.invoke({"context": context, "user_input": user_input})
-print("1 - LANGCHAIN RESULT")
-print(result.content)
-
-# 1. Switch to NotDiamond
-context = "You live in a hidden city that has remained undiscovered for centuries. The city is located in a dense jungle."
-user_input = "Tell me a joke about your city."
-
-prompt_template = PromptTemplate.from_template(
- "You are a world class storyteller that writes funny jokes. {context} {user_input}"
-)
-
-nd_llm = NDLLM(
- llm_providers=[
- NDLLMProviders.GPT_3_5_TURBO,
- NDLLMProviders.GPT_4,
- NDLLMProviders.CLAUDE_2_1,
- ]
-)
-result, session_id, _ = nd_llm.invoke(
- prompt_template=prompt_template,
- input={"context": context, "user_input": user_input},
-)
-
-print("1 - ND RESULTS")
-print(result.content)
-
-# 2. Switch to NotDiamond and have more control over providers
-context = "You live in a hidden city that has remained undiscovered for centuries. The city is located in a dense jungle."
-user_input = "Tell me a joke about your city."
-
-prompt_template = PromptTemplate.from_template(
- "You are a world class storyteller that writes funny jokes. {context} {user_input}"
-)
-
-gpt35_provider = NDLLMProvider(
- provider="openai",
- model="gpt-3.5-turbo",
- temperature=0,
- request_timeout=1200,
- max_retries=5,
- max_tokens=2000,
-)
-
-claude_provider = NDLLMProvider(provider="anthropic", model="claude-2.1", temperature=1)
-
-nd_llm = NDLLM(llm_providers=[gpt35_provider, claude_provider])
-result, session_id, best_llm = nd_llm.invoke(
- prompt_template=prompt_template,
- input={"context": context, "user_input": user_input},
-)
-
-print("2 - ND RESULTS")
-print(best_llm.provider)
-print(result.content)
-There are multiple cookbooks available to help you get started with NotDiamond. Check out the cookbooks folder in the GitHub repository.
Not Diamond is an AI model router that automatically determines which LLM is best-suited to respond to any query, improving LLM output quality by combining multiple LLMs into a meta-model that learns when to call each LLM.
-Maximize output quality: Not Diamond outperforms every foundation model on major evaluation benchmarks by always calling the best model for every prompt.
Reduce cost and latency: Not Diamond lets you define intelligent cost and latency tradeoffs to efficiently leverage smaller and cheaper models without degrading quality.
Train your own router: Not Diamond lets you train your own custom routers optimized to your data and use case.
Python, TypeScript, and REST API support: Not Diamond works across a variety of stacks.
Python: Requires Python 3.10+. It’s recommended that you create and activate a virtualenv prior to installing the package. For this example, we’ll be installing the optional additional create dependencies, which you can learn more about here.
-pip install notdiamond[create]
-Create a .env file with your Not Diamond API key and the API keys of the models you want to route between:
-NOTDIAMOND_API_KEY = "YOUR_NOTDIAMOND_API_KEY"
-OPENAI_API_KEY = "YOUR_OPENAI_API_KEY"
-ANTHROPIC_API_KEY = "YOUR_ANTHROPIC_API_KEY"
-Create a new file in the same directory as your .env file and copy and run the code below (you can toggle between Python and TypeScript in the top left of the code block):
-from notdiamond import NotDiamond
-
-# Define the Not Diamond routing client
-client = NotDiamond()
-
-# The best LLM is determined by Not Diamond based on the messages and specified models
-result, session_id, provider = client.chat.completions.create(
- messages=[
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Consiely explain merge sort."} # Adjust as desired
- ],
- model=['openai/gpt-3.5-turbo', 'openai/gpt-4o', 'anthropic/claude-3-5-sonnet-20240620']
-)
-
-print("ND session ID: ", session_id) # A unique ID of Not Diamond's recommendation
-print("LLM called: ", provider.model) # The LLM routed to
-print("LLM output: ", result.content) # The LLM response
-FeedbackRequestPayload
-ModelSelectRequestPayloadModelSelectRequestPayload.componentsModelSelectRequestPayload.formatted_promptModelSelectRequestPayload.llm_providersModelSelectRequestPayload.max_model_depthModelSelectRequestPayload.metricModelSelectRequestPayload.model_computed_fieldsModelSelectRequestPayload.model_configModelSelectRequestPayload.model_fieldsModelSelectRequestPayload.prompt_templateNDApiKeyValidator
-NotDiamondNotDiamond.ConfigNotDiamond.api_keyNotDiamond.defaultNotDiamond.hash_contentNotDiamond.latency_trackingNotDiamond.llm_configsNotDiamond.max_model_depthNotDiamond.max_retriesNotDiamond.model_configNotDiamond.nd_api_urlNotDiamond.preference_idNotDiamond.timeoutNotDiamond.toolsNotDiamond.tradeoffNotDiamond.user_agentEmbeddingConfig
-LLMConfig
-Entrypoint for fallback and retry features without changing existing code.
-Add this to existing codebase without other modifications to enable the following capabilities:
-Fallback to a different model if a model invocation fails.
If configured, fallback to a different provider if a model invocation fails -(eg. azure/gpt-4o fails -> invoke openai/gpt-4o)
Load-balance between models and providers, if specified.
Pass timeout and retry configurations to each invoke, optionally configured per model.
Pass model-specific messages on each retry (prepended to the provided messages parameter)
client (Union[ClientType, List[ClientType]]) – Clients to apply retry/fallback logic to.
models (Union[Dict[str, float], List[str]]) –
Models to use of the format <provider>/<model>. -Supports two formats:
----
-- -
List of models, eg. [“openai/gpt-4o”, “azure/gpt-4o”]. Models will be prioritized as listed.
- -
Dict of models to weights for load balancing, eg. {“openai/gpt-4o”: 0.9, “azure/gpt-4o”: 0.1}. -If a model invocation fails, the next model is selected by sampling using the remaining weights.
max_retries (Union[int, Dict[str, int]]) – Maximum number of retries. Can be configured globally or per model.
timeout (Union[float, Dict[str, float]]) – Timeout in seconds per model. Can be configured globally or per model.
model_messages (Dict[str, OpenAIMessagesType]) – Model-specific messages to prepend to messages on each invocation, formatted OpenAI-style. Can be -configured using any role which is valid as an initial message (eg. “system” or “user”, but not “assistant”).
api_key (Optional[str]) – Not Diamond API key for authentication. Unused for now - will offer logging and metrics in the future.
async_mode (bool) – Whether to manage clients as async.
backoff (Union[float, Dict[str, float]]) – Backoff factor for exponential backoff per each retry. Can be configured globally or per model.
Manager object that handles retries and fallbacks. Not required for usage.
-RetryManager
-If models is a list, the fallback model is selected in order after removing the failed model. -eg. If “openai/gpt-4o” fails for the list:
----
-- -
[“openai/gpt-4o”, “azure/gpt-4o”], “azure/gpt-4o” will be tried next
- -
[“openai/gpt-4o-mini”, “openai/gpt-4o”, “azure/gpt-4o”], “openai/gpt-4o-mini” will be tried next.
If models is a dict, the next model is selected by sampling using the remaining weights. -eg. If “openai/gpt-4o” fails for the dict:
----
-- -
{“openai/gpt-4o”: 0.9, “azure/gpt-4o”: 0.1}, “azure/gpt-4o” will be invoked 100% of the time
- -
{“openai/gpt-4o”: 0.5, “azure/gpt-4o”: 0.25, “openai/gpt-4o-mini”: 0.25}, then “azure/gpt-4o” and -“openai/gpt-4o-mini” can be invoked with 50% probability each.
Please refer to tests/test_init.py for more examples on how to use notdiamond.init.
-# ...existing workflow code, including client initialization...
-openai_client = OpenAI(...)
-azure_client = AzureOpenAI(...)
-
-# Add `notdiamond.init` to the workflow.
-notdiamond.init(
- [openai_client, azure_client],
- models={"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1},
- max_retries={"openai/gpt-4o": 3, "azure/gpt-4o": 1},
- timeout={"openai/gpt-4o": 10.0, "azure/gpt-4o": 5.0},
- model_messages={
- "openai/gpt-4o": [{"role": "user", "content": "Here is a prompt for OpenAI."}],
- "azure/gpt-4o": [{"role": "user", "content": "Here is a prompt for Azure."}],
- },
- api_key="sk-...",
- backoff=2.0,
-)
-
-# ...continue existing workflow code...
-response = openai_client.chat.completions.create(
- model="notdiamond",
- messages=[{"role": "user", "content": "Hello!"}]
-)
-Entrypoint for fallback and retry features without changing existing code.
-Add this to existing codebase without other modifications to enable the following capabilities:
-Fallback to a different model if a model invocation fails.
If configured, fallback to a different provider if a model invocation fails -(eg. azure/gpt-4o fails -> invoke openai/gpt-4o)
Load-balance between models and providers, if specified.
Pass timeout and retry configurations to each invoke, optionally configured per model.
Pass model-specific messages on each retry (prepended to the provided messages parameter)
client (Union[ClientType, List[ClientType]]) – Clients to apply retry/fallback logic to.
models (Union[Dict[str, float], List[str]]) –
Models to use of the format <provider>/<model>. -Supports two formats:
----
-- -
List of models, eg. [“openai/gpt-4o”, “azure/gpt-4o”]. Models will be prioritized as listed.
- -
Dict of models to weights for load balancing, eg. {“openai/gpt-4o”: 0.9, “azure/gpt-4o”: 0.1}. -If a model invocation fails, the next model is selected by sampling using the remaining weights.
max_retries (Union[int, Dict[str, int]]) – Maximum number of retries. Can be configured globally or per model.
timeout (Union[float, Dict[str, float]]) – Timeout in seconds per model. Can be configured globally or per model.
model_messages (Dict[str, OpenAIMessagesType]) – Model-specific messages to prepend to messages on each invocation, formatted OpenAI-style. Can be -configured using any role which is valid as an initial message (eg. “system” or “user”, but not “assistant”).
api_key (Optional[str]) – Not Diamond API key for authentication. Unused for now - will offer logging and metrics in the future.
async_mode (bool) – Whether to manage clients as async.
backoff (Union[float, Dict[str, float]]) – Backoff factor for exponential backoff per each retry. Can be configured globally or per model.
Manager object that handles retries and fallbacks. Not required for usage.
-RetryManager
-If models is a list, the fallback model is selected in order after removing the failed model. -eg. If “openai/gpt-4o” fails for the list:
----
-- -
[“openai/gpt-4o”, “azure/gpt-4o”], “azure/gpt-4o” will be tried next
- -
[“openai/gpt-4o-mini”, “openai/gpt-4o”, “azure/gpt-4o”], “openai/gpt-4o-mini” will be tried next.
If models is a dict, the next model is selected by sampling using the remaining weights. -eg. If “openai/gpt-4o” fails for the dict:
----
-- -
{“openai/gpt-4o”: 0.9, “azure/gpt-4o”: 0.1}, “azure/gpt-4o” will be invoked 100% of the time
- -
{“openai/gpt-4o”: 0.5, “azure/gpt-4o”: 0.25, “openai/gpt-4o-mini”: 0.25}, then “azure/gpt-4o” and -“openai/gpt-4o-mini” can be invoked with 50% probability each.
Please refer to tests/test_init.py for more examples on how to use notdiamond.init.
-# ...existing workflow code, including client initialization...
-openai_client = OpenAI(...)
-azure_client = AzureOpenAI(...)
-
-# Add `notdiamond.init` to the workflow.
-notdiamond.init(
- [openai_client, azure_client],
- models={"openai/gpt-4o": 0.9, "azure/gpt-4o": 0.1},
- max_retries={"openai/gpt-4o": 3, "azure/gpt-4o": 1},
- timeout={"openai/gpt-4o": 10.0, "azure/gpt-4o": 5.0},
- model_messages={
- "openai/gpt-4o": [{"role": "user", "content": "Here is a prompt for OpenAI."}],
- "azure/gpt-4o": [{"role": "user", "content": "Here is a prompt for Azure."}],
- },
- api_key="sk-...",
- backoff=2.0,
-)
-
-# ...continue existing workflow code...
-response = openai_client.chat.completions.create(
- model="notdiamond",
- messages=[{"role": "user", "content": "Hello!"}]
-)
-Bases: _NDRouterClient
Create a new model by parsing and validating input data from keyword arguments.
-Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be -validated to form a valid model.
-self is explicitly positional-only to allow self as a field name.
-nd_api_url (str | None)
user_agent (str | None)
api_key (str)
llm_configs (List[LLMConfig | str] | None)
default (LLMConfig | int | str)
max_model_depth (int | None)
latency_tracking (bool)
hash_content (bool)
tradeoff (str | None)
preference_id (str | None)
tools (Sequence[Dict[str, Any] | Callable] | None)
callbacks (List | None)
max_retries (int)
timeout (float)
API key required for making calls to NotDiamond. -You can get an API key via our dashboard: https://app.notdiamond.ai -If an API key is not set, it will check for NOTDIAMOND_API_KEY in .env file.
-Set a default LLM, so in case anything goes wrong in the flow, -as for example NotDiamond API call fails, your code won’t break and you have -a fallback model. There are various ways to configure a default model:
-Integer, specifying the index of the default provider from the llm_configs list
String, similar how you can specify llm_configs, of structure ‘provider_name/model_name’
LLMConfig, just directly specify the object of the provider
By default, we will set your first LLM in the list as the default.
-Hashing the content before being sent to the NotDiamond API. -By default this is False.
-Tracking and sending latency of LLM call to NotDiamond server as feedback, so we can improve our router. -By default this is turned on, set it to False to turn off.
-The list of LLMs that are available to route between.
-If your top recommended model is down, specify up to which depth of routing you’re willing to go. -If max_model_depth is not set, it defaults to the length of the llm_configs list. -If max_model_depth is set to 0, the init will fail. -If the value is larger than the llm_configs list length, we reset the value to len(llm_configs).
-The maximum number of retries to make when calling the Not Diamond API.
-Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
-The URL of the NotDiamond API. Defaults to settings.NOTDIAMOND_API_URL.
-The ID of the router preference that was configured via the Dashboard. Defaults to None.
-The timeout for the Not Diamond API call.
-Bind tools to the LLM object. The tools will be passed to the LLM object when invoking it.
-[DEPRECATED] The tradeoff constructor parameter is deprecated and will be removed in a future version. -Please specify the tradeoff when using model_select or invocation methods.
-Define tradeoff between “cost” and “latency” for the router to determine the best LLM for a given query. -If None is specified, then the router will not consider either cost or latency.
-The supported values: “cost”, “latency”
-Defaults to None.
-Bases: object
A NotDiamond embedding provider config (or EmbeddingConfig) is represented by a combination of provider and model. -Provider refers to the company of the foundational model, such as openai, anthropic, google. -The model represents the model name as defined by the owner company, such as text-embedding-3-large -Beside this you can also specify the API key for each provider or extra arguments -that are also supported by Langchain.
-All supported providers and models can be found in our docs.
-If the API key is not specified, the Config will try to read the key from an .env file before failing. -For example, the Config will look for OPENAI_API_KEY to authenticate any OpenAI provider.
-The name of the LLM provider (e.g., “openai”, “anthropic”). Must be one of the -predefined providers in POSSIBLE_EMBEDDING_PROVIDERS.
-str
-The name of the LLM model to use (e.g., “gpt-3.5-turbo”). -Must be one of the predefined models in POSSIBLE_MODELS.
-str
-The API key for accessing the LLM provider’s services. -Defaults to None, in which case it tries to fetch from the environment.
-Optional[str], optional
-Additional keyword arguments that might be necessary for specific providers or models.
-UnsupportedLLMProvider – If the provider or model specified is not supported.
-provider (str)
model (str)
api_key (str | None)
_summary_
-provider (str) – The name of the embedding provider (e.g., “openai”, “anthropic”).
model (str) – The name of the embedding model to use (e.g., “text-embedding-3-large”).
api_key (Optional[str], optional) – The API key for accessing the embedding provider’s services. -Defaults to None.
**kwargs – Additional keyword arguments that might be necessary for specific providers or models.
UnsupportedEmbeddingProvider – If the provider or model specified is not supported.
-_summary_
-provider (str) – The name of the embedding provider (e.g., “openai”, “anthropic”).
model (str) – The name of the embedding model to use (e.g., “text-embedding-3-large”).
api_key (Optional[str], optional) – The API key for accessing the embedding provider’s services. -Defaults to None.
**kwargs – Additional keyword arguments that might be necessary for specific providers or models.
UnsupportedEmbeddingProvider – If the provider or model specified is not supported.
-We allow our users to specify LLM providers for NotDiamond in the string format ‘provider_name/model_name’, -for example ‘openai/gpt-3.5-turbo’. Our workflows expect LLMConfig as -the base type, so this class method converts a string specification of an LLM provider into an -LLMConfig object.
-llm_provider (str) – this is the string definition of the LLM provider
-initialized object with correct provider and model
-api_key (str)
-Bases: object
A NotDiamond LLM provider config (or LLMConfig) is represented by a combination of provider and model. -Provider refers to the company of the foundational model, such as openai, anthropic, google. -The model represents the model name as defined by the owner company, such as gpt-3.5-turbo -Beside this you can also specify the API key for each provider, specify extra arguments -that are also supported by Langchain (eg. temperature), and a system prmopt to be used -with the provider. If the provider is selected during routing, then the system prompt will -be used, replacing the one in the message array if there are any.
-All supported providers and models can be found in our docs.
-If the API key it’s not specified, it will try to pick it up from an .env file before failing. -As example for OpenAI it will look for OPENAI_API_KEY.
-The name of the LLM provider (e.g., “openai”, “anthropic”). Must be one of the -predefined providers in POSSIBLE_PROVIDERS.
-str
-The name of the LLM model to use (e.g., “gpt-3.5-turbo”). -Must be one of the predefined models in POSSIBLE_MODELS.
-str
-The system prompt to use for the provider. Defaults to None.
-Optional[str], optional
-The API key for accessing the LLM provider’s services. -Defaults to None, in which case it tries to fetch from the settings.
-Optional[str], optional
-The OpenRouter model equivalent for this provider / model
-str
-Additional keyword arguments that might be necessary for specific providers or models.
-UnsupportedLLMProvider – If the provider or model specified is not supported.
-provider (str)
model (str)
is_custom (bool)
system_prompt (str | None)
context_length (int | None)
input_price (float | None)
custom_input_price (float | None)
output_price (float | None)
custom_output_price (float | None)
latency (float | None)
custom_latency (float | None)
api_key (str | None)
_summary_
-provider (str) – The name of the LLM provider (e.g., “openai”, “anthropic”).
model (str) – The name of the LLM model to use (e.g., “gpt-3.5-turbo”).
is_custom (bool) – Whether this is a custom model. Defaults to False.
system_prompt (Optional[str], optional) – The system prompt to use for the provider. Defaults to None.
context_length (Optional[int], optional) – Custom context window length for the provider/model.
custom_input_price (Optional[float], optional) – Custom input price (USD) per million tokens for this -provider/model; will default to public input price if available.
custom_output_price (Optional[float], optional) – Custom output price (USD) per million tokens for this -provider/model; will default to public output price if available.
custom_latency (Optional[float], optional) – Custom latency (time to first token) for provider/model.
api_key (Optional[str], optional) – The API key for accessing the LLM provider’s services. -Defaults to None.
**kwargs – Additional keyword arguments that might be necessary for specific providers or models.
input_price (float | None)
output_price (float | None)
latency (float | None)
UnsupportedLLMProvider – If the provider or model specified is not supported.
-_summary_
-provider (str) – The name of the LLM provider (e.g., “openai”, “anthropic”).
model (str) – The name of the LLM model to use (e.g., “gpt-3.5-turbo”).
is_custom (bool) – Whether this is a custom model. Defaults to False.
system_prompt (Optional[str], optional) – The system prompt to use for the provider. Defaults to None.
context_length (Optional[int], optional) – Custom context window length for the provider/model.
custom_input_price (Optional[float], optional) – Custom input price (USD) per million tokens for this -provider/model; will default to public input price if available.
custom_output_price (Optional[float], optional) – Custom output price (USD) per million tokens for this -provider/model; will default to public output price if available.
custom_latency (Optional[float], optional) – Custom latency (time to first token) for provider/model.
api_key (Optional[str], optional) – The API key for accessing the LLM provider’s services. -Defaults to None.
**kwargs – Additional keyword arguments that might be necessary for specific providers or models.
input_price (float | None)
output_price (float | None)
latency (float | None)
UnsupportedLLMProvider – If the provider or model specified is not supported.
-We allow our users to specify LLM providers for NotDiamond in the string format ‘provider_name/model_name’, -as example ‘openai/gpt-3.5-turbo’. Underlying our workflows we want to ensure we use LLMConfig as -the base type, so this class method converts a string specification of an LLM provider into an -LLMConfig object.
-llm_provider (str) – this is the string definition of the LLM provider
-initialized object with correct provider and model
-Converts the LLMConfig object to a dict in the format accepted by -the NotDiamond API.
-dict
-api_key (str)
-Bases: Enum
NDLLMProviders serves as a registry for the supported LLM models by NotDiamond. -It allows developers to easily specify available LLM providers for the router.
-refers to ‘gpt-3.5-turbo’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-3.5-turbo-0125’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4-0613’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4-1106-preview’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4-turbo’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4-turbo-preview’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4-turbo-2024-04-09’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4o-2024-05-13’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4o-2024-08-06’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4o’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4o-mini-2024-07-18’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4o-mini’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4-0125-preview’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4.1’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4.1-2025-04-14’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4.1-mini’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4.1-mini-2025-04-14’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4.1-nano’ model by OpenAI
-NDLLMProvider
-refers to ‘gpt-4.1-nano-2025-04-14’ model by OpenAI
-NDLLMProvider
-refers to ‘o1-preview’ model by OpenAI
-NDLLMProvider
-refers to ‘o1-preview-2024-09-12’ model by OpenAI
-NDLLMProvider
-refers to ‘o1-mini’ model by OpenAI
-NDLLMProvider
-refers to ‘o1-mini-2024-09-12’ model by OpenAI
-NDLLMProvider
-refers to ‘claude-2.1’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-opus-20240229’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-sonnet-20240229’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-5-sonnet-20240620’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-7-sonnet-latest’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-7-sonnet-20250219’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-5-haiku-20241022’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-3-haiku-20240307’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-opus-4-20250514’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-sonnet-4-20250514’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-opus-4-0’ model by Anthropic
-NDLLMProvider
-refers to ‘claude-sonnet-4-0’ model by Anthropic
-NDLLMProvider
-refers to ‘gemini-pro’ model by Google
-NDLLMProvider
-refers to ‘gemini-1.0-pro-latest’ model by Google
-NDLLMProvider
-refers to ‘gemini-1.5-pro-latest’ model by Google
-NDLLMProvider
-refers to ‘gemini-1.5-pro-exp-0801’ model by Google
-NDLLMProvider
-refers to ‘gemini-1.5-flash-latest’ model by Google
-NDLLMProvider
-refers to ‘gemini-20-flash’ model by Google
-NDLLMProvider
-refers to ‘gemini-20-flash-001’ model by Google
-NDLLMProvider
-refers to ‘gemini-25-flash’ model by Google
-NDLLMProvider
-refers to ‘gemini-25-pro’ model by Google
-NDLLMProvider
-refers to ‘command-r’ model by Cohere
-NDLLMProvider
-refers to ‘command-r-plus’ model by Cohere
-NDLLMProvider
-refers to ‘mistral-large-latest’ model by Mistral AI
-NDLLMProvider
-refers to ‘mistral-large-2407’ model by Mistral AI
-NDLLMProvider
-refers to ‘mistral-large-2402’ model by Mistral AI
-NDLLMProvider
-refers to ‘mistral-medium-latest’ model by Mistral AI
-NDLLMProvider
-refers to ‘mistral-small-latest’ model by Mistral AI
-NDLLMProvider
-refers to ‘open-mistral-7b’ model by Mistral AI
-NDLLMProvider
-refers to ‘open-mixtral-8x7b’ model by Mistral AI
-NDLLMProvider
-refers to ‘open-mixtral-8x22b’ model by Mistral AI
-NDLLMProvider
-refers to ‘open-mistral-nemo’ model by Mistral AI
-NDLLMProvider
-refers to ‘Mistral-7B-Instruct-v0.2’ model served via TogetherAI
-NDLLMProvider
-refers to ‘Mixtral-8x7B-Instruct-v0.1’ model served via TogetherAI
-NDLLMProvider
-refers to ‘Mixtral-8x22B-Instruct-v0.1’ model served via TogetherAI
-NDLLMProvider
-refers to ‘Llama-3-70b-chat-hf’ model served via TogetherAI
-NDLLMProvider
-refers to ‘Llama-3-8b-chat-hf’ model served via TogetherAI
-NDLLMProvider
-refers to ‘Qwen2-72B-Instruct’ model served via TogetherAI
-NDLLMProvider
-refers to ‘Meta-Llama-3.1-8B-Instruct-Turbo’ -model served via TogetherAI
-NDLLMProvider
-refers to ‘Meta-Llama-3.1-70B-Instruct-Turbo’ -model served via TogetherAI
-NDLLMProvider
-refers to ‘Meta-Llama-3.1-405B-Instruct-Turbo’ -model served via TogetherAI
-NDLLMProvider
-refers to ‘DeepSeek-R1’ -model served via TogetherAI
-NDLLMProvider
-refers to “mistral-7b-instruct-v0.2” model served via Replicate
-NDLLMProvider
-refers to “mixtral-8x7b-instruct-v0.1” model served via Replicate
-NDLLMProvider
-refers to “meta-llama-3-70b-instruct” model served via Replicate
-NDLLMProvider
-refers to “meta-llama-3-8b-instruct” model served via Replicate
-NDLLMProvider
-refers to “meta-llama-3.1-405b-instruct” -model served via Replicate
-NDLLMProvider
-refers to “sonar” model by Perplexity
-NDLLMProvider
-This endpoint receives the prompt and routing settings, and makes a call to the NotDiamond API. -It returns the best fitting LLM to call and a session ID that can be used for feedback.
-messages (List[Dict[str, str]]) – list of messages to be used for the LLM call
llm_configs (List[LLMConfig]) – a list of available LLMs that the router can decide from
metric (Metric) – metric based off which the router makes the decision. As of now only ‘accuracy’ supported.
notdiamond_api_key (str) – API key generated via the NotDiamond dashboard.
max_model_depth (int) – if your top recommended model is down, specify up to which depth of routing you’re willing to go.
hash_content (Optional[bool]) – Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional) – Define the “cost” or “latency” tradeoff -for the router to determine the best LLM for a given query.
preference_id (Optional[str], optional) – The ID of the router preference that was configured via the Dashboard. -Defaults to None.
previous_session (Optional[str], optional) – The session ID of a previous session, allow you to link requests.
timeout (int, optional) – timeout for the request. Defaults to 60.
max_retries (int, optional) – The maximum number of retries to make when calling the Not Diamond API.
nd_api_url (Optional[str], optional) – The URL of the NotDiamond API. Defaults to None.
tools (Sequence[Dict[str, Any] | Callable] | None)
_user_agent (str)
In case of an error the LLM defaults to None and the session ID defaults -to ‘NO-SESSION-ID’.
-tuple(LLMConfig, string)
-Create a preference id with an optional name. The preference name will appear in your -dashboard on Not Diamond.
-notdiamond_api_key (str)
name (str | None)
nd_api_url (str | None)
_user_agent (str)
str
-This function converts the tools list into the format that OpenAI expects. -Does this by using langchains Model that automatically creates the dictionary on bind_tools
-tools (Optional[Sequence[Union[Dict[str, Any], Callable]]]) – list of tools to be converted
-dictionary of tools in the format that OpenAI expects
-dict
-This endpoint receives the prompt and routing settings, and makes a call to the NotDiamond API. -It returns the best fitting LLM to call and a session ID that can be used for feedback.
-messages (List[Dict[str, str]]) – list of messages to be used for the LLM call
llm_configs (List[LLMConfig]) – a list of available LLMs that the router can decide from
metric (Metric) – metric based off which the router makes the decision. As of now only ‘accuracy’ supported.
notdiamond_api_key (str) – API key generated via the NotDiamond dashboard.
max_model_depth (int) – if your top recommended model is down, specify up to which depth of routing you’re willing to go.
hash_content (Optional[bool]) – Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional) – Define the “cost” or “latency” tradeoff -for the router to determine the best LLM for a given query.
preference_id (Optional[str], optional) – The ID of the router preference that was configured via the Dashboard. -Defaults to None.
previous_session (Optional[str], optional) – The session ID of a previous session, allow you to link requests.
timeout (int, optional) – timeout for the request. Defaults to 60.
max_retries (int, optional) – The maximum number of retries to make when calling the Not Diamond API. -Defaults to 3.
nd_api_url (Optional[str], optional) – The URL of the NotDiamond API. Defaults to None.
tools (Sequence[Dict[str, Any] | Callable] | None)
_user_agent (str)
In case of an error the LLM defaults to None and the session ID defaults -to ‘NO-SESSION-ID’.
-tuple(LLMConfig, string)
-This is the core method for the model_select endpoint. -It returns the best fitting LLM to call and a session ID that can be used for feedback.
-messages (List[Dict[str, str]]) – list of messages to be used for the LLM call
llm_configs (List[LLMConfig]) – a list of available LLMs that the router can decide from
metric (Metric) – metric based off which the router makes the decision. As of now only ‘accuracy’ supported.
notdiamond_api_key (str) – API key generated via the NotDiamond dashboard.
max_model_depth (int) – if your top recommended model is down, specify up to which depth of routing you’re willing to go.
hash_content (Optional[bool]) – Flag for hashing content before sending to NotDiamond API.
tradeoff (Optional[str], optional) – Define the “cost” or “latency” tradeoff -for the router to determine the best LLM for a given query.
preference_id (Optional[str], optional) – The ID of the router preference that was configured via the Dashboard. -Defaults to None.
previous_session (Optional[str], optional) – The session ID of a previous session, allow you to link requests.
async_mode (bool, optional) – whether to run the request in async mode. Defaults to False.
nd_api_url (Optional[str], optional) – The URL of the NotDiamond API. Defaults to None.
tools (Sequence[Dict[str, Any] | Callable] | None)
_user_agent (str)
returns data to be used for the API call of modelSelect
-tuple(url, payload, headers)
-This method makes an API call to the NotDiamond server to report the latency of an LLM call. -It helps fine-tune our model router and ensure we offer recommendations that meet your latency expectation.
-This feature can be disabled on the NDLLM class level by setting latency_tracking to False.
-session_id (str) – the session ID that was returned from the invoke or model_select calls, so we know which -router call your latency report refers to.
llm_provider (LLMConfig) – specifying the LLM provider for which the latency is reported
tokens_per_second (float) – latency of the model call calculated based on time elapsed, input tokens, and output tokens
notdiamond_api_key (str) – NotDiamond API call used for authentication
nd_api_url (Optional[str], optional) – The URL of the NotDiamond API. Defaults to None.
llm_config (LLMConfig)
_user_agent (str)
status code of the API call, 200 if it’s success
-int
-ApiError – if the API call to the NotDiamond backend fails, this error is raised
-Bases: object
metric (str | None)
-session_id (str)
llm_config (LLMConfig)
feedback_payload (Dict[str, int])
notdiamond_api_key (str)
nd_api_url (str)
_user_agent (str)
bool
-Bases: object
Implementation of CustomRouter class, used to train custom routers using custom datasets.
-language (str)
maximize (bool)
api_key (str | None)
The language of the dataset in lowercase. Defaults to “english”.
-str
-Whether higher score is better. Defaults to true.
-bool
-The NotDiamond API key. If not specified, will try to -find it in the environment variable NOTDIAMOND_API_KEY.
-Optional[str], optional
-Method to evaluate a custom router using provided dataset.
-dataset (Dict[str, pandas.DataFrame]) – The dataset to train a custom router. -Each key in the dictionary should be in the form of <provider>/<model>.
prompt_column (str) – The column name in each DataFrame corresponding -to the prompts used to evaluate the LLM.
response_column (str) – The column name in each DataFrame corresponding -to the response given by the LLM for a given prompt.
score_column (str) – The column name in each DataFrame corresponding -to the score given to the response from the LLM.
preference_id (str) – The preference_id associated with the custom router -returned from .fit().
include_latency (bool)
ApiError – When the NotDiamond API fails
ValueError – When parsing the provided dataset fails
UnsupportedLLMProvider – When a provider specified in the dataset is not supported.
(indicated by column <provider>/<model>/response), scores of each provider -(indicated by column <provider>/<model>/score), and notdiamond custom router -response and score (indicated by column notdiamond/response and notdiamond/score).
-provided dataset, the “Best Provider Average Score” achieved by the “Best Average Provider”, -and the “Not Diamond Average Score” achieved through custom router.
-Tuple[pandas.DataFrame, pandas.DataFrame]
-Method to train a custom router using provided dataset.
-dataset (Dict[str, pandas.DataFrame]) – The dataset to train a custom router. -Each key in the dictionary should be in the form of <provider>/<model>.
prompt_column (str) – The column name in each DataFrame corresponding -to the prompts used to evaluate the LLM.
response_column (str) – The column name in each DataFrame corresponding -to the response given by the LLM for a given prompt.
score_column (str) – The column name in each DataFrame corresponding -to the score given to the response from the LLM.
preference_id (Optional[str], optional) – If specified, the custom router -associated with the preference_id will be updated with the provided dataset.
nd_api_url (Optional[str], optional) – The URL of the NotDiamond API. Defaults to prod.
ApiError – When the NotDiamond API fails
ValueError – When parsing the provided dataset fails
UnsupportedLLMProvider – When a provider specified in the dataset is not supported.
Use this preference_id in your routing calls to use the custom router.
-str
-Bases: Exception
Bases: object
llm_providers (List[Dict[str, str]])
tools (List[Dict[str, str]] | None)
max_model_depth (int)
tradeoff (str | None)
preference_id (str | None)
hash_content (bool | None)
model (str)
messages (list)
api_base (str)
model_response (ModelResponse)
print_verbose (Callable)
response (dict)
-str
-Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
-model (str) – The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/
messages (List) – A list of message objects representing the conversation context (default is an empty list).
PARAMS (OPTIONAL)
functions (List, optional) – A list of functions to apply to the conversation messages (default is an empty list).
function_call (str, optional) – The name of the function to call within the conversation (default is an empty string).
temperature (float, optional) – The temperature parameter for controlling the randomness of the output (default is 1.0).
top_p (float, optional) – The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional) – The number of completions to generate (default is 1).
stream (bool, optional) – If True, return a streaming response (default is False).
stream_options (dict, optional) – A dictionary containing options for the streaming response. Only use this if stream is True.
stop (string/list, optional) –
Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional) – The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional) – It is used to penalize new tokens based on their existence in the text so far.
frequency_penalty (float | None) – It is used to penalize new tokens based on their frequency in the text so far.
logit_bias (dict, optional) – Used to modify the probability of specific tokens appearing in the completion.
user (str, optional) – A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse.
metadata (dict, optional) – Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc.
api_base (str, optional) – Base URL for the API (default is None).
api_version (str, optional) – API version (default is None).
api_key (str, optional) – API key (default is None).
model_list (list, optional) – List of api base, version, keys
timeout (float, optional) – The maximum execution time in seconds for the completion request.
Params (LITELLM Specific)
mock_response (str, optional) – If provided, return a mock completion response for testing or debugging purposes (default is None).
custom_llm_provider (str, optional) – Used for Non-OpenAI LLMs, Example usage for bedrock, set model=”amazon.titan-tg1-large” and custom_llm_provider=”bedrock”
response_format (dict | Type[BaseModel] | None)
seed (int | None)
tools (List | None)
tool_choice (str | None)
parallel_tool_calls (bool | None)
logprobs (bool | None)
top_logprobs (int | None)
base_url (str | None)
extra_headers (dict | None)
A response object containing the generated completion and associated metadata.
-ModelResponse
-Notes
-This function is an asynchronous version of the completion function.
The completion function is called using run_in_executor to execute synchronously in the event loop.
If stream is True, the function returns an async generator that yields completion lines.
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) -:param model: The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ -:type model: str -:param messages: A list of message objects representing the conversation context (default is an empty list). -:type messages: List
-functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). -function_call (str, optional): The name of the function to call within the conversation (default is an empty string). -temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). -top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). -n (int, optional): The number of completions to generate (default is 1). -stream (bool, optional): If True, return a streaming response (default is False). -stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. -stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. -max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). -presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. -frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. -logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. -user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. -logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message -top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. -metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. -api_base (str, optional): Base URL for the API (default is None). -api_version (str, optional): API version (default is None). -api_key (str, optional): API key (default is None). -model_list (list, optional): List of api base, version, keys -extra_headers (dict, optional): Additional headers to include in the request.
-mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). -custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model=”amazon.titan-tg1-large” and custom_llm_provider=”bedrock” -max_retries (int, optional): The number of retries to attempt (default is 0).
-A response object containing the generated completion and associated metadata.
-ModelResponse
-model (str)
messages (List)
timeout (float | str | Timeout | None)
temperature (float | None)
top_p (float | None)
n (int | None)
stream (bool | None)
stream_options (dict | None)
max_tokens (int | None)
presence_penalty (float | None)
frequency_penalty (float | None)
logit_bias (dict | None)
user (str | None)
response_format (dict | Type[BaseModel] | None)
seed (int | None)
tools (List | None)
tool_choice (str | dict | None)
logprobs (bool | None)
top_logprobs (int | None)
parallel_tool_calls (bool | None)
extra_headers (dict | None)
functions (List | None)
function_call (str | None)
base_url (str | None)
api_version (str | None)
api_key (str | None)
model_list (list | None)
Note
-This function is used to perform completions() using the specified language model.
It supports various optional parameters for customizing the completion behavior.
If ‘mock_response’ is provided, a mock completion response is returned for testing or debugging.
llm_provider (str)
dynamic_api_key (str | None)
Returns the provider for a given model name - e.g. ‘azure/chatgpt-v-2’ -> ‘azure’
-For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
-Raises Error - if unable to map model to a provider
-model (str)
custom_llm_provider (str | None)
api_base (str | None)
api_key (str | None)
litellm_params (LiteLLM_Params | None)
Tuple[str, str, str | None, str | None]
-