Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing as T
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Protocol, runtime_checkable

from ..message import Message
Expand Down Expand Up @@ -154,6 +156,7 @@ def __init__(
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
use_compact_api: bool = True,
) -> None:
"""Initialize the LLM summary compressor.

Expand All @@ -162,10 +165,12 @@ def __init__(
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
use_compact_api: Whether to prefer provider native compact API when available.
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.use_compact_api = use_compact_api

self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
Expand Down Expand Up @@ -193,38 +198,84 @@ def should_compress(
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold

def _supports_native_compact(self) -> bool:
support_native_compact = getattr(self.provider, "supports_native_compact", None)
if not callable(support_native_compact):
return False
try:
return bool(support_native_compact())
except Exception:
return False

async def _try_native_compact(
self,
system_messages: list[Message],
messages_to_summarize: list[Message],
recent_messages: list[Message],
) -> list[Message] | None:
compact_context = getattr(self.provider, "compact_context", None)
if not callable(compact_context):
return None

compact_context_callable = T.cast(
"Callable[[list[Message]], Awaitable[list[Message]]]",
compact_context,
)

try:
compacted_messages = await compact_context_callable(messages_to_summarize)
except Exception as e:
logger.warning(
f"Native compact failed, fallback to summary compression: {e}"
)
return None

if not compacted_messages:
return None

result: list[Message] = []
result.extend(system_messages)
result.extend(compacted_messages)
result.extend(recent_messages)
return result

async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.

Process:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].
2. Prefer native compact when provider supports it.
3. Fallback to LLM summary and reconstruct message list.
"""
if len(messages) <= self.keep_recent + 1:
return messages

system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)

if not messages_to_summarize:
return messages

# build payload
# Only try native compact if user allows it and provider supports it
if self.use_compact_api and self._supports_native_compact():
compacted = await self._try_native_compact(
system_messages,
messages_to_summarize,
recent_messages,
)
if compacted is not None:
return compacted
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages

# build result
result = []
result: list[Message] = []
result.extend(system_messages)

result.append(
Expand Down
2 changes: 2 additions & 0 deletions astrbot/core/agent/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class ContextConfig:
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
llm_compress_use_compact_api: bool = True
"""Whether to prefer provider native compact API when available."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
Expand Down
5 changes: 3 additions & 2 deletions astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
use_compact_api=config.llm_compress_use_compact_api,
)
else:
self.compressor = TruncateByTurnsCompressor(
Expand All @@ -55,15 +56,15 @@ async def process(
try:
result = messages

# 1. 基于轮次的截断 (Enforce max turns)
# 1. Enforce max turns
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)

# 2. 基于 token 的压缩
# 2. Token-based compression
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
Expand Down
37 changes: 20 additions & 17 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def reset(
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
llm_compress_use_compact_api: bool = True,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
Expand All @@ -99,6 +100,7 @@ async def reset(
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.llm_compress_use_compact_api = llm_compress_use_compact_api
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
Expand All @@ -114,6 +116,7 @@ async def reset(
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
llm_compress_use_compact_api=self.llm_compress_use_compact_api,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
Expand Down Expand Up @@ -659,24 +662,24 @@ async def _handle_function_tools(
),
)

# yield the last tool call result
if tool_call_result_blocks:
last_tcr_content = str(tool_call_result_blocks[-1].content)
yield _HandleFunctionToolsResult.from_message_chain(
MessageChain(
type="tool_call_result",
chain=[
Json(
data={
"id": func_tool_id,
"ts": time.time(),
"result": last_tcr_content,
}
)
],
# yield the tool call result
if tool_call_result_blocks:
last_tcr_content = str(tool_call_result_blocks[-1].content)
yield _HandleFunctionToolsResult.from_message_chain(
MessageChain(
type="tool_call_result",
chain=[
Json(
data={
"id": func_tool_id,
"ts": time.time(),
"result": last_tcr_content,
}
)
],
)
)
)
logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}")
logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}")

# 处理函数调用响应
if tool_call_result_blocks:
Expand Down
57 changes: 36 additions & 21 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class MainAgentBuildConfig:
"""The number of most recent turns to keep during llm_compress strategy."""
llm_compress_provider_id: str = ""
"""The provider ID for the LLM used in context compression."""
llm_compress_use_compact_api: bool = True
"""Whether to prefer provider native compact API when available."""
max_context_length: int = -1
"""The maximum number of turns to keep in context. -1 means no limit.
This enforce max turns before compression"""
Expand Down Expand Up @@ -742,17 +744,22 @@ async def _handle_webchat(
if not user_prompt or not chatui_session_id or not session or session.display_name:
return

llm_resp = await prov.text_chat(
system_prompt=(
"You are a conversation title generator. "
"Generate a concise title in the same language as the user’s input, "
"no more than 10 words, capturing only the core topic."
"If the input is a greeting, small talk, or has no clear topic, "
"(e.g., “hi”, “hello”, “haha”), return <None>. "
"Output only the title itself or <None>, with no explanations."
),
prompt=f"Generate a concise title for the following user query:\n{user_prompt}",
)
try:
llm_resp = await prov.text_chat(
system_prompt=(
"You are a conversation title generator. "
"Generate a concise title in the same language as the user's input, "
"no more than 10 words, capturing only the core topic."
"If the input is a greeting, small talk, or has no clear topic, "
'(e.g., "hi", "hello", "haha"), return <None>. '
"Output only the title itself or <None>, with no explanations."
),
prompt=f"Generate a concise title for the following user query:\n{user_prompt}",
)
except Exception as e:
logger.warning("Failed to generate chatui title: %s", e)
return

if llm_resp and llm_resp.completion_text:
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
Expand Down Expand Up @@ -807,26 +814,33 @@ def _proactive_cron_job_tools(req: ProviderRequest) -> None:


def _get_compress_provider(
config: MainAgentBuildConfig, plugin_context: Context
config: MainAgentBuildConfig,
plugin_context: Context,
active_provider: Provider | None,
) -> Provider | None:
if not config.llm_compress_provider_id:
return None
if config.context_limit_reached_strategy != "llm_compress":
return None
provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id)
if provider is None:

if not config.llm_compress_provider_id:
return None

selected_provider = plugin_context.get_provider_by_id(
config.llm_compress_provider_id
)
if selected_provider is None:
logger.warning(
"未找到指定的上下文压缩模型 %s,将跳过压缩。",
"Configured llm_compress_provider_id not found: %s. Skip compression.",
config.llm_compress_provider_id,
)
return None
if not isinstance(provider, Provider):
if not isinstance(selected_provider, Provider):
logger.warning(
"指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。",
"Configured llm_compress_provider_id is not a Provider: %s. Skip compression.",
config.llm_compress_provider_id,
)
return None
return provider

return selected_provider


async def build_main_agent(
Expand Down Expand Up @@ -970,7 +984,8 @@ async def build_main_agent(
streaming=config.streaming_response,
llm_compress_instruction=config.llm_compress_instruction,
llm_compress_keep_recent=config.llm_compress_keep_recent,
llm_compress_provider=_get_compress_provider(config, plugin_context),
llm_compress_provider=_get_compress_provider(config, plugin_context, provider),
llm_compress_use_compact_api=config.llm_compress_use_compact_api,
truncate_turns=config.dequeue_context_length,
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
Expand Down
23 changes: 23 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
),
"llm_compress_keep_recent": 6,
"llm_compress_provider_id": "",
"llm_compress_use_compact_api": True,
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
Expand Down Expand Up @@ -929,6 +930,19 @@ class ChatProviderTemplate(TypedDict):
"proxy": "",
"custom_headers": {},
},
"OpenAI Responses": {
"id": "openai_responses",
"provider": "openai",
"type": "openai_responses",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"proxy": "",
"custom_headers": {},
"custom_extra_body": {},
},
"Google Gemini": {
"id": "google_gemini",
"provider": "google",
Expand Down Expand Up @@ -2828,6 +2842,15 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_use_compact_api": {
"description": "Prefer compact API when available",
"type": "bool",
"hint": "When enabled, local runner first tries provider native compact API and falls back to LLM summary compression.",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.llm_compress_provider_id: str = settings.get(
"llm_compress_provider_id", ""
)
self.llm_compress_use_compact_api: bool = settings.get(
"llm_compress_use_compact_api", True
)
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
Expand Down Expand Up @@ -113,6 +116,7 @@ async def initialize(self, ctx: PipelineContext) -> None:
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider_id=self.llm_compress_provider_id,
llm_compress_use_compact_api=self.llm_compress_use_compact_api,
max_context_length=self.max_context_length,
dequeue_context_length=self.dequeue_context_length,
llm_safety_mode=self.llm_safety_mode,
Expand Down
Loading