From 989f0c05bfeca8178079a49bf4344bbf6850a627 Mon Sep 17 00:00:00 2001 From: Ayush Karupakula Date: Mon, 16 Feb 2026 05:28:09 -0800 Subject: [PATCH] force answer tool usage at max step limit Ensure MCPAgent enforces an answer submission when max steps are reached by setting a force-answer flag and appending a final instruction. Update Claude, OpenAI, and Gemini agents to honor that flag with provider-native tool-choice restrictions so runs terminate with an answer instead of stalling. Co-authored-by: Cursor --- hud/agents/base.py | 121 +++++++++++++++++++++++++++++-------------- hud/agents/claude.py | 31 +++++++++-- hud/agents/gemini.py | 37 ++++++++++--- hud/agents/openai.py | 9 +++- 4 files changed, 145 insertions(+), 53 deletions(-) diff --git a/hud/agents/base.py b/hud/agents/base.py index 5d91cec7..624f96e8 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -487,6 +487,8 @@ async def _run_context( self.console.debug(f"Messages: {messages}") step_count = 0 + answer_submitted = False + while max_steps == -1 or step_count < max_steps: step_count += 1 if max_steps == -1: @@ -494,52 +496,91 @@ async def _run_context( else: self.console.debug(f"Step {step_count}/{max_steps}") + # If we've reached max_steps without submitting an answer, force the + # next model call to use the `answer` tool only. Provider-specific + # agents (Claude/OpenAI/Gemini) implement this by honoring the + # `_force_answer_only` attribute via their native tool-choice APIs. + if max_steps != -1 and step_count >= max_steps and not answer_submitted: + try: + tools = self.get_available_tools() + except RuntimeError: + tools = self._available_tools or [] + + if tools and any(getattr(t, "name", None) == "answer" for t in tools): + # Set a flag checked by provider agents to restrict tools + setattr(self, "_force_answer_only", True) + self.console.warning_log( + f"Reached max_steps ({max_steps}) without submitting answer. " + "Prompting model to submit answer now." + ) + # Also inject a clear instruction so the model knows what to do + messages.extend( + await self.format_message( + ( + "You have reached the maximum number of steps allowed. " + "You MUST use the `answer` tool NOW to submit your final " + "answer based on all the information you have gathered so far. " + "Do not use any other tools. Submit your answer immediately." + ) + ) + ) + try: # 1. Get model response response = await self.get_response(messages) self.console.debug(f"Agent:\n{response}") - # Check if we should stop - if response.done or not response.tool_calls: - # Use auto_respond to decide whether to stop - decision: Literal["STOP", "CONTINUE"] = "STOP" - if self.auto_respond and response.content: - try: - from hud.agents.misc import ResponseAgent - - response_agent = ResponseAgent() - decision = await response_agent.determine_response(response.content) - except Exception as e: - self.console.warning_log(f"Auto-respond failed: {e}") - if decision == "STOP": - self.console.debug("Stopping execution") - final_response = response - break - else: - self.console.debug("Continuing execution") - messages.extend(await self.format_message(decision)) - continue - - # 2. Execute tools - tool_calls = response.tool_calls - tool_results = await self.call_tools(tool_calls) - - # 3. Format tool results and add to messages - tool_messages = await self.format_tool_results(tool_calls, tool_results) - messages.extend(tool_messages) - - # Compact step completion display - step_info = f"\n[bold]Step {step_count}" - if max_steps != -1: - step_info += f"/{max_steps}" - step_info += "[/bold]" - - # Show tool calls and results in compact format - for call, result in zip(tool_calls, tool_results, strict=False): - step_info += f"\n{call}\n{result}" - - self.console.info_log(step_info) + # Track whether an answer was submitted in this step + if response.tool_calls: + for tool_call in response.tool_calls: + if getattr(tool_call, "name", "") == "answer": + answer_submitted = True + # Clear forced-answer flag once we've seen an answer + if hasattr(self, "_force_answer_only"): + setattr(self, "_force_answer_only", False) + break + + # Check if we should stop + if response.done or not response.tool_calls: + # Use auto_respond to decide whether to stop + decision: Literal["STOP", "CONTINUE"] = "STOP" + if self.auto_respond and response.content: + try: + from hud.agents.misc import ResponseAgent + + response_agent = ResponseAgent() + decision = await response_agent.determine_response(response.content) + except Exception as e: + self.console.warning_log(f"Auto-respond failed: {e}") + if decision == "STOP": + self.console.debug("Stopping execution") + final_response = response + break + else: + self.console.debug("Continuing execution") + messages.extend(await self.format_message(decision)) + continue + + # 2. Execute tools + tool_calls = response.tool_calls + tool_results = await self.call_tools(tool_calls) + + # 3. Format tool results and add to messages + tool_messages = await self.format_tool_results(tool_calls, tool_results) + messages.extend(tool_messages) + + # Compact step completion display + step_info = f"\n[bold]Step {step_count}" + if max_steps != -1: + step_info += f"/{max_steps}" + step_info += "[/bold]" + + # Show tool calls and results in compact format + for call, result in zip(tool_calls, tool_results, strict=False): + step_info += f"\n{call}\n{result}" + + self.console.info_log(step_info) except Exception as e: self.console.error_log(f"Step failed: {e}") diff --git a/hud/agents/claude.py b/hud/agents/claude.py index b98e2bc3..f5ec19d9 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -199,6 +199,29 @@ async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: """Get response from Claude including any tool calls.""" messages_cached = self._add_prompt_caching(messages) + # Support forced-answer behavior from MCPAgent by restricting tools and + # tool_choice when _force_answer_only is set. This mirrors sandbox's + # per-message forceToolOptions behavior. + force_answer_only = getattr(self, "_force_answer_only", False) + + tools_for_call: list[BetaToolUnionParam] = self.claude_tools + tool_choice: dict[str, Any] = { + "type": "auto", + "disable_parallel_tool_use": True, + } + + if force_answer_only: + answer_tools = [ + t for t in self.claude_tools if getattr(t, "name", None) == "answer" + ] + if answer_tools: + tools_for_call = answer_tools + tool_choice = { + "type": "tool", + "name": "answer", + "disable_parallel_tool_use": True, + } + # betas to use - collected during tool conversion based on native specs # Only pass betas when non-empty; an empty list can produce an empty # anthropic-beta header which the API rejects. @@ -212,8 +235,8 @@ async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: system=self.system_prompt if self.system_prompt is not None else Omit(), max_tokens=self.max_tokens, messages=messages_cached, - tools=self.claude_tools, - tool_choice={"type": "auto", "disable_parallel_tool_use": True}, + tools=tools_for_call, + tool_choice=tool_choice, betas=betas, ) messages.append(BetaMessageParam(role="assistant", content=response.content)) @@ -228,8 +251,8 @@ async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: system=self.system_prompt if self.system_prompt is not None else Omit(), max_tokens=self.max_tokens, messages=messages_cached, - tools=self.claude_tools, - tool_choice={"type": "auto", "disable_parallel_tool_use": True}, + tools=tools_for_call, + tool_choice=tool_choice, betas=betas, ) as stream: # allow backend to accumulate message content diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py index 08b49eee..86534d8d 100644 --- a/hud/agents/gemini.py +++ b/hud/agents/gemini.py @@ -143,14 +143,35 @@ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_ty async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse: """Get response from Gemini including any tool calls.""" # Build generate content config - generate_config = genai_types.GenerateContentConfig( - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - max_output_tokens=self.max_output_tokens, - tools=self.gemini_tools, - system_instruction=self.system_prompt, - ) + force_answer_only = getattr(self, "_force_answer_only", False) + + tool_config: genai_types.ToolConfig | None = None + if force_answer_only: + # Restrict Gemini's function calling to the `answer` function only, + # matching sandbox's per-message forceToolOptions behavior. + try: + tool_config = genai_types.ToolConfig( + function_calling_config=genai_types.FunctionCallingConfig( + mode="ANY", + allowed_function_names=["answer"], + ) + ) + except Exception as exc: # pragma: no cover - older SDKs + logger.debug(f"Could not construct FunctionCallingConfig: {exc}") + tool_config = None + + generate_kwargs: dict[str, Any] = { + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "max_output_tokens": self.max_output_tokens, + "tools": self.gemini_tools, + "system_instruction": self.system_prompt, + } + if tool_config is not None: + generate_kwargs["tool_config"] = tool_config + + generate_config = genai_types.GenerateContentConfig(**generate_kwargs) # Use async API to avoid blocking the event loop response = await self.gemini_client.aio.models.generate_content( diff --git a/hud/agents/openai.py b/hud/agents/openai.py index fb9bf1a7..03b80797 100644 --- a/hud/agents/openai.py +++ b/hud/agents/openai.py @@ -344,7 +344,14 @@ async def get_response(self, messages: ResponseInputParam) -> AgentResponse: instructions=self.system_prompt, max_output_tokens=self.max_output_tokens, temperature=self.temperature, - tool_choice=self.tool_choice if self.tool_choice is not None else Omit(), + # Per-step forced-answer behavior: when the base agent sets + # `_force_answer_only`, override tool_choice so the model MUST call + # the `answer` function on this request. + tool_choice=( + ToolChoice(type="function", function={"name": "answer"}) + if getattr(self, "_force_answer_only", False) + else (self.tool_choice if self.tool_choice is not None else Omit()) + ), parallel_tool_calls=self.parallel_tool_calls, reasoning=self.reasoning if self.reasoning is not None else Omit(), tools=self._openai_tools if self._openai_tools else Omit(),