diff --git a/sdk/agenta/sdk/assets.py b/sdk/agenta/sdk/assets.py index 4457ab357b..a584371842 100644 --- a/sdk/agenta/sdk/assets.py +++ b/sdk/agenta/sdk/assets.py @@ -1,3 +1,8 @@ +from typing import Dict, Optional, Tuple + +from litellm import cost_calculator + + supported_llm_models = { "anthropic": [ "anthropic/claude-sonnet-4-5", @@ -206,6 +211,58 @@ providers_list = list(supported_llm_models.keys()) + +def _get_model_costs(model: str) -> Optional[Tuple[float, float]]: + """ + Get the input and output costs per 1M tokens for a model. + + Uses litellm's cost_calculator (same as tracing/inline.py) for consistency. + + Args: + model: The model name (e.g., "gpt-4o" or "anthropic/claude-3-opus-20240229") + + Returns: + Tuple of (input_cost, output_cost) per 1M tokens, or None if not found. + """ + try: + costs = cost_calculator.cost_per_token( + model=model, + prompt_tokens=1_000_000, + completion_tokens=1_000_000, + ) + if costs: + input_cost, output_cost = costs + if input_cost > 0 or output_cost > 0: + return (input_cost, output_cost) + except Exception: + pass + return None + + +def _build_model_metadata() -> Dict[str, Dict[str, Dict[str, float]]]: + """ + Build metadata dictionary with costs for all supported models. + + Returns: + Nested dict: {provider: {model: {"input": cost, "output": cost}}} + """ + metadata: Dict[str, Dict[str, Dict[str, float]]] = {} + + for provider, models in supported_llm_models.items(): + metadata[provider] = {} + for model in models: + costs = _get_model_costs(model) + if costs: + metadata[provider][model] = { + "input": costs[0], + "output": costs[1], + } + + return metadata + + +model_metadata = _build_model_metadata() + model_to_provider_mapping = { model: provider for provider, models in supported_llm_models.items() diff --git a/sdk/agenta/sdk/types.py b/sdk/agenta/sdk/types.py index f5953a142f..243886ac31 100644 --- a/sdk/agenta/sdk/types.py +++ b/sdk/agenta/sdk/types.py @@ -8,7 +8,7 @@ from starlette.responses import StreamingResponse -from agenta.sdk.assets import supported_llm_models +from agenta.sdk.assets import supported_llm_models, model_metadata from agenta.client.backend.types import AgentaNodesResponse, AgentaNodeDto @@ -23,7 +23,11 @@ def MCField( # pylint: disable=invalid-name ) -> Field: # Pydantic 2.12+ no longer allows post-creation mutation of field properties if isinstance(choices, dict): - json_extra = {"choices": choices, "x-parameter": "grouped_choice"} + json_extra = { + "choices": choices, + "x-parameter": "grouped_choice", + "x-model-metadata": model_metadata, + } elif isinstance(choices, list): json_extra = {"choices": choices, "x-parameter": "choice"} else: diff --git a/web/oss/src/components/SelectLLMProvider/index.tsx b/web/oss/src/components/SelectLLMProvider/index.tsx index 15bbb04df7..3a0bf30556 100644 --- a/web/oss/src/components/SelectLLMProvider/index.tsx +++ b/web/oss/src/components/SelectLLMProvider/index.tsx @@ -1,13 +1,13 @@ import {useMemo, useRef, useState} from "react" import {CaretRight, Plus, X} from "@phosphor-icons/react" -import {Select, Input, Button, Divider, InputRef, Popover} from "antd" +import {Button, Divider, Input, InputRef, Popover, Select, Tooltip, Typography} from "antd" import clsx from "clsx" import useLazyEffect from "@/oss/hooks/useLazyEffect" import {useVaultSecret} from "@/oss/hooks/useVaultSecret" import {capitalize} from "@/oss/lib/helpers/utils" -import {SecretDTOProvider, PROVIDER_LABELS} from "@/oss/lib/Types" +import {PROVIDER_LABELS, SecretDTOProvider} from "@/oss/lib/Types" import LLMIcons from "../LLMIcons" import Anthropic from "../LLMIcons/assets/Anthropic" @@ -25,6 +25,7 @@ interface ProviderOption { label: string value: string key?: string + metadata?: Record } interface ProviderGroup { @@ -169,6 +170,7 @@ const SelectLLMProvider = ({ label: resolvedLabel, value: resolvedValue, key: option?.key ?? resolvedValue, + metadata: option?.metadata, } }) .filter(Boolean) as ProviderOption[]) ?? [], @@ -208,6 +210,68 @@ const SelectLLMProvider = ({ setTimeout(() => setOpen(false), 0) } + const formatCost = (cost: number) => { + const value = Number(cost) + if (isNaN(value)) return "N/A" + return value < 0.01 ? value.toFixed(4) : value.toFixed(2) + } + + const renderTooltipContent = (metadata: Record) => ( +
+ {(metadata.input !== undefined || metadata.output !== undefined) && ( + <> +
+ + Input: + + + ${formatCost(metadata.input)} / 1M + +
+
+ + Output:{" "} + + + ${formatCost(metadata.output)} / 1M + +
+ + )} +
+ ) + + const renderOptionContent = (option: ProviderOption) => { + const Icon = getProviderIcon(option.value) || LLMIcons[option.label] + return ( +
+
+ {Icon && } + {option.label} +
+
+ ) + } + + const renderOption = (option: ProviderOption) => { + const content = renderOptionContent(option) + + if (option.metadata) { + return ( + + {content} + + ) + } + + return content + } + return ( <>