diff --git a/agents/github_specialist/_context.md b/agents/github_specialist/_context.md
new file mode 100644
index 0000000000..8a4ebb3d0a
--- /dev/null
+++ b/agents/github_specialist/_context.md
@@ -0,0 +1 @@
+§§include(/a0/agents/github_specialist/_context.md)
\ No newline at end of file
diff --git a/agents/github_specialist/prompts/agent.system.main.role.md b/agents/github_specialist/prompts/agent.system.main.role.md
new file mode 100644
index 0000000000..6d01d07fbc
--- /dev/null
+++ b/agents/github_specialist/prompts/agent.system.main.role.md
@@ -0,0 +1 @@
+§§include(/a0/agents/github_specialist/prompts/agent.system.main.role.md)
\ No newline at end of file
diff --git a/agents/tdd_test/_context.md b/agents/tdd_test/_context.md
new file mode 100644
index 0000000000..c2563905aa
--- /dev/null
+++ b/agents/tdd_test/_context.md
@@ -0,0 +1 @@
+§§include(/a0/agents/tdd_test/_context.md)
\ No newline at end of file
diff --git a/agents/tdd_test/prompts/agent.system.main.role.md b/agents/tdd_test/prompts/agent.system.main.role.md
new file mode 100644
index 0000000000..7619d42682
--- /dev/null
+++ b/agents/tdd_test/prompts/agent.system.main.role.md
@@ -0,0 +1 @@
+§§include(/a0/agents/tdd_test/prompts/agent.system.main.role.md)
\ No newline at end of file
diff --git a/docs/res/a0-vector-graphics/banner.svg b/docs/res/a0-vector-graphics/banner.svg
index 5a98670be1..2d2e751c39 100644
--- a/docs/res/a0-vector-graphics/banner.svg
+++ b/docs/res/a0-vector-graphics/banner.svg
@@ -1,51 +1 @@
-
-
-
+§§include(/a0/docs/res/a0-vector-graphics/banner.svg)
\ No newline at end of file
diff --git a/docs/res/a0-vector-graphics/dark.svg b/docs/res/a0-vector-graphics/dark.svg
old mode 100755
new mode 100644
index 140246cd02..85ada8c5dd
--- a/docs/res/a0-vector-graphics/dark.svg
+++ b/docs/res/a0-vector-graphics/dark.svg
@@ -1,20 +1 @@
-
-
\ No newline at end of file
+§§include(/a0/docs/res/a0-vector-graphics/dark.svg)
\ No newline at end of file
diff --git a/docs/res/a0-vector-graphics/darkSymbol.svg b/docs/res/a0-vector-graphics/darkSymbol.svg
old mode 100755
new mode 100644
index 893fc49b25..7e090eee64
--- a/docs/res/a0-vector-graphics/darkSymbol.svg
+++ b/docs/res/a0-vector-graphics/darkSymbol.svg
@@ -1,5 +1 @@
-
-
\ No newline at end of file
+§§include(/a0/docs/res/a0-vector-graphics/darkSymbol.svg)
\ No newline at end of file
diff --git a/docs/res/a0-vector-graphics/light.svg b/docs/res/a0-vector-graphics/light.svg
old mode 100755
new mode 100644
index b8148ecf0d..3725732396
--- a/docs/res/a0-vector-graphics/light.svg
+++ b/docs/res/a0-vector-graphics/light.svg
@@ -1,20 +1 @@
-
-
\ No newline at end of file
+§§include(/a0/docs/res/a0-vector-graphics/light.svg)
\ No newline at end of file
diff --git a/docs/res/a0-vector-graphics/lightSymbol.svg b/docs/res/a0-vector-graphics/lightSymbol.svg
old mode 100755
new mode 100644
index c988a103ea..6643292991
--- a/docs/res/a0-vector-graphics/lightSymbol.svg
+++ b/docs/res/a0-vector-graphics/lightSymbol.svg
@@ -1,5 +1 @@
-
-
\ No newline at end of file
+§§include(/a0/docs/res/a0-vector-graphics/lightSymbol.svg)
\ No newline at end of file
diff --git a/models.py b/models.py
index fbc2694dfd..2423bf15a9 100644
--- a/models.py
+++ b/models.py
@@ -1,919 +1 @@
-from dataclasses import dataclass, field
-from enum import Enum
-import logging
-import os
-from typing import (
- Any,
- Awaitable,
- Callable,
- List,
- Optional,
- Iterator,
- AsyncIterator,
- Tuple,
- TypedDict,
-)
-
-from litellm import completion, acompletion, embedding
-import litellm
-import openai
-from litellm.types.utils import ModelResponse
-
-from python.helpers import dotenv
-from python.helpers import settings, dirty_json
-from python.helpers.dotenv import load_dotenv
-from python.helpers.providers import get_provider_config
-from python.helpers.rate_limiter import RateLimiter
-from python.helpers.tokens import approximate_tokens
-from python.helpers import dirty_json, browser_use_monkeypatch
-
-from langchain_core.language_models.chat_models import SimpleChatModel
-from langchain_core.outputs.chat_generation import ChatGenerationChunk
-from langchain_core.callbacks.manager import (
- CallbackManagerForLLMRun,
- AsyncCallbackManagerForLLMRun,
-)
-from langchain_core.messages import (
- BaseMessage,
- AIMessageChunk,
- HumanMessage,
- SystemMessage,
-)
-from langchain.embeddings.base import Embeddings
-from sentence_transformers import SentenceTransformer
-from pydantic import ConfigDict
-
-
-# disable extra logging, must be done repeatedly, otherwise browser-use will turn it back on for some reason
-def turn_off_logging():
- os.environ["LITELLM_LOG"] = "ERROR" # only errors
- litellm.suppress_debug_info = True
- # Silence **all** LiteLLM sub-loggers (utils, cost_calculator…)
- for name in logging.Logger.manager.loggerDict:
- if name.lower().startswith("litellm"):
- logging.getLogger(name).setLevel(logging.ERROR)
-
-
-# init
-load_dotenv()
-turn_off_logging()
-browser_use_monkeypatch.apply()
-
-litellm.modify_params = True # helps fix anthropic tool calls by browser-use
-
-class ModelType(Enum):
- CHAT = "Chat"
- EMBEDDING = "Embedding"
-
-
-@dataclass
-class ModelConfig:
- type: ModelType
- provider: str
- name: str
- api_base: str = ""
- ctx_length: int = 0
- limit_requests: int = 0
- limit_input: int = 0
- limit_output: int = 0
- vision: bool = False
- kwargs: dict = field(default_factory=dict)
-
- def build_kwargs(self):
- kwargs = self.kwargs.copy() or {}
- if self.api_base and "api_base" not in kwargs:
- kwargs["api_base"] = self.api_base
- return kwargs
-
-
-class ChatChunk(TypedDict):
- """Simplified response chunk for chat models."""
- response_delta: str
- reasoning_delta: str
-
-class ChatGenerationResult:
- """Chat generation result object"""
- def __init__(self, chunk: ChatChunk|None = None):
- self.reasoning = ""
- self.response = ""
- self.thinking = False
- self.thinking_tag = ""
- self.unprocessed = ""
- self.native_reasoning = False
- self.thinking_pairs = [("", ""), ("", "")]
- if chunk:
- self.add_chunk(chunk)
-
- def add_chunk(self, chunk: ChatChunk) -> ChatChunk:
- if chunk["reasoning_delta"]:
- self.native_reasoning = True
-
- # if native reasoning detection works, there's no need to worry about thinking tags
- if self.native_reasoning:
- processed_chunk = ChatChunk(response_delta=chunk["response_delta"], reasoning_delta=chunk["reasoning_delta"])
- else:
- # if the model outputs thinking tags, we ned to parse them manually as reasoning
- processed_chunk = self._process_thinking_chunk(chunk)
-
- self.reasoning += processed_chunk["reasoning_delta"]
- self.response += processed_chunk["response_delta"]
-
- return processed_chunk
-
- def _process_thinking_chunk(self, chunk: ChatChunk) -> ChatChunk:
- response_delta = self.unprocessed + chunk["response_delta"]
- self.unprocessed = ""
- return self._process_thinking_tags(response_delta, chunk["reasoning_delta"])
-
- def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk:
- if self.thinking:
- close_pos = response.find(self.thinking_tag)
- if close_pos != -1:
- reasoning += response[:close_pos]
- response = response[close_pos + len(self.thinking_tag):]
- self.thinking = False
- self.thinking_tag = ""
- else:
- if self._is_partial_closing_tag(response):
- self.unprocessed = response
- response = ""
- else:
- reasoning += response
- response = ""
- else:
- for opening_tag, closing_tag in self.thinking_pairs:
- if response.startswith(opening_tag):
- response = response[len(opening_tag):]
- self.thinking = True
- self.thinking_tag = closing_tag
-
- close_pos = response.find(closing_tag)
- if close_pos != -1:
- reasoning += response[:close_pos]
- response = response[close_pos + len(closing_tag):]
- self.thinking = False
- self.thinking_tag = ""
- else:
- if self._is_partial_closing_tag(response):
- self.unprocessed = response
- response = ""
- else:
- reasoning += response
- response = ""
- break
- elif len(response) < len(opening_tag) and self._is_partial_opening_tag(response, opening_tag):
- self.unprocessed = response
- response = ""
- break
-
- return ChatChunk(response_delta=response, reasoning_delta=reasoning)
-
- def _is_partial_opening_tag(self, text: str, opening_tag: str) -> bool:
- for i in range(1, len(opening_tag)):
- if text == opening_tag[:i]:
- return True
- return False
-
- def _is_partial_closing_tag(self, text: str) -> bool:
- if not self.thinking_tag or not text:
- return False
- max_check = min(len(text), len(self.thinking_tag) - 1)
- for i in range(1, max_check + 1):
- if text.endswith(self.thinking_tag[:i]):
- return True
- return False
-
- def output(self) -> ChatChunk:
- response = self.response
- reasoning = self.reasoning
- if self.unprocessed:
- if reasoning and not response:
- reasoning += self.unprocessed
- else:
- response += self.unprocessed
- return ChatChunk(response_delta=response, reasoning_delta=reasoning)
-
-
-rate_limiters: dict[str, RateLimiter] = {}
-api_keys_round_robin: dict[str, int] = {}
-
-
-def get_api_key(service: str) -> str:
- # get api key for the service
- key = (
- dotenv.get_dotenv_value(f"API_KEY_{service.upper()}")
- or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY")
- or dotenv.get_dotenv_value(f"{service.upper()}_API_TOKEN")
- or "None"
- )
- # if the key contains a comma, use round-robin
- if "," in key:
- api_keys = [k.strip() for k in key.split(",") if k.strip()]
- api_keys_round_robin[service] = api_keys_round_robin.get(service, -1) + 1
- key = api_keys[api_keys_round_robin[service] % len(api_keys)]
- return key
-
-
-def get_rate_limiter(
- provider: str, name: str, requests: int, input: int, output: int
-) -> RateLimiter:
- key = f"{provider}\\{name}"
- rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60))
- limiter.limits["requests"] = requests or 0
- limiter.limits["input"] = input or 0
- limiter.limits["output"] = output or 0
- return limiter
-
-
-def _is_transient_litellm_error(exc: Exception) -> bool:
- """Uses status_code when available, else falls back to exception types"""
- # Prefer explicit status codes if present
- status_code = getattr(exc, "status_code", None)
- if isinstance(status_code, int):
- if status_code in (408, 429, 500, 502, 503, 504):
- return True
- # Treat other 5xx as retriable
- if status_code >= 500:
- return True
- return False
-
- # Fallback to exception classes mapped by LiteLLM/OpenAI
- transient_types = (
- getattr(openai, "APITimeoutError", Exception),
- getattr(openai, "APIConnectionError", Exception),
- getattr(openai, "RateLimitError", Exception),
- getattr(openai, "APIError", Exception),
- getattr(openai, "InternalServerError", Exception),
- # Some providers map overloads to ServiceUnavailable-like errors
- getattr(openai, "APIStatusError", Exception),
- )
- return isinstance(exc, transient_types)
-
-
-async def apply_rate_limiter(
- model_config: ModelConfig | None,
- input_text: str,
- rate_limiter_callback: (
- Callable[[str, str, int, int], Awaitable[bool]] | None
- ) = None,
-):
- if not model_config:
- return
- limiter = get_rate_limiter(
- model_config.provider,
- model_config.name,
- model_config.limit_requests,
- model_config.limit_input,
- model_config.limit_output,
- )
- limiter.add(input=approximate_tokens(input_text))
- limiter.add(requests=1)
- await limiter.wait(rate_limiter_callback)
- return limiter
-
-
-def apply_rate_limiter_sync(
- model_config: ModelConfig | None,
- input_text: str,
- rate_limiter_callback: (
- Callable[[str, str, int, int], Awaitable[bool]] | None
- ) = None,
-):
- if not model_config:
- return
- import asyncio, nest_asyncio
-
- nest_asyncio.apply()
- return asyncio.run(
- apply_rate_limiter(model_config, input_text, rate_limiter_callback)
- )
-
-
-class LiteLLMChatWrapper(SimpleChatModel):
- model_name: str
- provider: str
- kwargs: dict = {}
-
- model_config = ConfigDict(
- arbitrary_types_allowed=True,
- extra="allow",
- validate_assignment=False,
- )
-
- def __init__(
- self,
- model: str,
- provider: str,
- model_config: Optional[ModelConfig] = None,
- **kwargs: Any,
- ):
- model_value = f"{provider}/{model}"
- super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore
- # Set A0 model config as instance attribute after parent init
- self.a0_model_conf = model_config
-
- @property
- def _llm_type(self) -> str:
- return "litellm-chat"
-
- def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]:
- result = []
- # Map LangChain message types to LiteLLM roles
- role_mapping = {
- "human": "user",
- "ai": "assistant",
- "system": "system",
- "tool": "tool",
- }
- for m in messages:
- role = role_mapping.get(m.type, m.type)
- message_dict = {"role": role, "content": m.content}
-
- # Handle tool calls for AI messages
- tool_calls = getattr(m, "tool_calls", None)
- if tool_calls:
- # Convert LangChain tool calls to LiteLLM format
- new_tool_calls = []
- for tool_call in tool_calls:
- # Ensure arguments is a JSON string
- args = tool_call["args"]
- if isinstance(args, dict):
- import json
-
- args_str = json.dumps(args)
- else:
- args_str = str(args)
-
- new_tool_calls.append(
- {
- "id": tool_call.get("id", ""),
- "type": "function",
- "function": {
- "name": tool_call["name"],
- "arguments": args_str,
- },
- }
- )
- message_dict["tool_calls"] = new_tool_calls
-
- # Handle tool call ID for ToolMessage
- tool_call_id = getattr(m, "tool_call_id", None)
- if tool_call_id:
- message_dict["tool_call_id"] = tool_call_id
-
- result.append(message_dict)
- return result
-
- def _call(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- import asyncio
-
- msgs = self._convert_messages(messages)
-
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
-
- # Call the model
- resp = completion(
- model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
- )
-
- # Parse output
- parsed = _parse_chunk(resp)
- output = ChatGenerationResult(parsed).output()
- return output["response_delta"]
-
- def _stream(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> Iterator[ChatGenerationChunk]:
- import asyncio
-
- msgs = self._convert_messages(messages)
-
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
-
- result = ChatGenerationResult()
-
- for chunk in completion(
- model=self.model_name,
- messages=msgs,
- stream=True,
- stop=stop,
- **{**self.kwargs, **kwargs},
- ):
- # parse chunk
- parsed = _parse_chunk(chunk) # chunk parsing
- output = result.add_chunk(parsed) # chunk processing
-
- # Only yield chunks with non-None content
- if output["response_delta"]:
- yield ChatGenerationChunk(
- message=AIMessageChunk(content=output["response_delta"])
- )
-
- async def _astream(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> AsyncIterator[ChatGenerationChunk]:
- msgs = self._convert_messages(messages)
-
- # Apply rate limiting if configured
- await apply_rate_limiter(self.a0_model_conf, str(msgs))
-
- result = ChatGenerationResult()
-
- response = await acompletion(
- model=self.model_name,
- messages=msgs,
- stream=True,
- stop=stop,
- **{**self.kwargs, **kwargs},
- )
- async for chunk in response: # type: ignore
- # parse chunk
- parsed = _parse_chunk(chunk) # chunk parsing
- output = result.add_chunk(parsed) # chunk processing
-
- # Only yield chunks with non-None content
- if output["response_delta"]:
- yield ChatGenerationChunk(
- message=AIMessageChunk(content=output["response_delta"])
- )
-
- async def unified_call(
- self,
- system_message="",
- user_message="",
- messages: List[BaseMessage] | None = None,
- response_callback: Callable[[str, str], Awaitable[None]] | None = None,
- reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
- tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
- rate_limiter_callback: (
- Callable[[str, str, int, int], Awaitable[bool]] | None
- ) = None,
- **kwargs: Any,
- ) -> Tuple[str, str]:
-
- turn_off_logging()
-
- if not messages:
- messages = []
- # construct messages
- if system_message:
- messages.insert(0, SystemMessage(content=system_message))
- if user_message:
- messages.append(HumanMessage(content=user_message))
-
- # convert to litellm format
- msgs_conv = self._convert_messages(messages)
-
- # Apply rate limiting if configured
- limiter = await apply_rate_limiter(
- self.a0_model_conf, str(msgs_conv), rate_limiter_callback
- )
-
- # Prepare call kwargs and retry config (strip A0-only params before calling LiteLLM)
- call_kwargs: dict[str, Any] = {**self.kwargs, **kwargs}
- max_retries: int = int(call_kwargs.pop("a0_retry_attempts", 2))
- retry_delay_s: float = float(call_kwargs.pop("a0_retry_delay_seconds", 1.5))
- stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None
-
- # results
- result = ChatGenerationResult()
-
- attempt = 0
- while True:
- got_any_chunk = False
- try:
- # call model
- _completion = await acompletion(
- model=self.model_name,
- messages=msgs_conv,
- stream=stream,
- **call_kwargs,
- )
-
- if stream:
- # iterate over chunks
- async for chunk in _completion: # type: ignore
- got_any_chunk = True
- # parse chunk
- parsed = _parse_chunk(chunk)
- output = result.add_chunk(parsed)
-
- # collect reasoning delta and call callbacks
- if output["reasoning_delta"]:
- if reasoning_callback:
- await reasoning_callback(output["reasoning_delta"], result.reasoning)
- if tokens_callback:
- await tokens_callback(
- output["reasoning_delta"],
- approximate_tokens(output["reasoning_delta"]),
- )
- # Add output tokens to rate limiter if configured
- if limiter:
- limiter.add(output=approximate_tokens(output["reasoning_delta"]))
- # collect response delta and call callbacks
- if output["response_delta"]:
- if response_callback:
- await response_callback(output["response_delta"], result.response)
- if tokens_callback:
- await tokens_callback(
- output["response_delta"],
- approximate_tokens(output["response_delta"]),
- )
- # Add output tokens to rate limiter if configured
- if limiter:
- limiter.add(output=approximate_tokens(output["response_delta"]))
-
- # non-stream response
- else:
- parsed = _parse_chunk(_completion)
- output = result.add_chunk(parsed)
- if limiter:
- if output["response_delta"]:
- limiter.add(output=approximate_tokens(output["response_delta"]))
- if output["reasoning_delta"]:
- limiter.add(output=approximate_tokens(output["reasoning_delta"]))
-
- # Successful completion of stream
- return result.response, result.reasoning
-
- except Exception as e:
- import asyncio
-
- # Retry only if no chunks received and error is transient
- if got_any_chunk or not _is_transient_litellm_error(e) or attempt >= max_retries:
- raise
- attempt += 1
- await asyncio.sleep(retry_delay_s)
-
-
-class AsyncAIChatReplacement:
- class _Completions:
- def __init__(self, wrapper):
- self._wrapper = wrapper
-
- async def create(self, *args, **kwargs):
- # call the async _acall method on the wrapper
- return await self._wrapper._acall(*args, **kwargs)
-
- class _Chat:
- def __init__(self, wrapper):
- self.completions = AsyncAIChatReplacement._Completions(wrapper)
-
- def __init__(self, wrapper, *args, **kwargs):
- self._wrapper = wrapper
- self.chat = AsyncAIChatReplacement._Chat(wrapper)
-
-
-from browser_use.llm import ChatOllama, ChatOpenRouter, ChatGoogle, ChatAnthropic, ChatGroq, ChatOpenAI
-
-class BrowserCompatibleChatWrapper(ChatOpenRouter):
- """
- A wrapper for browser agent that can filter/sanitize messages
- before sending them to the LLM.
- """
-
- def __init__(self, *args, **kwargs):
- turn_off_logging()
- # Create the underlying LiteLLM wrapper
- self._wrapper = LiteLLMChatWrapper(*args, **kwargs)
- # Browser-use may expect a 'model' attribute
- self.model = self._wrapper.model_name
- self.kwargs = self._wrapper.kwargs
-
- @property
- def model_name(self) -> str:
- return self._wrapper.model_name
-
- @property
- def provider(self) -> str:
- return self._wrapper.provider
-
- def get_client(self, *args, **kwargs): # type: ignore
- return AsyncAIChatReplacement(self, *args, **kwargs)
-
- async def _acall(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ):
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self._wrapper.a0_model_conf, str(messages))
-
- # Call the model
- try:
- model = kwargs.pop("model", None)
- kwrgs = {**self._wrapper.kwargs, **kwargs}
-
- # hack from browser-use to fix json schema for gemini (additionalProperties, $defs, $ref)
- if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] and model.startswith("gemini/"):
- kwrgs["response_format"]["json_schema"] = ChatGoogle("")._fix_gemini_schema(kwrgs["response_format"]["json_schema"])
-
- resp = await acompletion(
- model=self._wrapper.model_name,
- messages=messages,
- stop=stop,
- **kwrgs,
- )
-
- # Gemini: strip triple backticks and conform schema
- try:
- msg = resp.choices[0].message # type: ignore
- if self.provider == "gemini" and isinstance(getattr(msg, "content", None), str):
- cleaned = browser_use_monkeypatch.gemini_clean_and_conform(msg.content) # type: ignore
- if cleaned:
- msg.content = cleaned
- except Exception:
- pass
-
- except Exception as e:
- raise e
-
- # another hack for browser-use post process invalid jsons
- try:
- if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] or "json_object" in kwrgs["response_format"]:
- if resp.choices[0].message.content is not None and not resp.choices[0].message.content.startswith("{"): # type: ignore
- js = dirty_json.parse(resp.choices[0].message.content) # type: ignore
- resp.choices[0].message.content = dirty_json.stringify(js) # type: ignore
- except Exception as e:
- pass
-
- return resp
-
-class LiteLLMEmbeddingWrapper(Embeddings):
- model_name: str
- kwargs: dict = {}
- a0_model_conf: Optional[ModelConfig] = None
-
- def __init__(
- self,
- model: str,
- provider: str,
- model_config: Optional[ModelConfig] = None,
- **kwargs: Any,
- ):
- self.model_name = f"{provider}/{model}" if provider != "openai" else model
- self.kwargs = kwargs
- self.a0_model_conf = model_config
-
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
-
- resp = embedding(model=self.model_name, input=texts, **self.kwargs)
- return [
- item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
- for item in resp.data # type: ignore
- ]
-
- def embed_query(self, text: str) -> List[float]:
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self.a0_model_conf, text)
-
- resp = embedding(model=self.model_name, input=[text], **self.kwargs)
- item = resp.data[0] # type: ignore
- return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
-
-
-class LocalSentenceTransformerWrapper(Embeddings):
- """Local wrapper for sentence-transformers models to avoid HuggingFace API calls"""
-
- def __init__(
- self,
- provider: str,
- model: str,
- model_config: Optional[ModelConfig] = None,
- **kwargs: Any,
- ):
- # Clean common user-input mistakes
- model = model.strip().strip('"').strip("'")
-
- # Remove the "sentence-transformers/" prefix if present
- if model.startswith("sentence-transformers/"):
- model = model[len("sentence-transformers/") :]
-
- # Filter kwargs for SentenceTransformer only (no LiteLLM params like 'stream_timeout')
- st_allowed_keys = {
- "device",
- "cache_folder",
- "use_auth_token",
- "revision",
- "trust_remote_code",
- "model_kwargs",
- }
- st_kwargs = {k: v for k, v in (kwargs or {}).items() if k in st_allowed_keys}
-
- self.model = SentenceTransformer(model, **st_kwargs)
- self.model_name = model
- self.a0_model_conf = model_config
-
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
-
- embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore
- return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore
-
- def embed_query(self, text: str) -> List[float]:
- # Apply rate limiting if configured
- apply_rate_limiter_sync(self.a0_model_conf, text)
-
- embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore
- result = (
- embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
- )
- return result # type: ignore
-
-
-def _get_litellm_chat(
- cls: type = LiteLLMChatWrapper,
- model_name: str = "",
- provider_name: str = "",
- model_config: Optional[ModelConfig] = None,
- **kwargs: Any,
-):
- # use api key from kwargs or env
- api_key = kwargs.pop("api_key", None) or get_api_key(provider_name)
-
- # Only pass API key if key is not a placeholder
- if api_key and api_key not in ("None", "NA"):
- kwargs["api_key"] = api_key
-
- provider_name, model_name, kwargs = _adjust_call_args(
- provider_name, model_name, kwargs
- )
- return cls(
- provider=provider_name, model=model_name, model_config=model_config, **kwargs
- )
-
-
-def _get_litellm_embedding(
- model_name: str,
- provider_name: str,
- model_config: Optional[ModelConfig] = None,
- **kwargs: Any,
-):
- # Check if this is a local sentence-transformers model
- if provider_name == "huggingface" and model_name.startswith(
- "sentence-transformers/"
- ):
- # Use local sentence-transformers instead of LiteLLM for local models
- provider_name, model_name, kwargs = _adjust_call_args(
- provider_name, model_name, kwargs
- )
- return LocalSentenceTransformerWrapper(
- provider=provider_name,
- model=model_name,
- model_config=model_config,
- **kwargs,
- )
-
- # use api key from kwargs or env
- api_key = kwargs.pop("api_key", None) or get_api_key(provider_name)
-
- # Only pass API key if key is not a placeholder
- if api_key and api_key not in ("None", "NA"):
- kwargs["api_key"] = api_key
-
- provider_name, model_name, kwargs = _adjust_call_args(
- provider_name, model_name, kwargs
- )
- return LiteLLMEmbeddingWrapper(
- model=model_name, provider=provider_name, model_config=model_config, **kwargs
- )
-
-
-def _parse_chunk(chunk: Any) -> ChatChunk:
- delta = chunk["choices"][0].get("delta", {})
- message = chunk["choices"][0].get("message", {}) or chunk["choices"][0].get(
- "model_extra", {}
- ).get("message", {})
- response_delta = (
- delta.get("content", "")
- if isinstance(delta, dict)
- else getattr(delta, "content", "")
- ) or (
- message.get("content", "")
- if isinstance(message, dict)
- else getattr(message, "content", "")
- )
- reasoning_delta = (
- delta.get("reasoning_content", "")
- if isinstance(delta, dict)
- else getattr(delta, "reasoning_content", "")
- ) or (
- message.get("reasoning_content", "")
- if isinstance(message, dict)
- else getattr(message, "reasoning_content", "")
- )
-
- return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta)
-
-
-
-def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict):
- # for openrouter add app reference
- if provider_name == "openrouter":
- kwargs["extra_headers"] = {
- "HTTP-Referer": "https://agent-zero.ai",
- "X-Title": "Agent Zero",
- }
-
- # remap other to openai for litellm
- if provider_name == "other":
- provider_name = "openai"
-
- return provider_name, model_name, kwargs
-
-
-def _merge_provider_defaults(
- provider_type: str, original_provider: str, kwargs: dict
-) -> tuple[str, dict]:
- # Normalize .env-style numeric strings (e.g., "timeout=30") into ints/floats for LiteLLM
- def _normalize_values(values: dict) -> dict:
- result: dict[str, Any] = {}
- for k, v in values.items():
- if isinstance(v, str):
- try:
- result[k] = int(v)
- except ValueError:
- try:
- result[k] = float(v)
- except ValueError:
- result[k] = v
- else:
- result[k] = v
- return result
-
- provider_name = original_provider # default: unchanged
- cfg = get_provider_config(provider_type, original_provider)
- if cfg:
- provider_name = cfg.get("litellm_provider", original_provider).lower()
-
- # Extra arguments nested under `kwargs` for readability
- extra_kwargs = cfg.get("kwargs") if isinstance(cfg, dict) else None # type: ignore[arg-type]
- if isinstance(extra_kwargs, dict):
- for k, v in extra_kwargs.items():
- kwargs.setdefault(k, v)
-
- # Inject API key based on the *original* provider id if still missing
- if "api_key" not in kwargs:
- key = get_api_key(original_provider)
- if key and key not in ("None", "NA"):
- kwargs["api_key"] = key
-
- # Merge LiteLLM global kwargs (timeouts, stream_timeout, etc.)
- try:
- global_kwargs = settings.get_settings().get("litellm_global_kwargs", {}) # type: ignore[union-attr]
- except Exception:
- global_kwargs = {}
- if isinstance(global_kwargs, dict):
- for k, v in _normalize_values(global_kwargs).items():
- kwargs.setdefault(k, v)
-
- return provider_name, kwargs
-
-
-def get_chat_model(
- provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any
-) -> LiteLLMChatWrapper:
- orig = provider.lower()
- provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs)
- return _get_litellm_chat(
- LiteLLMChatWrapper, name, provider_name, model_config, **kwargs
- )
-
-
-def get_browser_model(
- provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any
-) -> BrowserCompatibleChatWrapper:
- orig = provider.lower()
- provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs)
- return _get_litellm_chat(
- BrowserCompatibleChatWrapper, name, provider_name, model_config, **kwargs
- )
-
-
-def get_embedding_model(
- provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any
-) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper:
- orig = provider.lower()
- provider_name, kwargs = _merge_provider_defaults("embedding", orig, kwargs)
- return _get_litellm_embedding(name, provider_name, model_config, **kwargs)
+§§include(/a0/models.py)
\ No newline at end of file
diff --git a/python/extensions/agent_init/_10_initial_message.py b/python/extensions/agent_init/_10_initial_message.py
index f64a3fce44..0da2814617 100644
--- a/python/extensions/agent_init/_10_initial_message.py
+++ b/python/extensions/agent_init/_10_initial_message.py
@@ -1,42 +1 @@
-import json
-from agent import LoopData
-from python.helpers.extension import Extension
-
-
-class InitialMessage(Extension):
-
- async def execute(self, **kwargs):
- """
- Add an initial greeting message when first user message is processed.
- Called only once per session via _process_chain method.
- """
-
- # Only add initial message for main agent (A0), not subordinate agents
- if self.agent.number != 0:
- return
-
- # If the context already contains log messages, do not add another initial message
- if self.agent.context.log.logs:
- return
-
- # Construct the initial message from prompt template
- initial_message = self.agent.read_prompt("fw.initial_message.md")
-
- # add initial loop data to agent (for hist_add_ai_response)
- self.agent.loop_data = LoopData(user_message=None)
-
- # Add the message to history as an AI response
- self.agent.hist_add_ai_response(initial_message)
-
- # json parse the message, get the tool_args text
- initial_message_json = json.loads(initial_message)
- initial_message_text = initial_message_json.get("tool_args", {}).get("text", "Hello! How can I help you?")
-
- # Add to log (green bubble) for immediate UI display
- self.agent.context.log.log(
- type="response",
- heading=f"{self.agent.agent_name}: Welcome",
- content=initial_message_text,
- finished=True,
- update_progress="none",
- )
+§§include(/a0/python/extensions/agent_init/_10_initial_message.py)
\ No newline at end of file