Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 90 additions & 14 deletions apps/beeai-cli/src/beeai_cli/commands/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _ollama_exe() -> str:

RECOMMENDED_LLM_MODELS = [
f"{ModelProviderType.WATSONX}:ibm/granite-3-3-8b-instruct",
f"{ModelProviderType.AWS_BEDROCK}:anthropic.claude-3-sonnet-20240229-v1:0",
f"{ModelProviderType.OPENAI}:gpt-4o",
f"{ModelProviderType.ANTHROPIC}:claude-sonnet-4-20250514",
f"{ModelProviderType.CEREBRAS}:llama-3.3-70b",
Expand Down Expand Up @@ -72,9 +73,17 @@ def _ollama_exe() -> str:
]

LLM_PROVIDERS = [
Choice(
name="Amazon Bedrock".ljust(20),
value=(ModelProviderType.AWS_BEDROCK, "Amazon Bedrock", None),
),
Choice(
name="Anthropic Claude".ljust(20),
value=(ModelProviderType.ANTHROPIC, "Anthropic Claude", "https://api.anthropic.com/v1"),
value=(
ModelProviderType.ANTHROPIC,
"Anthropic Claude",
"https://api.anthropic.com/v1",
),
),
Choice(
name="Cerebras".ljust(20) + "🆓 has a free tier",
Expand All @@ -86,23 +95,44 @@ def _ollama_exe() -> str:
),
Choice(
name="Cohere".ljust(20) + "🆓 has a free tier",
value=(ModelProviderType.COHERE, "Cohere", "https://api.cohere.ai/compatibility/v1"),
value=(
ModelProviderType.COHERE,
"Cohere",
"https://api.cohere.ai/compatibility/v1",
),
),
Choice(
name="DeepSeek",
value=(ModelProviderType.DEEPSEEK, "DeepSeek", "https://api.deepseek.com/v1"),
),
Choice(name="DeepSeek", value=(ModelProviderType.DEEPSEEK, "DeepSeek", "https://api.deepseek.com/v1")),
Choice(
name="Google Gemini".ljust(20) + "🆓 has a free tier",
value=(ModelProviderType.GEMINI, "Google Gemini", "https://generativelanguage.googleapis.com/v1beta/openai"),
value=(
ModelProviderType.GEMINI,
"Google Gemini",
"https://generativelanguage.googleapis.com/v1beta/openai",
),
),
Choice(
name="GitHub Models".ljust(20) + "🆓 has a free tier",
value=(ModelProviderType.GITHUB, "GitHub Models", "https://models.github.ai/inference"),
value=(
ModelProviderType.GITHUB,
"GitHub Models",
"https://models.github.ai/inference",
),
),
Choice(
name="Groq".ljust(20) + "🆓 has a free tier",
value=(ModelProviderType.GROQ, "Groq", "https://api.groq.com/openai/v1"),
),
Choice(name="IBM watsonx".ljust(20), value=(ModelProviderType.WATSONX, "IBM watsonx", None)),
Choice(name="Jan".ljust(20) + "💻 local", value=(ModelProviderType.JAN, "Jan", "http://localhost:1337/v1")),
Choice(
name="IBM watsonx".ljust(20),
value=(ModelProviderType.WATSONX, "IBM watsonx", None),
),
Choice(
name="Jan".ljust(20) + "💻 local",
value=(ModelProviderType.JAN, "Jan", "http://localhost:1337/v1"),
),
Choice(
name="Mistral".ljust(20) + "🆓 has a free tier",
value=(ModelProviderType.MISTRAL, "Mistral", "https://api.mistral.ai/v1"),
Expand All @@ -113,7 +143,11 @@ def _ollama_exe() -> str:
),
Choice(
name="NVIDIA NIM".ljust(20),
value=(ModelProviderType.NVIDIA, "NVIDIA NIM", "https://integrate.api.nvidia.com/v1"),
value=(
ModelProviderType.NVIDIA,
"NVIDIA NIM",
"https://integrate.api.nvidia.com/v1",
),
),
Choice(
name="Ollama".ljust(20) + "💻 local",
Expand All @@ -125,15 +159,23 @@ def _ollama_exe() -> str:
),
Choice(
name="OpenRouter".ljust(20) + "🆓 has some free models",
value=(ModelProviderType.OPENROUTER, "OpenRouter", "https://openrouter.ai/api/v1"),
value=(
ModelProviderType.OPENROUTER,
"OpenRouter",
"https://openrouter.ai/api/v1",
),
),
Choice(
name="Perplexity".ljust(20),
value=(ModelProviderType.PERPLEXITY, "Perplexity", "https://api.perplexity.ai"),
),
Choice(
name="Together.ai".ljust(20) + "🆓 has a free tier",
value=(ModelProviderType.TOGETHER, "together.ai", "https://api.together.xyz/v1"),
value=(
ModelProviderType.TOGETHER,
"together.ai",
"https://api.together.xyz/v1",
),
),
Choice(
name="🛠️ Other (RITS, Amazon Bedrock, vLLM, ..., any OpenAI-compatible API)",
Expand Down Expand Up @@ -182,6 +224,7 @@ async def _add_provider(capability: ModelCapability, use_true_localhost: bool =
provider_name: str
base_url: str
watsonx_project_id, watsonx_space_id = None, None
aws_region, aws_access_key_id = None, None
choices = LLM_PROVIDERS if capability == ModelCapability.LLM else EMBEDDING_PROVIDERS
provider_type, provider_name, base_url = await inquirer.fuzzy( # type: ignore
message=f"Select {capability} provider (type to search):", choices=choices
Expand Down Expand Up @@ -231,14 +274,46 @@ async def _add_provider(capability: ModelCapability, use_true_localhost: bool =
watsonx_project_id = watsonx_project_or_space_id if watsonx_project_or_space == "project" else None
watsonx_space_id = watsonx_project_or_space_id if watsonx_project_or_space == "space" else None

if (api_key := os.environ.get(f"{provider_type.upper()}_API_KEY")) is None or not await inquirer.confirm( # type: ignore
if provider_type == ModelProviderType.AWS_BEDROCK:
aws_region = await inquirer.select( # type: ignore
message="Select AWS region:",
choices=[
"us-east-1",
"us-west-2",
"eu-central-1",
"ap-northeast-1",
"ap-southeast-2",
],
default="us-east-1",
).execute_async()
base_url = f"https://bedrock-runtime.{aws_region}.amazonaws.com/openai/v1"
if (
os.environ.get("AWS_ACCESS_KEY_ID")
and os.environ.get("AWS_SECRET_ACCESS_KEY")
and await inquirer.confirm( # type: ignore
message="Use AWS credentials from environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)?",
default=True,
).execute_async()
):
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
api_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
else:
aws_access_key_id = await inquirer.text( # type: ignore
message="Enter AWS Access Key ID:", validate=EmptyInputValidator()
).execute_async()
api_key = await inquirer.secret( # type: ignore
message="Enter AWS Secret Access Key:", validate=EmptyInputValidator()
).execute_async()
elif (api_key := os.environ.get(f"{provider_type.upper()}_API_KEY")) is None or not await inquirer.confirm( # type: ignore
message=f"Use the API key from environment variable '{provider_type.upper()}_API_KEY'?",
default=True,
).execute_async():
api_key: str = (
api_key = (
"dummy"
if provider_type in {ModelProviderType.OLLAMA, ModelProviderType.JAN}
else await inquirer.secret(message="Enter API key:", validate=EmptyInputValidator()).execute_async() # type: ignore
else await inquirer.secret( # type: ignore
message="Enter API key:", validate=EmptyInputValidator()
).execute_async()
)

try:
Expand Down Expand Up @@ -286,9 +361,10 @@ async def _add_provider(capability: ModelCapability, use_true_localhost: bool =
name=provider_name,
type=ModelProviderType(provider_type),
base_url=base_url,
api_key=api_key,
api_key=api_key or "",
watsonx_space_id=watsonx_space_id,
watsonx_project_id=watsonx_project_id,
aws_access_key_id=aws_access_key_id,
)

except httpx.HTTPError as e:
Expand Down
4 changes: 4 additions & 0 deletions apps/beeai-sdk/src/beeai_sdk/platform/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class ModelProviderType(StrEnum):
ANTHROPIC = "anthropic"
AWS_BEDROCK = "aws_bedrock"
CEREBRAS = "cerebras"
CHUTES = "chutes"
COHERE = "cohere"
Expand Down Expand Up @@ -52,6 +53,7 @@ class ModelProvider(pydantic.BaseModel):
base_url: pydantic.HttpUrl
watsonx_project_id: str | None = None
watsonx_space_id: str | None = None
aws_access_key_id: str | None = None
created_at: pydantic.AwareDatetime
capabilities: set[ModelCapability]

Expand All @@ -64,6 +66,7 @@ async def create(
base_url: str | pydantic.HttpUrl,
watsonx_project_id: str | None = None,
watsonx_space_id: str | None = None,
aws_access_key_id: str | None = None,
api_key: str,
client: PlatformClient | None = None,
) -> ModelProvider:
Expand All @@ -79,6 +82,7 @@ async def create(
"base_url": str(base_url),
"watsonx_project_id": watsonx_project_id,
"watsonx_space_id": watsonx_space_id,
"aws_access_key_id": aws_access_key_id,
"api_key": api_key,
},
)
Expand Down
1 change: 1 addition & 0 deletions apps/beeai-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"openai>=1.97.0",
"authlib>=1.6.4",
"async-lru>=2.0.5",
"aws-bedrock-token-generator>=1.1.0",
]

[dependency-groups]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def create_model_provider(
base_url=request.base_url,
watsonx_project_id=request.watsonx_project_id,
watsonx_space_id=request.watsonx_space_id,
aws_access_key_id=request.aws_access_key_id,
api_key=request.api_key.get_secret_value(),
)
return EntityModel(model_provider)
Expand Down
17 changes: 16 additions & 1 deletion apps/beeai-server/src/beeai_server/api/routes/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import ibm_watsonx_ai
import ibm_watsonx_ai.foundation_models.embeddings
import openai
import openai.pagination
import openai.types.chat
from fastapi import Depends, HTTPException
from fastapi.concurrency import run_in_threadpool
Expand Down Expand Up @@ -44,6 +43,22 @@ async def create_chat_completion(

api_key = await model_provider_service.get_provider_api_key(model_provider_id=provider.id)

if provider.type == ModelProviderType.AWS_BEDROCK:
import aws_bedrock_token_generator
import botocore.credentials

# exchange aws_secret_access_key for short-lived Bedrock API key
api_key = await run_in_threadpool(
aws_bedrock_token_generator.provide_token,
region=provider.aws_region,
aws_credentials_provider=botocore.credentials.EnvProvider(
{
"AWS_ACCESS_KEY_ID": provider.aws_access_key_id,
"AWS_SECRET_ACCESS_KEY": api_key,
}
),
)

if provider.type == ModelProviderType.WATSONX:
model = ibm_watsonx_ai.foundation_models.ModelInference(
model_id=model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class CreateModelProviderRequest(BaseModel):
base_url: HttpUrl
watsonx_project_id: str | None = None
watsonx_space_id: str | None = None
aws_access_key_id: str | None = None
api_key: Secret[str]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import re
from datetime import datetime
from enum import StrEnum
from typing import Any, Literal
Expand All @@ -15,6 +16,7 @@

class ModelProviderType(StrEnum):
ANTHROPIC = "anthropic"
AWS_BEDROCK = "aws_bedrock"
CEREBRAS = "cerebras"
CHUTES = "chutes"
COHERE = "cohere"
Expand Down Expand Up @@ -75,13 +77,27 @@ class ModelProvider(BaseModel):
exclude=True,
)

# AWS Bedrock specific fields
aws_access_key_id: str | None = Field(None, description="AWS access key ID for Bedrock", exclude=True)

@model_validator(mode="after")
def validate_watsonx_config(self):
def validate_provider_config(self):
"""Validate that watsonx providers have either project_id or space_id."""
if self.type == ModelProviderType.WATSONX and not (bool(self.watsonx_project_id) ^ bool(self.watsonx_space_id)):
raise ValueError("WatsonX providers must have either watsonx_project_id or watsonx_space_id")
if self.type == ModelProviderType.AWS_BEDROCK and not self.aws_access_key_id:
raise ValueError("AWS Bedrock providers must have aws_access_key_id")
return self

@computed_field
@property
def aws_region(self) -> str | None:
if self.type == ModelProviderType.AWS_BEDROCK:
match = re.search(r"bedrock-runtime\.([^.]+)\.amazonaws\.com", str(self.base_url))
if match:
return match.group(1)
return None

@computed_field
@property
def capabilities(self) -> set[ModelCapability]:
Expand All @@ -99,6 +115,26 @@ def _parse_openai_compatible_model(self, model: dict[str, Any]) -> Model:
async def load_models(self, api_key: str) -> list[Model]:
async with AsyncClient() as client:
match self.type:
case ModelProviderType.AWS_BEDROCK:
import boto3

response = boto3.client(
"bedrock",
region_name=self.aws_region,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=api_key,
).list_foundation_models(byInferenceType="ON_DEMAND")
return [
Model(
id=f"{self.type}:{model['modelId']}",
created=int(datetime.now().timestamp()),
object="model",
owned_by=model["providerName"],
provider=self._model_provider_info,
)
for model in response["modelSummaries"]
if "TEXT" in model["outputModalities"]
]
case ModelProviderType.WATSONX:
response = await client.get(f"{self.base_url}/ml/v1/foundation_model_specs?version=2025-08-27")
response_models = response.raise_for_status().json()["resources"]
Expand Down Expand Up @@ -189,6 +225,7 @@ class ModelWithScore(BaseModel):

_PROVIDER_CAPABILITIES: dict[ModelProviderType, set[ModelCapability]] = {
ModelProviderType.ANTHROPIC: {ModelCapability.LLM},
ModelProviderType.AWS_BEDROCK: {ModelCapability.LLM, ModelCapability.EMBEDDING},
ModelProviderType.CEREBRAS: {ModelCapability.LLM},
ModelProviderType.CHUTES: {ModelCapability.LLM},
ModelProviderType.COHERE: {ModelCapability.LLM, ModelCapability.EMBEDDING},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

"""empty message

Revision ID: 198d161f5b5c
Revises: 73e2d8596ada
Create Date: 2025-10-02 15:55:24.226680

"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "198d161f5b5c"
down_revision: str | None = "73e2d8596ada"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("model_providers", sa.Column("aws_access_key_id", sa.String(length=256), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("model_providers", "aws_access_key_id")
# ### end Alembic commands ###
Loading
Loading