Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions sdk/agenta/sdk/assets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import Dict, Optional, Tuple

from litellm import cost_calculator


supported_llm_models = {
"anthropic": [
"anthropic/claude-sonnet-4-5",
Expand Down Expand Up @@ -206,6 +211,58 @@

providers_list = list(supported_llm_models.keys())


def _get_model_costs(model: str) -> Optional[Tuple[float, float]]:
"""
Get the input and output costs per 1M tokens for a model.

Uses litellm's cost_calculator (same as tracing/inline.py) for consistency.

Args:
model: The model name (e.g., "gpt-4o" or "anthropic/claude-3-opus-20240229")

Returns:
Tuple of (input_cost, output_cost) per 1M tokens, or None if not found.
"""
try:
costs = cost_calculator.cost_per_token(
model=model,
prompt_tokens=1_000_000,
completion_tokens=1_000_000,
)
if costs:
input_cost, output_cost = costs
if input_cost > 0 or output_cost > 0:
return (input_cost, output_cost)
except Exception:
pass
return None


def _build_model_metadata() -> Dict[str, Dict[str, Dict[str, float]]]:
"""
Build metadata dictionary with costs for all supported models.

Returns:
Nested dict: {provider: {model: {"input": cost, "output": cost}}}
"""
metadata: Dict[str, Dict[str, Dict[str, float]]] = {}

for provider, models in supported_llm_models.items():
metadata[provider] = {}
for model in models:
costs = _get_model_costs(model)
if costs:
metadata[provider][model] = {
"input": costs[0],
"output": costs[1],
}

return metadata


model_metadata = _build_model_metadata()

model_to_provider_mapping = {
model: provider
for provider, models in supported_llm_models.items()
Expand Down
8 changes: 6 additions & 2 deletions sdk/agenta/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.responses import StreamingResponse


from agenta.sdk.assets import supported_llm_models
from agenta.sdk.assets import supported_llm_models, model_metadata
from agenta.client.backend.types import AgentaNodesResponse, AgentaNodeDto


Expand All @@ -23,7 +23,11 @@ def MCField( # pylint: disable=invalid-name
) -> Field:
# Pydantic 2.12+ no longer allows post-creation mutation of field properties
if isinstance(choices, dict):
json_extra = {"choices": choices, "x-parameter": "grouped_choice"}
json_extra = {
"choices": choices,
"x-parameter": "grouped_choice",
"x-model-metadata": model_metadata,
}
elif isinstance(choices, list):
json_extra = {"choices": choices, "x-parameter": "choice"}
else:
Expand Down
99 changes: 80 additions & 19 deletions web/oss/src/components/SelectLLMProvider/index.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import {useMemo, useRef, useState} from "react"

import {CaretRight, Plus, X} from "@phosphor-icons/react"
import {Select, Input, Button, Divider, InputRef, Popover} from "antd"
import {Button, Divider, Input, InputRef, Popover, Select, Tooltip, Typography} from "antd"
import clsx from "clsx"

import useLazyEffect from "@/oss/hooks/useLazyEffect"
import {useVaultSecret} from "@/oss/hooks/useVaultSecret"
import {capitalize} from "@/oss/lib/helpers/utils"
import {SecretDTOProvider, PROVIDER_LABELS} from "@/oss/lib/Types"
import {PROVIDER_LABELS, SecretDTOProvider} from "@/oss/lib/Types"

import LLMIcons from "../LLMIcons"
import Anthropic from "../LLMIcons/assets/Anthropic"
Expand All @@ -25,6 +25,7 @@ interface ProviderOption {
label: string
value: string
key?: string
metadata?: Record<string, any>
}

interface ProviderGroup {
Expand Down Expand Up @@ -169,6 +170,7 @@ const SelectLLMProvider = ({
label: resolvedLabel,
value: resolvedValue,
key: option?.key ?? resolvedValue,
metadata: option?.metadata,
}
})
.filter(Boolean) as ProviderOption[]) ?? [],
Expand Down Expand Up @@ -208,6 +210,68 @@ const SelectLLMProvider = ({
setTimeout(() => setOpen(false), 0)
}

const formatCost = (cost: number) => {
const value = Number(cost)
if (isNaN(value)) return "N/A"
return value < 0.01 ? value.toFixed(4) : value.toFixed(2)
}

const renderTooltipContent = (metadata: Record<string, any>) => (
<div className="flex flex-col gap-0.5">
{(metadata.input !== undefined || metadata.output !== undefined) && (
<>
<div className="flex justify-between gap-4">
<Typography.Text className="text-[10px] text-nowrap">
Input:
</Typography.Text>
<Typography.Text className="text-[10px] text-nowrap">
${formatCost(metadata.input)} / 1M
</Typography.Text>
</div>
<div className="flex justify-between gap-4">
<Typography.Text className="text-[10px] text-nowrap">
Output:{" "}
</Typography.Text>
<Typography.Text className="text-[10px] text-nowrap">
${formatCost(metadata.output)} / 1M
</Typography.Text>
</div>
</>
)}
</div>
)

const renderOptionContent = (option: ProviderOption) => {
const Icon = getProviderIcon(option.value) || LLMIcons[option.label]
return (
<div className="flex items-center gap-2 w-full justify-between group h-full">
<div className="flex items-center gap-2 overflow-hidden w-full">
{Icon && <Icon className="w-4 h-4 flex-shrink-0" />}
<span className="truncate">{option.label}</span>
</div>
</div>
)
}

const renderOption = (option: ProviderOption) => {
const content = renderOptionContent(option)

if (option.metadata) {
return (
<Tooltip
title={renderTooltipContent(option.metadata)}
placement="right"
mouseEnterDelay={0.3}
color="white"
>
{content}
</Tooltip>
)
}

return content
}

return (
<>
<Select
Expand All @@ -225,6 +289,7 @@ const SelectLLMProvider = ({
placeholder="Select a provider"
style={{width: "100%"}}
virtual={false}
optionLabelProp="label"
className={clsx([
"[&_.ant-select-item-option-content]:flex [&_.ant-select-item-option-content]:items-center [&_.ant-select-item-option-content]:gap-2 [&_.ant-select-selection-item]:!flex [&_.ant-select-selection-item]:!items-center [&_.ant-select-selection-item]:!gap-2",
className,
Expand Down Expand Up @@ -292,10 +357,7 @@ const SelectLLMProvider = ({
handleSelect(option.value)
}}
>
{Icon && (
<Icon className="w-4 h-4 flex-shrink-0" />
)}
<span>{option.label}</span>
{renderOption(option)}
</div>
))}
</div>
Expand Down Expand Up @@ -369,27 +431,26 @@ const SelectLLMProvider = ({
}
>
{group.options?.map((option) => {
const Icon =
getProviderIcon(group.label || "") || LLMIcons[option.label]
return (
<Option key={option.key ?? option.value} value={option.value}>
<div className="flex items-center gap-2">
{Icon && <Icon className="w-4 h-4 flex-shrink-0" />}
<span>{option.label}</span>
</div>
<Option
key={option.key ?? option.value}
value={option.value}
label={renderOptionContent(option)}
>
{renderOption(option)}
</Option>
)
})}
</OptGroup>
) : (
group.options?.map((option) => {
const Icon = getProviderIcon(option.value) || LLMIcons[option.label]
return (
<Option key={option.key ?? option.value} value={option.value}>
<div className="flex items-center gap-2">
{Icon && <Icon className="w-4 h-4 flex-shrink-0" />}
<span>{option.label}</span>
</div>
<Option
key={option.key ?? option.value}
value={option.value}
label={renderOptionContent(option)}
>
{renderOption(option)}
</Option>
)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {useSessions} from "@/oss/state/newObservability/hooks/useSessions"
import {openSessionDrawerWithUrlAtom} from "@/oss/state/url/session"

import {AUTO_REFRESH_INTERVAL} from "../../constants"

import EmptySessions from "./assets/EmptySessions"
import {getSessionColumns, SessionRow} from "./assets/getSessionColumns"

Expand Down
Loading