diff --git a/agents/github_specialist/_context.md b/agents/github_specialist/_context.md new file mode 100644 index 0000000000..8a4ebb3d0a --- /dev/null +++ b/agents/github_specialist/_context.md @@ -0,0 +1 @@ +§§include(/a0/agents/github_specialist/_context.md) \ No newline at end of file diff --git a/agents/github_specialist/prompts/agent.system.main.role.md b/agents/github_specialist/prompts/agent.system.main.role.md new file mode 100644 index 0000000000..6d01d07fbc --- /dev/null +++ b/agents/github_specialist/prompts/agent.system.main.role.md @@ -0,0 +1 @@ +§§include(/a0/agents/github_specialist/prompts/agent.system.main.role.md) \ No newline at end of file diff --git a/agents/tdd_test/_context.md b/agents/tdd_test/_context.md new file mode 100644 index 0000000000..c2563905aa --- /dev/null +++ b/agents/tdd_test/_context.md @@ -0,0 +1 @@ +§§include(/a0/agents/tdd_test/_context.md) \ No newline at end of file diff --git a/agents/tdd_test/prompts/agent.system.main.role.md b/agents/tdd_test/prompts/agent.system.main.role.md new file mode 100644 index 0000000000..7619d42682 --- /dev/null +++ b/agents/tdd_test/prompts/agent.system.main.role.md @@ -0,0 +1 @@ +§§include(/a0/agents/tdd_test/prompts/agent.system.main.role.md) \ No newline at end of file diff --git a/docs/res/a0-vector-graphics/banner.svg b/docs/res/a0-vector-graphics/banner.svg index 5a98670be1..2d2e751c39 100644 --- a/docs/res/a0-vector-graphics/banner.svg +++ b/docs/res/a0-vector-graphics/banner.svg @@ -1,51 +1 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +§§include(/a0/docs/res/a0-vector-graphics/banner.svg) \ No newline at end of file diff --git a/docs/res/a0-vector-graphics/dark.svg b/docs/res/a0-vector-graphics/dark.svg old mode 100755 new mode 100644 index 140246cd02..85ada8c5dd --- a/docs/res/a0-vector-graphics/dark.svg +++ b/docs/res/a0-vector-graphics/dark.svg @@ -1,20 +1 @@ - - - - - - - - - - - - - - - - - - - - \ No newline at end of file +§§include(/a0/docs/res/a0-vector-graphics/dark.svg) \ No newline at end of file diff --git a/docs/res/a0-vector-graphics/darkSymbol.svg b/docs/res/a0-vector-graphics/darkSymbol.svg old mode 100755 new mode 100644 index 893fc49b25..7e090eee64 --- a/docs/res/a0-vector-graphics/darkSymbol.svg +++ b/docs/res/a0-vector-graphics/darkSymbol.svg @@ -1,5 +1 @@ - - - - - \ No newline at end of file +§§include(/a0/docs/res/a0-vector-graphics/darkSymbol.svg) \ No newline at end of file diff --git a/docs/res/a0-vector-graphics/light.svg b/docs/res/a0-vector-graphics/light.svg old mode 100755 new mode 100644 index b8148ecf0d..3725732396 --- a/docs/res/a0-vector-graphics/light.svg +++ b/docs/res/a0-vector-graphics/light.svg @@ -1,20 +1 @@ - - - - - - - - - - - - - - - - - - - - \ No newline at end of file +§§include(/a0/docs/res/a0-vector-graphics/light.svg) \ No newline at end of file diff --git a/docs/res/a0-vector-graphics/lightSymbol.svg b/docs/res/a0-vector-graphics/lightSymbol.svg old mode 100755 new mode 100644 index c988a103ea..6643292991 --- a/docs/res/a0-vector-graphics/lightSymbol.svg +++ b/docs/res/a0-vector-graphics/lightSymbol.svg @@ -1,5 +1 @@ - - - - - \ No newline at end of file +§§include(/a0/docs/res/a0-vector-graphics/lightSymbol.svg) \ No newline at end of file diff --git a/models.py b/models.py index fbc2694dfd..2423bf15a9 100644 --- a/models.py +++ b/models.py @@ -1,919 +1 @@ -from dataclasses import dataclass, field -from enum import Enum -import logging -import os -from typing import ( - Any, - Awaitable, - Callable, - List, - Optional, - Iterator, - AsyncIterator, - Tuple, - TypedDict, -) - -from litellm import completion, acompletion, embedding -import litellm -import openai -from litellm.types.utils import ModelResponse - -from python.helpers import dotenv -from python.helpers import settings, dirty_json -from python.helpers.dotenv import load_dotenv -from python.helpers.providers import get_provider_config -from python.helpers.rate_limiter import RateLimiter -from python.helpers.tokens import approximate_tokens -from python.helpers import dirty_json, browser_use_monkeypatch - -from langchain_core.language_models.chat_models import SimpleChatModel -from langchain_core.outputs.chat_generation import ChatGenerationChunk -from langchain_core.callbacks.manager import ( - CallbackManagerForLLMRun, - AsyncCallbackManagerForLLMRun, -) -from langchain_core.messages import ( - BaseMessage, - AIMessageChunk, - HumanMessage, - SystemMessage, -) -from langchain.embeddings.base import Embeddings -from sentence_transformers import SentenceTransformer -from pydantic import ConfigDict - - -# disable extra logging, must be done repeatedly, otherwise browser-use will turn it back on for some reason -def turn_off_logging(): - os.environ["LITELLM_LOG"] = "ERROR" # only errors - litellm.suppress_debug_info = True - # Silence **all** LiteLLM sub-loggers (utils, cost_calculator…) - for name in logging.Logger.manager.loggerDict: - if name.lower().startswith("litellm"): - logging.getLogger(name).setLevel(logging.ERROR) - - -# init -load_dotenv() -turn_off_logging() -browser_use_monkeypatch.apply() - -litellm.modify_params = True # helps fix anthropic tool calls by browser-use - -class ModelType(Enum): - CHAT = "Chat" - EMBEDDING = "Embedding" - - -@dataclass -class ModelConfig: - type: ModelType - provider: str - name: str - api_base: str = "" - ctx_length: int = 0 - limit_requests: int = 0 - limit_input: int = 0 - limit_output: int = 0 - vision: bool = False - kwargs: dict = field(default_factory=dict) - - def build_kwargs(self): - kwargs = self.kwargs.copy() or {} - if self.api_base and "api_base" not in kwargs: - kwargs["api_base"] = self.api_base - return kwargs - - -class ChatChunk(TypedDict): - """Simplified response chunk for chat models.""" - response_delta: str - reasoning_delta: str - -class ChatGenerationResult: - """Chat generation result object""" - def __init__(self, chunk: ChatChunk|None = None): - self.reasoning = "" - self.response = "" - self.thinking = False - self.thinking_tag = "" - self.unprocessed = "" - self.native_reasoning = False - self.thinking_pairs = [("", ""), ("", "")] - if chunk: - self.add_chunk(chunk) - - def add_chunk(self, chunk: ChatChunk) -> ChatChunk: - if chunk["reasoning_delta"]: - self.native_reasoning = True - - # if native reasoning detection works, there's no need to worry about thinking tags - if self.native_reasoning: - processed_chunk = ChatChunk(response_delta=chunk["response_delta"], reasoning_delta=chunk["reasoning_delta"]) - else: - # if the model outputs thinking tags, we ned to parse them manually as reasoning - processed_chunk = self._process_thinking_chunk(chunk) - - self.reasoning += processed_chunk["reasoning_delta"] - self.response += processed_chunk["response_delta"] - - return processed_chunk - - def _process_thinking_chunk(self, chunk: ChatChunk) -> ChatChunk: - response_delta = self.unprocessed + chunk["response_delta"] - self.unprocessed = "" - return self._process_thinking_tags(response_delta, chunk["reasoning_delta"]) - - def _process_thinking_tags(self, response: str, reasoning: str) -> ChatChunk: - if self.thinking: - close_pos = response.find(self.thinking_tag) - if close_pos != -1: - reasoning += response[:close_pos] - response = response[close_pos + len(self.thinking_tag):] - self.thinking = False - self.thinking_tag = "" - else: - if self._is_partial_closing_tag(response): - self.unprocessed = response - response = "" - else: - reasoning += response - response = "" - else: - for opening_tag, closing_tag in self.thinking_pairs: - if response.startswith(opening_tag): - response = response[len(opening_tag):] - self.thinking = True - self.thinking_tag = closing_tag - - close_pos = response.find(closing_tag) - if close_pos != -1: - reasoning += response[:close_pos] - response = response[close_pos + len(closing_tag):] - self.thinking = False - self.thinking_tag = "" - else: - if self._is_partial_closing_tag(response): - self.unprocessed = response - response = "" - else: - reasoning += response - response = "" - break - elif len(response) < len(opening_tag) and self._is_partial_opening_tag(response, opening_tag): - self.unprocessed = response - response = "" - break - - return ChatChunk(response_delta=response, reasoning_delta=reasoning) - - def _is_partial_opening_tag(self, text: str, opening_tag: str) -> bool: - for i in range(1, len(opening_tag)): - if text == opening_tag[:i]: - return True - return False - - def _is_partial_closing_tag(self, text: str) -> bool: - if not self.thinking_tag or not text: - return False - max_check = min(len(text), len(self.thinking_tag) - 1) - for i in range(1, max_check + 1): - if text.endswith(self.thinking_tag[:i]): - return True - return False - - def output(self) -> ChatChunk: - response = self.response - reasoning = self.reasoning - if self.unprocessed: - if reasoning and not response: - reasoning += self.unprocessed - else: - response += self.unprocessed - return ChatChunk(response_delta=response, reasoning_delta=reasoning) - - -rate_limiters: dict[str, RateLimiter] = {} -api_keys_round_robin: dict[str, int] = {} - - -def get_api_key(service: str) -> str: - # get api key for the service - key = ( - dotenv.get_dotenv_value(f"API_KEY_{service.upper()}") - or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY") - or dotenv.get_dotenv_value(f"{service.upper()}_API_TOKEN") - or "None" - ) - # if the key contains a comma, use round-robin - if "," in key: - api_keys = [k.strip() for k in key.split(",") if k.strip()] - api_keys_round_robin[service] = api_keys_round_robin.get(service, -1) + 1 - key = api_keys[api_keys_round_robin[service] % len(api_keys)] - return key - - -def get_rate_limiter( - provider: str, name: str, requests: int, input: int, output: int -) -> RateLimiter: - key = f"{provider}\\{name}" - rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60)) - limiter.limits["requests"] = requests or 0 - limiter.limits["input"] = input or 0 - limiter.limits["output"] = output or 0 - return limiter - - -def _is_transient_litellm_error(exc: Exception) -> bool: - """Uses status_code when available, else falls back to exception types""" - # Prefer explicit status codes if present - status_code = getattr(exc, "status_code", None) - if isinstance(status_code, int): - if status_code in (408, 429, 500, 502, 503, 504): - return True - # Treat other 5xx as retriable - if status_code >= 500: - return True - return False - - # Fallback to exception classes mapped by LiteLLM/OpenAI - transient_types = ( - getattr(openai, "APITimeoutError", Exception), - getattr(openai, "APIConnectionError", Exception), - getattr(openai, "RateLimitError", Exception), - getattr(openai, "APIError", Exception), - getattr(openai, "InternalServerError", Exception), - # Some providers map overloads to ServiceUnavailable-like errors - getattr(openai, "APIStatusError", Exception), - ) - return isinstance(exc, transient_types) - - -async def apply_rate_limiter( - model_config: ModelConfig | None, - input_text: str, - rate_limiter_callback: ( - Callable[[str, str, int, int], Awaitable[bool]] | None - ) = None, -): - if not model_config: - return - limiter = get_rate_limiter( - model_config.provider, - model_config.name, - model_config.limit_requests, - model_config.limit_input, - model_config.limit_output, - ) - limiter.add(input=approximate_tokens(input_text)) - limiter.add(requests=1) - await limiter.wait(rate_limiter_callback) - return limiter - - -def apply_rate_limiter_sync( - model_config: ModelConfig | None, - input_text: str, - rate_limiter_callback: ( - Callable[[str, str, int, int], Awaitable[bool]] | None - ) = None, -): - if not model_config: - return - import asyncio, nest_asyncio - - nest_asyncio.apply() - return asyncio.run( - apply_rate_limiter(model_config, input_text, rate_limiter_callback) - ) - - -class LiteLLMChatWrapper(SimpleChatModel): - model_name: str - provider: str - kwargs: dict = {} - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="allow", - validate_assignment=False, - ) - - def __init__( - self, - model: str, - provider: str, - model_config: Optional[ModelConfig] = None, - **kwargs: Any, - ): - model_value = f"{provider}/{model}" - super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore - # Set A0 model config as instance attribute after parent init - self.a0_model_conf = model_config - - @property - def _llm_type(self) -> str: - return "litellm-chat" - - def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]: - result = [] - # Map LangChain message types to LiteLLM roles - role_mapping = { - "human": "user", - "ai": "assistant", - "system": "system", - "tool": "tool", - } - for m in messages: - role = role_mapping.get(m.type, m.type) - message_dict = {"role": role, "content": m.content} - - # Handle tool calls for AI messages - tool_calls = getattr(m, "tool_calls", None) - if tool_calls: - # Convert LangChain tool calls to LiteLLM format - new_tool_calls = [] - for tool_call in tool_calls: - # Ensure arguments is a JSON string - args = tool_call["args"] - if isinstance(args, dict): - import json - - args_str = json.dumps(args) - else: - args_str = str(args) - - new_tool_calls.append( - { - "id": tool_call.get("id", ""), - "type": "function", - "function": { - "name": tool_call["name"], - "arguments": args_str, - }, - } - ) - message_dict["tool_calls"] = new_tool_calls - - # Handle tool call ID for ToolMessage - tool_call_id = getattr(m, "tool_call_id", None) - if tool_call_id: - message_dict["tool_call_id"] = tool_call_id - - result.append(message_dict) - return result - - def _call( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - import asyncio - - msgs = self._convert_messages(messages) - - # Apply rate limiting if configured - apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) - - # Call the model - resp = completion( - model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs} - ) - - # Parse output - parsed = _parse_chunk(resp) - output = ChatGenerationResult(parsed).output() - return output["response_delta"] - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - import asyncio - - msgs = self._convert_messages(messages) - - # Apply rate limiting if configured - apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) - - result = ChatGenerationResult() - - for chunk in completion( - model=self.model_name, - messages=msgs, - stream=True, - stop=stop, - **{**self.kwargs, **kwargs}, - ): - # parse chunk - parsed = _parse_chunk(chunk) # chunk parsing - output = result.add_chunk(parsed) # chunk processing - - # Only yield chunks with non-None content - if output["response_delta"]: - yield ChatGenerationChunk( - message=AIMessageChunk(content=output["response_delta"]) - ) - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - msgs = self._convert_messages(messages) - - # Apply rate limiting if configured - await apply_rate_limiter(self.a0_model_conf, str(msgs)) - - result = ChatGenerationResult() - - response = await acompletion( - model=self.model_name, - messages=msgs, - stream=True, - stop=stop, - **{**self.kwargs, **kwargs}, - ) - async for chunk in response: # type: ignore - # parse chunk - parsed = _parse_chunk(chunk) # chunk parsing - output = result.add_chunk(parsed) # chunk processing - - # Only yield chunks with non-None content - if output["response_delta"]: - yield ChatGenerationChunk( - message=AIMessageChunk(content=output["response_delta"]) - ) - - async def unified_call( - self, - system_message="", - user_message="", - messages: List[BaseMessage] | None = None, - response_callback: Callable[[str, str], Awaitable[None]] | None = None, - reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None, - tokens_callback: Callable[[str, int], Awaitable[None]] | None = None, - rate_limiter_callback: ( - Callable[[str, str, int, int], Awaitable[bool]] | None - ) = None, - **kwargs: Any, - ) -> Tuple[str, str]: - - turn_off_logging() - - if not messages: - messages = [] - # construct messages - if system_message: - messages.insert(0, SystemMessage(content=system_message)) - if user_message: - messages.append(HumanMessage(content=user_message)) - - # convert to litellm format - msgs_conv = self._convert_messages(messages) - - # Apply rate limiting if configured - limiter = await apply_rate_limiter( - self.a0_model_conf, str(msgs_conv), rate_limiter_callback - ) - - # Prepare call kwargs and retry config (strip A0-only params before calling LiteLLM) - call_kwargs: dict[str, Any] = {**self.kwargs, **kwargs} - max_retries: int = int(call_kwargs.pop("a0_retry_attempts", 2)) - retry_delay_s: float = float(call_kwargs.pop("a0_retry_delay_seconds", 1.5)) - stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None - - # results - result = ChatGenerationResult() - - attempt = 0 - while True: - got_any_chunk = False - try: - # call model - _completion = await acompletion( - model=self.model_name, - messages=msgs_conv, - stream=stream, - **call_kwargs, - ) - - if stream: - # iterate over chunks - async for chunk in _completion: # type: ignore - got_any_chunk = True - # parse chunk - parsed = _parse_chunk(chunk) - output = result.add_chunk(parsed) - - # collect reasoning delta and call callbacks - if output["reasoning_delta"]: - if reasoning_callback: - await reasoning_callback(output["reasoning_delta"], result.reasoning) - if tokens_callback: - await tokens_callback( - output["reasoning_delta"], - approximate_tokens(output["reasoning_delta"]), - ) - # Add output tokens to rate limiter if configured - if limiter: - limiter.add(output=approximate_tokens(output["reasoning_delta"])) - # collect response delta and call callbacks - if output["response_delta"]: - if response_callback: - await response_callback(output["response_delta"], result.response) - if tokens_callback: - await tokens_callback( - output["response_delta"], - approximate_tokens(output["response_delta"]), - ) - # Add output tokens to rate limiter if configured - if limiter: - limiter.add(output=approximate_tokens(output["response_delta"])) - - # non-stream response - else: - parsed = _parse_chunk(_completion) - output = result.add_chunk(parsed) - if limiter: - if output["response_delta"]: - limiter.add(output=approximate_tokens(output["response_delta"])) - if output["reasoning_delta"]: - limiter.add(output=approximate_tokens(output["reasoning_delta"])) - - # Successful completion of stream - return result.response, result.reasoning - - except Exception as e: - import asyncio - - # Retry only if no chunks received and error is transient - if got_any_chunk or not _is_transient_litellm_error(e) or attempt >= max_retries: - raise - attempt += 1 - await asyncio.sleep(retry_delay_s) - - -class AsyncAIChatReplacement: - class _Completions: - def __init__(self, wrapper): - self._wrapper = wrapper - - async def create(self, *args, **kwargs): - # call the async _acall method on the wrapper - return await self._wrapper._acall(*args, **kwargs) - - class _Chat: - def __init__(self, wrapper): - self.completions = AsyncAIChatReplacement._Completions(wrapper) - - def __init__(self, wrapper, *args, **kwargs): - self._wrapper = wrapper - self.chat = AsyncAIChatReplacement._Chat(wrapper) - - -from browser_use.llm import ChatOllama, ChatOpenRouter, ChatGoogle, ChatAnthropic, ChatGroq, ChatOpenAI - -class BrowserCompatibleChatWrapper(ChatOpenRouter): - """ - A wrapper for browser agent that can filter/sanitize messages - before sending them to the LLM. - """ - - def __init__(self, *args, **kwargs): - turn_off_logging() - # Create the underlying LiteLLM wrapper - self._wrapper = LiteLLMChatWrapper(*args, **kwargs) - # Browser-use may expect a 'model' attribute - self.model = self._wrapper.model_name - self.kwargs = self._wrapper.kwargs - - @property - def model_name(self) -> str: - return self._wrapper.model_name - - @property - def provider(self) -> str: - return self._wrapper.provider - - def get_client(self, *args, **kwargs): # type: ignore - return AsyncAIChatReplacement(self, *args, **kwargs) - - async def _acall( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ): - # Apply rate limiting if configured - apply_rate_limiter_sync(self._wrapper.a0_model_conf, str(messages)) - - # Call the model - try: - model = kwargs.pop("model", None) - kwrgs = {**self._wrapper.kwargs, **kwargs} - - # hack from browser-use to fix json schema for gemini (additionalProperties, $defs, $ref) - if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] and model.startswith("gemini/"): - kwrgs["response_format"]["json_schema"] = ChatGoogle("")._fix_gemini_schema(kwrgs["response_format"]["json_schema"]) - - resp = await acompletion( - model=self._wrapper.model_name, - messages=messages, - stop=stop, - **kwrgs, - ) - - # Gemini: strip triple backticks and conform schema - try: - msg = resp.choices[0].message # type: ignore - if self.provider == "gemini" and isinstance(getattr(msg, "content", None), str): - cleaned = browser_use_monkeypatch.gemini_clean_and_conform(msg.content) # type: ignore - if cleaned: - msg.content = cleaned - except Exception: - pass - - except Exception as e: - raise e - - # another hack for browser-use post process invalid jsons - try: - if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] or "json_object" in kwrgs["response_format"]: - if resp.choices[0].message.content is not None and not resp.choices[0].message.content.startswith("{"): # type: ignore - js = dirty_json.parse(resp.choices[0].message.content) # type: ignore - resp.choices[0].message.content = dirty_json.stringify(js) # type: ignore - except Exception as e: - pass - - return resp - -class LiteLLMEmbeddingWrapper(Embeddings): - model_name: str - kwargs: dict = {} - a0_model_conf: Optional[ModelConfig] = None - - def __init__( - self, - model: str, - provider: str, - model_config: Optional[ModelConfig] = None, - **kwargs: Any, - ): - self.model_name = f"{provider}/{model}" if provider != "openai" else model - self.kwargs = kwargs - self.a0_model_conf = model_config - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - # Apply rate limiting if configured - apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) - - resp = embedding(model=self.model_name, input=texts, **self.kwargs) - return [ - item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore - for item in resp.data # type: ignore - ] - - def embed_query(self, text: str) -> List[float]: - # Apply rate limiting if configured - apply_rate_limiter_sync(self.a0_model_conf, text) - - resp = embedding(model=self.model_name, input=[text], **self.kwargs) - item = resp.data[0] # type: ignore - return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore - - -class LocalSentenceTransformerWrapper(Embeddings): - """Local wrapper for sentence-transformers models to avoid HuggingFace API calls""" - - def __init__( - self, - provider: str, - model: str, - model_config: Optional[ModelConfig] = None, - **kwargs: Any, - ): - # Clean common user-input mistakes - model = model.strip().strip('"').strip("'") - - # Remove the "sentence-transformers/" prefix if present - if model.startswith("sentence-transformers/"): - model = model[len("sentence-transformers/") :] - - # Filter kwargs for SentenceTransformer only (no LiteLLM params like 'stream_timeout') - st_allowed_keys = { - "device", - "cache_folder", - "use_auth_token", - "revision", - "trust_remote_code", - "model_kwargs", - } - st_kwargs = {k: v for k, v in (kwargs or {}).items() if k in st_allowed_keys} - - self.model = SentenceTransformer(model, **st_kwargs) - self.model_name = model - self.a0_model_conf = model_config - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - # Apply rate limiting if configured - apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) - - embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore - return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore - - def embed_query(self, text: str) -> List[float]: - # Apply rate limiting if configured - apply_rate_limiter_sync(self.a0_model_conf, text) - - embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore - result = ( - embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0] - ) - return result # type: ignore - - -def _get_litellm_chat( - cls: type = LiteLLMChatWrapper, - model_name: str = "", - provider_name: str = "", - model_config: Optional[ModelConfig] = None, - **kwargs: Any, -): - # use api key from kwargs or env - api_key = kwargs.pop("api_key", None) or get_api_key(provider_name) - - # Only pass API key if key is not a placeholder - if api_key and api_key not in ("None", "NA"): - kwargs["api_key"] = api_key - - provider_name, model_name, kwargs = _adjust_call_args( - provider_name, model_name, kwargs - ) - return cls( - provider=provider_name, model=model_name, model_config=model_config, **kwargs - ) - - -def _get_litellm_embedding( - model_name: str, - provider_name: str, - model_config: Optional[ModelConfig] = None, - **kwargs: Any, -): - # Check if this is a local sentence-transformers model - if provider_name == "huggingface" and model_name.startswith( - "sentence-transformers/" - ): - # Use local sentence-transformers instead of LiteLLM for local models - provider_name, model_name, kwargs = _adjust_call_args( - provider_name, model_name, kwargs - ) - return LocalSentenceTransformerWrapper( - provider=provider_name, - model=model_name, - model_config=model_config, - **kwargs, - ) - - # use api key from kwargs or env - api_key = kwargs.pop("api_key", None) or get_api_key(provider_name) - - # Only pass API key if key is not a placeholder - if api_key and api_key not in ("None", "NA"): - kwargs["api_key"] = api_key - - provider_name, model_name, kwargs = _adjust_call_args( - provider_name, model_name, kwargs - ) - return LiteLLMEmbeddingWrapper( - model=model_name, provider=provider_name, model_config=model_config, **kwargs - ) - - -def _parse_chunk(chunk: Any) -> ChatChunk: - delta = chunk["choices"][0].get("delta", {}) - message = chunk["choices"][0].get("message", {}) or chunk["choices"][0].get( - "model_extra", {} - ).get("message", {}) - response_delta = ( - delta.get("content", "") - if isinstance(delta, dict) - else getattr(delta, "content", "") - ) or ( - message.get("content", "") - if isinstance(message, dict) - else getattr(message, "content", "") - ) - reasoning_delta = ( - delta.get("reasoning_content", "") - if isinstance(delta, dict) - else getattr(delta, "reasoning_content", "") - ) or ( - message.get("reasoning_content", "") - if isinstance(message, dict) - else getattr(message, "reasoning_content", "") - ) - - return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta) - - - -def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict): - # for openrouter add app reference - if provider_name == "openrouter": - kwargs["extra_headers"] = { - "HTTP-Referer": "https://agent-zero.ai", - "X-Title": "Agent Zero", - } - - # remap other to openai for litellm - if provider_name == "other": - provider_name = "openai" - - return provider_name, model_name, kwargs - - -def _merge_provider_defaults( - provider_type: str, original_provider: str, kwargs: dict -) -> tuple[str, dict]: - # Normalize .env-style numeric strings (e.g., "timeout=30") into ints/floats for LiteLLM - def _normalize_values(values: dict) -> dict: - result: dict[str, Any] = {} - for k, v in values.items(): - if isinstance(v, str): - try: - result[k] = int(v) - except ValueError: - try: - result[k] = float(v) - except ValueError: - result[k] = v - else: - result[k] = v - return result - - provider_name = original_provider # default: unchanged - cfg = get_provider_config(provider_type, original_provider) - if cfg: - provider_name = cfg.get("litellm_provider", original_provider).lower() - - # Extra arguments nested under `kwargs` for readability - extra_kwargs = cfg.get("kwargs") if isinstance(cfg, dict) else None # type: ignore[arg-type] - if isinstance(extra_kwargs, dict): - for k, v in extra_kwargs.items(): - kwargs.setdefault(k, v) - - # Inject API key based on the *original* provider id if still missing - if "api_key" not in kwargs: - key = get_api_key(original_provider) - if key and key not in ("None", "NA"): - kwargs["api_key"] = key - - # Merge LiteLLM global kwargs (timeouts, stream_timeout, etc.) - try: - global_kwargs = settings.get_settings().get("litellm_global_kwargs", {}) # type: ignore[union-attr] - except Exception: - global_kwargs = {} - if isinstance(global_kwargs, dict): - for k, v in _normalize_values(global_kwargs).items(): - kwargs.setdefault(k, v) - - return provider_name, kwargs - - -def get_chat_model( - provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any -) -> LiteLLMChatWrapper: - orig = provider.lower() - provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) - return _get_litellm_chat( - LiteLLMChatWrapper, name, provider_name, model_config, **kwargs - ) - - -def get_browser_model( - provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any -) -> BrowserCompatibleChatWrapper: - orig = provider.lower() - provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) - return _get_litellm_chat( - BrowserCompatibleChatWrapper, name, provider_name, model_config, **kwargs - ) - - -def get_embedding_model( - provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any -) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper: - orig = provider.lower() - provider_name, kwargs = _merge_provider_defaults("embedding", orig, kwargs) - return _get_litellm_embedding(name, provider_name, model_config, **kwargs) +§§include(/a0/models.py) \ No newline at end of file diff --git a/python/extensions/agent_init/_10_initial_message.py b/python/extensions/agent_init/_10_initial_message.py index f64a3fce44..0da2814617 100644 --- a/python/extensions/agent_init/_10_initial_message.py +++ b/python/extensions/agent_init/_10_initial_message.py @@ -1,42 +1 @@ -import json -from agent import LoopData -from python.helpers.extension import Extension - - -class InitialMessage(Extension): - - async def execute(self, **kwargs): - """ - Add an initial greeting message when first user message is processed. - Called only once per session via _process_chain method. - """ - - # Only add initial message for main agent (A0), not subordinate agents - if self.agent.number != 0: - return - - # If the context already contains log messages, do not add another initial message - if self.agent.context.log.logs: - return - - # Construct the initial message from prompt template - initial_message = self.agent.read_prompt("fw.initial_message.md") - - # add initial loop data to agent (for hist_add_ai_response) - self.agent.loop_data = LoopData(user_message=None) - - # Add the message to history as an AI response - self.agent.hist_add_ai_response(initial_message) - - # json parse the message, get the tool_args text - initial_message_json = json.loads(initial_message) - initial_message_text = initial_message_json.get("tool_args", {}).get("text", "Hello! How can I help you?") - - # Add to log (green bubble) for immediate UI display - self.agent.context.log.log( - type="response", - heading=f"{self.agent.agent_name}: Welcome", - content=initial_message_text, - finished=True, - update_progress="none", - ) +§§include(/a0/python/extensions/agent_init/_10_initial_message.py) \ No newline at end of file