diff --git a/src/mcp_cli/commands/__init__.py b/src/mcp_cli/commands/__init__.py index 58049d23..8f81c37d 100644 --- a/src/mcp_cli/commands/__init__.py +++ b/src/mcp_cli/commands/__init__.py @@ -109,6 +109,7 @@ def register_all_commands() -> None: from mcp_cli.commands.apps import AppsCommand from mcp_cli.commands.memory import MemoryCommand from mcp_cli.commands.plan import PlanCommand + from mcp_cli.commands.cmd import CmdCommand from mcp_cli.commands.attach import AttachCommand # Register basic commands @@ -161,6 +162,8 @@ def register_all_commands() -> None: # Register plan command registry.register(PlanCommand()) + # Register cmd command (CLI-only) + registry.register(CmdCommand()) # Register attach command (multi-modal file staging) registry.register(AttachCommand()) diff --git a/src/mcp_cli/commands/cmd/__init__.py b/src/mcp_cli/commands/cmd/__init__.py new file mode 100644 index 00000000..317a2c9f --- /dev/null +++ b/src/mcp_cli/commands/cmd/__init__.py @@ -0,0 +1,5 @@ +"""Command mode command.""" + +from mcp_cli.commands.cmd.cmd import CmdCommand + +__all__ = ["CmdCommand"] diff --git a/src/mcp_cli/commands/cmd/cmd.py b/src/mcp_cli/commands/cmd/cmd.py new file mode 100644 index 00000000..97bb48d0 --- /dev/null +++ b/src/mcp_cli/commands/cmd/cmd.py @@ -0,0 +1,472 @@ +# src/mcp_cli/commands/cmd/cmd.py +"""Command mode for Unix-friendly automation and scripting.""" + +from __future__ import annotations + +import json +import logging +import sys +from pathlib import Path +from typing import Any + +from chuk_term.ui import output +from mcp_cli.commands.base import ( + CommandMode, + CommandParameter, + CommandResult, + UnifiedCommand, +) +from mcp_cli.utils.serialization import to_serializable, unwrap_tool_result + +logger = logging.getLogger(__name__) + + +class CmdCommand(UnifiedCommand): + """Command mode for Unix-friendly automation and scripting.""" + + @property + def name(self) -> str: + return "cmd" + + @property + def description(self) -> str: + return "Command mode for Unix-friendly automation and scripting" + + @property + def modes(self) -> CommandMode: + return CommandMode.CLI + + @property + def aliases(self) -> list[str]: + return [] + + @property + def parameters(self) -> list[CommandParameter]: + return [ + CommandParameter( + name="input_file", type=str, required=False, + help="Input file (use - for stdin)", + ), + CommandParameter( + name="output_file", type=str, required=False, + help="Output file (use - for stdout)", + ), + CommandParameter( + name="prompt", type=str, required=False, + help="Prompt text", + ), + CommandParameter( + name="tool", type=str, required=False, + help="Tool name to execute", + ), + CommandParameter( + name="tool_args", type=str, required=False, + help="Tool arguments as JSON", + ), + CommandParameter( + name="system_prompt", type=str, required=False, + help="Custom system prompt", + ), + CommandParameter( + name="raw", type=bool, default=False, is_flag=True, + help="Raw output without formatting", + ), + CommandParameter( + name="single_turn", type=bool, default=False, is_flag=True, + help="Disable multi-turn conversation", + ), + CommandParameter( + name="max_turns", type=int, default=100, + help="Maximum conversation turns", + ), + ] + + @property + def help_text(self) -> str: + return """ +Command mode for Unix-friendly automation and scripting. + +Usage: + mcp-cli cmd --tool Execute a tool directly + mcp-cli cmd --tool --tool-args '{...}' Execute tool with arguments + mcp-cli cmd --prompt "Summarize this" Use LLM with a prompt + mcp-cli cmd --input data.txt --prompt "..." Combine file input with prompt + echo "text" | mcp-cli cmd --input - Read from stdin +""" + + @property + def requires_context(self) -> bool: + return True + + async def execute(self, **kwargs: Any) -> CommandResult: + """Execute the cmd command.""" + tool_manager = kwargs.get("tool_manager") + + tool = kwargs.get("tool") + tool_args = kwargs.get("tool_args") + input_file = kwargs.get("input_file") + output_file = kwargs.get("output_file") + prompt = kwargs.get("prompt") + system_prompt = kwargs.get("system_prompt") + raw = kwargs.get("raw", False) + single_turn = kwargs.get("single_turn", False) + max_turns = kwargs.get("max_turns", 100) + + # Branch 1: Tool direct execution + if tool: + return await self._execute_tool_direct( + tool_manager=tool_manager, + tool_name=tool, + tool_args_json=tool_args, + output_file=output_file, + raw=raw, + ) + + # Branch 2: Prompt / LLM mode + if prompt or input_file: + return await self._execute_prompt_mode( + tool_manager=tool_manager, + model_manager=kwargs.get("model_manager"), + input_file=input_file, + output_file=output_file, + prompt=prompt, + system_prompt=system_prompt, + raw=raw, + single_turn=single_turn, + max_turns=max_turns, + ) + + # Branch 3: No operation specified + return CommandResult( + success=False, + error="No operation specified. Use --tool or --prompt/--input", + ) + + # ── tool direct execution ──────────────────────────────────────── + + async def _execute_tool_direct( + self, + tool_manager: Any | None, + tool_name: str, + tool_args_json: str | None, + output_file: str | None, + raw: bool, + ) -> CommandResult: + """Execute a tool directly without LLM interaction.""" + if not tool_manager: + return CommandResult( + success=False, + error="Tool manager not available. Are servers connected?", + ) + + # Parse tool arguments + tool_args: dict[str, Any] = {} + if tool_args_json: + try: + tool_args = json.loads(tool_args_json) + except json.JSONDecodeError as e: + return CommandResult( + success=False, + error=f"Invalid JSON in tool arguments: {e}", + ) + + try: + if not raw: + output.info(f"Executing tool: {tool_name}") + + tool_call_result = await tool_manager.execute_tool(tool_name, tool_args) + + if not tool_call_result.success or tool_call_result.error: + return CommandResult( + success=False, + error=f"Tool execution failed: {tool_call_result.error}", + ) + + result_data = tool_call_result.result + + # Unwrap middleware ToolExecutionResult if present + result_data = unwrap_tool_result(result_data) + + # Convert to JSON-serializable form + result_data = to_serializable(result_data) + + # Format result + result_str = ( + json.dumps(result_data, indent=None if raw else 2) + if not isinstance(result_data, str) + else result_data + ) + + # Write output + if output_file and output_file != "-": + Path(output_file).write_text(result_str) + if not raw: + output.success(f"Output written to: {output_file}") + return CommandResult(success=True, data=result_data) + else: + print(result_str) + return CommandResult(success=True, data=result_data) + + except Exception as e: + return CommandResult( + success=False, + error=f"Tool execution failed: {e}", + ) + + # ── prompt / LLM mode ──────────────────────────────────────────── + + async def _execute_prompt_mode( + self, + tool_manager: Any | None, + model_manager: Any | None, + input_file: str | None, + output_file: str | None, + prompt: str | None, + system_prompt: str | None, + raw: bool, + single_turn: bool, + max_turns: int, + ) -> CommandResult: + """Execute prompt mode with LLM interaction.""" + from mcp_cli.context import get_context + + context = get_context() + + # Read input + input_text = "" + if input_file: + if input_file == "-": + input_text = sys.stdin.read() + else: + try: + input_text = Path(input_file).read_text() + except FileNotFoundError: + return CommandResult( + success=False, + error=f"Input file not found: {input_file}", + ) + + # Build full prompt + if prompt and input_text: + full_prompt = f"{prompt}\n\nInput:\n{input_text}" + elif prompt: + full_prompt = prompt + elif input_text: + full_prompt = input_text + else: + return CommandResult(success=False, error="No prompt or input provided") + + # Get LLM client + effective_mm = model_manager or (context.model_manager if context else None) + if not effective_mm: + return CommandResult( + success=False, error="Model manager not available." + ) + + if not context: + return CommandResult( + success=False, error="Context not initialized." + ) + + try: + client = effective_mm.get_client( + provider=context.provider, + model=context.model, + ) + except Exception as e: + return CommandResult( + success=False, + error=f"Failed to initialize LLM client: {e}", + ) + + # Build messages + messages: list[dict[str, Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": full_prompt}) + + # Get tools for LLM if available + tools = None + if tool_manager and not single_turn: + tools = await tool_manager.get_tools_for_llm() + + try: + response = await client.create_completion( + model=context.model, + messages=messages, + tools=tools, + max_tokens=4096, + ) + + result_text = response.get("response", "") + tool_calls = response.get("tool_calls", []) + + # Handle tool calls if present + if tool_calls and not single_turn and tool_manager: + result_text = await self._handle_tool_calls( + client=client, + model=context.model, + tool_manager=tool_manager, + messages=messages, + tool_calls=tool_calls, + response_text=result_text, + max_turns=max_turns, + raw=raw, + ) + + # Write output + if output_file and output_file != "-": + Path(output_file).write_text(result_text) + if not raw: + output.success(f"Output written to: {output_file}") + return CommandResult(success=True, data=result_text) + else: + print(result_text) + return CommandResult(success=True, data=result_text) + + except Exception as e: + return CommandResult( + success=False, + error=f"LLM execution failed: {e}", + ) + + # ── multi-turn tool call loop ──────────────────────────────────── + + async def _handle_tool_calls( + self, + client: Any, + model: str, + tool_manager: Any, + messages: list[dict[str, Any]], + tool_calls: list[Any], + response_text: str, + max_turns: int, + raw: bool, + ) -> str: + """Handle tool calls in multi-turn conversation.""" + messages.append({ + "role": "assistant", + "content": response_text, + "tool_calls": tool_calls, + }) + + await self._execute_tool_call_batch( + tool_manager, tool_calls, messages, raw + ) + + turns = 1 + while turns < max_turns: + tools = await tool_manager.get_tools_for_llm() + response = await client.create_completion( + model=model, + messages=messages, + tools=tools, + max_tokens=4096, + ) + + response_text = response.get("response", "") + new_tool_calls = response.get("tool_calls", []) + + if not new_tool_calls: + return response_text + + messages.append({ + "role": "assistant", + "content": response_text, + "tool_calls": new_tool_calls, + }) + + await self._execute_tool_call_batch( + tool_manager, new_tool_calls, messages, raw + ) + + turns += 1 + + # Max turns reached — get a final synthesized response without tools + if not raw: + output.warning(f"Max turns ({max_turns}) reached") + + try: + final_response = await client.create_completion( + model=model, + messages=messages, + max_tokens=4096, + ) + return final_response.get("response", response_text) + except Exception: + return response_text + + async def _execute_tool_call_batch( + self, + tool_manager: Any, + tool_calls: list[Any], + messages: list[dict[str, Any]], + raw: bool, + ) -> None: + """Execute a batch of tool calls and append results to messages.""" + for tool_call in tool_calls: + tool_name, tool_args_str, tool_call_id = _parse_tool_call(tool_call) + + try: + tool_args = ( + json.loads(tool_args_str) + if isinstance(tool_args_str, str) + else tool_args_str + ) + except json.JSONDecodeError: + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": tool_name, + "content": f"Error: Invalid JSON in tool arguments: {tool_args_str}", + }) + continue + + if not raw: + output.info(f"Executing tool: {tool_name}") + + try: + result = await tool_manager.execute_tool(tool_name, tool_args) + if result.success: + result_data = to_serializable( + unwrap_tool_result(result.result) + ) + else: + result_data = f"Error: {result.error}" + result_str = ( + json.dumps(result_data) + if not isinstance(result_data, str) + else result_data + ) + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": tool_name, + "content": result_str, + }) + except Exception as e: + output.error(f"Tool execution failed: {e}") + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "name": tool_name, + "content": f"Error: {e}", + }) + + +def _parse_tool_call(tool_call: Any) -> tuple[str, str, str]: + """Extract (tool_name, arguments_str, call_id) from a tool_call.""" + if isinstance(tool_call, dict): + return ( + tool_call.get("function", {}).get("name", ""), + tool_call.get("function", {}).get("arguments", "{}"), + tool_call.get("id", ""), + ) + try: + return ( + tool_call.function.name, + tool_call.function.arguments, + tool_call.id, + ) + except AttributeError: + return ("", "{}", "") diff --git a/src/mcp_cli/commands/tools/execute_tool.py b/src/mcp_cli/commands/tools/execute_tool.py index 1264c9b1..10a5f3c3 100644 --- a/src/mcp_cli/commands/tools/execute_tool.py +++ b/src/mcp_cli/commands/tools/execute_tool.py @@ -23,56 +23,10 @@ JSON_TYPE_STRING, ) from mcp_cli.tools.manager import ToolManager +from mcp_cli.utils.serialization import to_serializable, unwrap_tool_result from chuk_term.ui import output -def _to_serializable(obj: Any) -> Any: - """Convert an object to a JSON-serializable form. - - Handles MCP SDK ToolResult, Pydantic models, and other non-serializable types. - """ - # Handle None - if obj is None: - return None - - # Handle primitives - if isinstance(obj, (str, int, float, bool)): - return obj - - # Handle lists - if isinstance(obj, list): - return [_to_serializable(item) for item in obj] - - # Handle dicts - if isinstance(obj, dict): - return {k: _to_serializable(v) for k, v in obj.items()} - - # Handle Pydantic models (they have model_dump or dict method) - if hasattr(obj, "model_dump"): - return _to_serializable(obj.model_dump()) - if hasattr(obj, "dict"): - return _to_serializable(obj.dict()) - - # Handle MCP SDK ToolResult (has content attribute) - if hasattr(obj, "content"): - content = obj.content - # Content is typically a list of TextContent/ImageContent objects - if isinstance(content, list): - result_parts = [] - for item in content: - if hasattr(item, "text"): - result_parts.append(item.text) - elif hasattr(item, "model_dump"): - result_parts.append(_to_serializable(item.model_dump())) - else: - result_parts.append(str(item)) - return "\n".join(result_parts) if len(result_parts) == 1 else result_parts - return _to_serializable(content) - - # Fallback to string representation - return str(obj) - - class ExecuteToolCommand(UnifiedCommand): """Command to execute a tool with parameters.""" @@ -351,10 +305,10 @@ async def execute( if isinstance(result, ToolCallResult): if result.success and result.result is not None: - output.success("✅ Tool executed successfully") # Extract the actual result from ToolCallResult - # Use _to_serializable to handle MCP SDK types - serializable_result = _to_serializable(result.result) + # Unwrap middleware wrappers, then serialize + serializable_result = to_serializable(unwrap_tool_result(result.result)) + output.success("✅ Tool executed successfully") if isinstance(serializable_result, (dict, list)): output.print(json.dumps(serializable_result, indent=2)) else: @@ -395,8 +349,8 @@ async def execute( else: output.warning("Tool returned no result") elif isinstance(result, dict): # type: ignore[unreachable] + serializable_result = to_serializable(unwrap_tool_result(result)) output.success("✅ Tool executed successfully") - serializable_result = _to_serializable(result) output.print(json.dumps(serializable_result, indent=2)) else: output.success("✅ Tool executed successfully") diff --git a/src/mcp_cli/main.py b/src/mcp_cli/main.py index 0d8c1ef7..ccb353a5 100644 --- a/src/mcp_cli/main.py +++ b/src/mcp_cli/main.py @@ -1675,7 +1675,7 @@ async def _cmd_wrapper(**params): ) -direct_registered.append("cmd") +# "cmd" is registered via unified registry (CmdCommand), no need for direct_registered # Ping command - test connectivity diff --git a/src/mcp_cli/utils/serialization.py b/src/mcp_cli/utils/serialization.py new file mode 100644 index 00000000..fd2271a3 --- /dev/null +++ b/src/mcp_cli/utils/serialization.py @@ -0,0 +1,93 @@ +# src/mcp_cli/utils/serialization.py +"""Shared helpers for unwrapping and serializing MCP tool results.""" + +from __future__ import annotations + +from typing import Any + +_UNWRAP_MAX_DEPTH = 10 + + +def unwrap_tool_result(obj: Any, *, max_depth: int = _UNWRAP_MAX_DEPTH) -> Any: + """Unwrap middleware ``ToolExecutionResult`` wrappers and MCP result dicts. + + When middleware is enabled, ``ToolManager`` returns the result wrapped in + a ``ToolExecutionResult`` object (from ``chuk_tool_processor.mcp.middleware``). + The inner payload is typically ``{"isError": bool, "content": ToolResult}``. + This peels off those layers to get the actual content. + + Raises ``RuntimeError`` if any wrapper layer reports failure. + """ + depth = 0 + while ( + hasattr(obj, "success") + and hasattr(obj, "result") + and not isinstance(obj, dict) + ): + if depth >= max_depth: + raise RuntimeError(f"Exceeded max unwrap depth ({max_depth})") + if not obj.success: + error = getattr(obj, "error", None) or "Unknown tool error" + raise RuntimeError(error) + obj = obj.result + depth += 1 + + # Unwrap MCP call_tool dict pattern: {"isError": ..., "content": ...} + if isinstance(obj, dict) and "content" in obj and "isError" in obj: + if obj["isError"]: + error_msg = obj.get("error") or obj.get("content") or "Tool returned an error" + if not isinstance(error_msg, str): + error_msg = str(error_msg) + raise RuntimeError(error_msg) + obj = obj["content"] + + return obj + + +def to_serializable(obj: Any) -> Any: + """Convert an object to a JSON-serializable form. + + Handles MCP SDK ``ToolResult``, Pydantic models, and other + non-serializable types. MCP ``ToolResult`` objects (identified by + a ``content`` list attribute) are checked *before* generic Pydantic + ``model_dump()`` so that text content is extracted directly. + """ + if obj is None: + return None + + if isinstance(obj, (str, int, float, bool)): + return obj + + if isinstance(obj, list): + return [to_serializable(item) for item in obj] + + if isinstance(obj, dict): + return {k: to_serializable(v) for k, v in obj.items()} + + # MCP SDK ToolResult (has content *list* of TextContent / ImageContent). + # Must be checked BEFORE generic Pydantic model_dump to extract text. + content = getattr(obj, "content", None) + if isinstance(content, list): + parts: list[Any] = [] + for item in content: + if hasattr(item, "text"): + parts.append(item.text) + elif isinstance(item, dict) and "text" in item: + parts.append(item["text"]) + elif hasattr(item, "model_dump"): + parts.append(to_serializable(item.model_dump())) + else: + parts.append(str(item)) + return parts[0] if len(parts) == 1 else parts + + # Pydantic models (generic fallback) + if hasattr(obj, "model_dump"): + return to_serializable(obj.model_dump()) + if hasattr(obj, "dict"): + return to_serializable(obj.dict()) + + # Non-list .content attribute on non-Pydantic objects + if content is not None: + return to_serializable(content) + + return str(obj) diff --git a/tests/commands/cmd/__init__.py b/tests/commands/cmd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/commands/cmd/test_cmd_command.py b/tests/commands/cmd/test_cmd_command.py new file mode 100644 index 00000000..f7857eec --- /dev/null +++ b/tests/commands/cmd/test_cmd_command.py @@ -0,0 +1,391 @@ +# tests/commands/cmd/test_cmd_command.py +"""Tests for the CmdCommand (commands/cmd/cmd.py).""" + +from __future__ import annotations + +import io +import json + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from mcp_cli.commands.cmd.cmd import CmdCommand, _parse_tool_call +from mcp_cli.commands.base import CommandMode + +# Patch targets (canonical import locations) +_GET_CONTEXT = "mcp_cli.context.get_context" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tool_manager( + execute_result=None, tools_for_llm=None +) -> MagicMock: + tm = MagicMock() + if execute_result is None: + result = MagicMock() + result.success = True + result.error = None + result.result = {"status": "ok"} + execute_result = result + tm.execute_tool = AsyncMock(return_value=execute_result) + tm.get_tools_for_llm = AsyncMock(return_value=tools_for_llm or []) + return tm + + +def _make_context(provider="openai", model="gpt-4") -> MagicMock: + ctx = MagicMock() + ctx.provider = provider + ctx.model = model + ctx.model_manager = MagicMock() + ctx.tool_manager = _make_tool_manager() + return ctx + + +@pytest.fixture +def cmd(): + return CmdCommand() + + +# --------------------------------------------------------------------------- +# Properties +# --------------------------------------------------------------------------- + + +class TestCmdCommandProperties: + def test_name(self, cmd): + assert cmd.name == "cmd" + + def test_description(self, cmd): + assert "automation" in cmd.description.lower() + + def test_modes(self, cmd): + assert cmd.modes == CommandMode.CLI + + def test_requires_context(self, cmd): + assert cmd.requires_context is True + + def test_aliases_empty(self, cmd): + assert cmd.aliases == [] + + def test_parameters_defined(self, cmd): + names = {p.name for p in cmd.parameters} + assert "tool" in names + assert "prompt" in names + assert "input_file" in names + assert "raw" in names + + +# --------------------------------------------------------------------------- +# Execute guards +# --------------------------------------------------------------------------- + + +class TestCmdExecuteGuards: + @pytest.mark.asyncio + async def test_no_args_returns_error(self, cmd): + """Branch 3: no --tool, no --prompt, no --input → error.""" + result = await cmd.execute() + assert result.success is False + assert result.error is not None + assert "No operation specified" in result.error + + @pytest.mark.asyncio + async def test_tool_without_manager_returns_error(self, cmd): + result = await cmd.execute(tool="some_tool", tool_manager=None) + assert result.success is False + assert "Tool manager not available" in result.error + + +# --------------------------------------------------------------------------- +# Tool direct execution +# --------------------------------------------------------------------------- + + +class TestCmdToolDirect: + @pytest.mark.asyncio + async def test_successful_tool_execution(self, cmd): + tm = _make_tool_manager() + with patch("builtins.print") as mock_print: + result = await cmd.execute(tool="echo", tool_manager=tm) + assert result.success is True + assert result.data == {"status": "ok"} + tm.execute_tool.assert_awaited_once_with("echo", {}) + mock_print.assert_called_once() + + @pytest.mark.asyncio + async def test_tool_with_json_args(self, cmd): + tm = _make_tool_manager() + args_json = '{"query": "SELECT 1"}' + with patch("builtins.print"): + result = await cmd.execute( + tool="run_query", tool_args=args_json, tool_manager=tm + ) + assert result.success is True + tm.execute_tool.assert_awaited_once_with( + "run_query", {"query": "SELECT 1"} + ) + + @pytest.mark.asyncio + async def test_tool_with_invalid_json_args(self, cmd): + tm = _make_tool_manager() + result = await cmd.execute( + tool="echo", tool_args="not json", tool_manager=tm + ) + assert result.success is False + assert "Invalid JSON" in result.error + + @pytest.mark.asyncio + async def test_tool_execution_failure(self, cmd): + fail_result = MagicMock() + fail_result.success = False + fail_result.error = "tool crashed" + fail_result.result = None + tm = _make_tool_manager(execute_result=fail_result) + with patch("builtins.print"): + result = await cmd.execute(tool="bad_tool", tool_manager=tm) + assert result.success is False + assert "tool crashed" in result.error + + @pytest.mark.asyncio + async def test_tool_output_to_file(self, cmd, tmp_path): + tm = _make_tool_manager() + out_file = str(tmp_path / "out.json") + result = await cmd.execute( + tool="echo", tool_manager=tm, output_file=out_file + ) + assert result.success is True + content = (tmp_path / "out.json").read_text() + assert "ok" in content + + @pytest.mark.asyncio + async def test_tool_raw_output(self, cmd): + tm = _make_tool_manager() + with patch("builtins.print") as mock_print: + result = await cmd.execute( + tool="echo", tool_manager=tm, raw=True + ) + assert result.success is True + # Raw mode: json.dumps without indent + printed = mock_print.call_args[0][0] + assert "\n" not in printed # compact JSON + + +# --------------------------------------------------------------------------- +# Prompt mode +# --------------------------------------------------------------------------- + + +class TestCmdPromptMode: + @pytest.mark.asyncio + async def test_prompt_no_context_returns_error(self, cmd): + with patch(_GET_CONTEXT, return_value=None): + result = await cmd.execute(prompt="hello", model_manager=None) + assert result.success is False + + @pytest.mark.asyncio + async def test_prompt_basic(self, cmd): + ctx = _make_context() + mock_client = MagicMock() + mock_client.create_completion = AsyncMock( + return_value={"response": "Hello!", "tool_calls": []} + ) + ctx.model_manager.get_client.return_value = mock_client + + with patch(_GET_CONTEXT, return_value=ctx): + with patch("builtins.print") as mock_print: + result = await cmd.execute( + prompt="hi", model_manager=ctx.model_manager + ) + assert result.success is True + assert result.data == "Hello!" + mock_print.assert_called_once_with("Hello!") + + @pytest.mark.asyncio + async def test_stdin_input(self, cmd): + """input_file='-' reads from sys.stdin and passes content to the LLM.""" + ctx = _make_context() + mock_client = MagicMock() + mock_client.create_completion = AsyncMock( + return_value={"response": "got it", "tool_calls": []} + ) + ctx.model_manager.get_client.return_value = mock_client + + with patch(_GET_CONTEXT, return_value=ctx), \ + patch("mcp_cli.commands.cmd.cmd.sys.stdin", new=io.StringIO("stdin data\n")), \ + patch("builtins.print"): + result = await cmd.execute( + input_file="-", model_manager=ctx.model_manager + ) + + assert result.success is True + call_kwargs = mock_client.create_completion.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + assert "stdin data" in messages[-1]["content"] + + @pytest.mark.asyncio + async def test_input_file_not_found(self, cmd): + ctx = _make_context() + with patch(_GET_CONTEXT, return_value=ctx): + result = await cmd.execute( + input_file="/nonexistent/path.txt", + model_manager=ctx.model_manager, + ) + assert result.success is False + assert "not found" in result.error + + @pytest.mark.asyncio + async def test_input_file_read(self, cmd, tmp_path): + input_file = tmp_path / "input.txt" + input_file.write_text("test data") + + ctx = _make_context() + mock_client = MagicMock() + mock_client.create_completion = AsyncMock( + return_value={"response": "analyzed", "tool_calls": []} + ) + ctx.model_manager.get_client.return_value = mock_client + + with patch(_GET_CONTEXT, return_value=ctx): + with patch("builtins.print"): + result = await cmd.execute( + input_file=str(input_file), + model_manager=ctx.model_manager, + ) + assert result.success is True + # Check that input text was used in the prompt + call_kwargs = mock_client.create_completion.call_args + messages = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + assert "test data" in messages[-1]["content"] + + +# --------------------------------------------------------------------------- +# _execute_tool_call_batch — JSON parse error handling (fix #2) +# --------------------------------------------------------------------------- + + +class TestCmdToolCallBatch: + @pytest.mark.asyncio + async def test_invalid_json_in_tool_args_adds_error_message(self, cmd): + """When LLM returns invalid JSON arguments, error is appended to messages.""" + tm = _make_tool_manager() + messages: list[dict] = [] + tool_calls = [ + { + "id": "call_1", + "function": {"name": "echo", "arguments": "not valid json{"}, + } + ] + + await cmd._execute_tool_call_batch(tm, tool_calls, messages, raw=True) + + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + assert "Invalid JSON" in messages[0]["content"] + # Tool should NOT have been called + tm.execute_tool.assert_not_awaited() + + @pytest.mark.asyncio + async def test_valid_tool_call_succeeds(self, cmd): + tm = _make_tool_manager() + messages: list[dict] = [] + tool_calls = [ + { + "id": "call_1", + "function": {"name": "echo", "arguments": '{"msg": "hi"}'}, + } + ] + + await cmd._execute_tool_call_batch(tm, tool_calls, messages, raw=True) + + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + tm.execute_tool.assert_awaited_once_with("echo", {"msg": "hi"}) + + +# --------------------------------------------------------------------------- +# _parse_tool_call (fix #3) +# --------------------------------------------------------------------------- + + +class TestParseToolCall: + def test_dict_format(self): + tc = { + "id": "call_1", + "function": {"name": "echo", "arguments": '{"x": 1}'}, + } + name, args, call_id = _parse_tool_call(tc) + assert name == "echo" + assert args == '{"x": 1}' + assert call_id == "call_1" + + def test_dict_missing_fields(self): + name, args, call_id = _parse_tool_call({}) + assert name == "" + assert args == "{}" + assert call_id == "" + + def test_object_format(self): + tc = MagicMock() + tc.function.name = "search" + tc.function.arguments = '{"q": "test"}' + tc.id = "call_2" + name, args, call_id = _parse_tool_call(tc) + assert name == "search" + assert args == '{"q": "test"}' + assert call_id == "call_2" + + def test_invalid_object_returns_defaults(self): + """Objects without .function attribute return safe defaults.""" + tc = object() # no .function attr + name, args, call_id = _parse_tool_call(tc) + assert name == "" + assert args == "{}" + assert call_id == "" + + +# --------------------------------------------------------------------------- +# Max turns — final response (fix #4) +# --------------------------------------------------------------------------- + + +class TestCmdMaxTurns: + @pytest.mark.asyncio + async def test_max_turns_gets_final_response(self, cmd): + """After hitting max_turns, a final create_completion is called without tools.""" + ctx = _make_context() + tm = _make_tool_manager() + + # With max_turns=1: initial batch runs, while loop skips (1 < 1 = False), + # then final completion is called. + # Call sequence: initial from _execute_prompt_mode → final from _handle_tool_calls + mock_client = MagicMock() + mock_client.create_completion = AsyncMock( + side_effect=[ + # Initial call from _execute_prompt_mode → returns tool_calls + { + "response": "Let me check...", + "tool_calls": [ + {"id": "c1", "function": {"name": "t", "arguments": "{}"}}, + ], + }, + # Final completion after max_turns (no tools param) + {"response": "Final synthesized answer"}, + ] + ) + ctx.model_manager.get_client.return_value = mock_client + + with patch(_GET_CONTEXT, return_value=ctx): + with patch("builtins.print"): + result = await cmd.execute( + prompt="test", + tool_manager=tm, + model_manager=ctx.model_manager, + max_turns=1, + ) + + assert result.success is True + assert result.data == "Final synthesized answer" diff --git a/tests/commands/definitions/test_execute_tool_coverage.py b/tests/commands/definitions/test_execute_tool_coverage.py index 9eebfe9d..a94cdd5d 100644 --- a/tests/commands/definitions/test_execute_tool_coverage.py +++ b/tests/commands/definitions/test_execute_tool_coverage.py @@ -3,7 +3,8 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch -from mcp_cli.commands.tools.execute_tool import ExecuteToolCommand, _to_serializable +from mcp_cli.commands.tools.execute_tool import ExecuteToolCommand +from mcp_cli.utils.serialization import to_serializable as _to_serializable from mcp_cli.tools.models import ToolCallResult diff --git a/tests/utils/test_serialization.py b/tests/utils/test_serialization.py new file mode 100644 index 00000000..88b60663 --- /dev/null +++ b/tests/utils/test_serialization.py @@ -0,0 +1,124 @@ +# tests/utils/test_serialization.py +"""Tests for unwrap_tool_result and to_serializable in mcp_cli.utils.serialization.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from mcp_cli.utils.serialization import to_serializable, unwrap_tool_result + + +class TestUnwrapToolResult: + """Tests for the MCP dict unwrapping path.""" + + def test_success_returns_content(self): + result = unwrap_tool_result({"isError": False, "content": "hello"}) + assert result == "hello" + + def test_error_raises_with_content_fallback(self): + with pytest.raises(RuntimeError, match="boom"): + unwrap_tool_result({"isError": True, "content": "boom"}) + + def test_error_prefers_error_field_over_content(self): + with pytest.raises(RuntimeError, match="specific error"): + unwrap_tool_result( + {"isError": True, "error": "specific error", "content": "fallback"} + ) + + def test_error_without_content_uses_default_message(self): + with pytest.raises(RuntimeError, match="Tool returned an error"): + unwrap_tool_result({"isError": True, "content": ""}) + + def test_passthrough_plain_values(self): + assert unwrap_tool_result("plain") == "plain" + assert unwrap_tool_result(42) == 42 + assert unwrap_tool_result(None) is None + + def test_passthrough_dict_without_iserror(self): + d = {"foo": "bar"} + assert unwrap_tool_result(d) == {"foo": "bar"} + + +class TestUnwrapObjectWrapper: + """Tests for the object-wrapper (ToolExecutionResult) unwrapping path.""" + + def test_object_wrapper_success(self): + wrapper = SimpleNamespace(success=True, result={"data": 42}) + assert unwrap_tool_result(wrapper) == {"data": 42} + + def test_object_wrapper_failure_raises(self): + wrapper = SimpleNamespace(success=False, result=None, error="msg") + with pytest.raises(RuntimeError, match="msg"): + unwrap_tool_result(wrapper) + + def test_object_wrapper_failure_default_message(self): + wrapper = SimpleNamespace(success=False, result=None) + with pytest.raises(RuntimeError, match="Unknown tool error"): + unwrap_tool_result(wrapper) + + def test_nested_object_wrappers(self): + inner = SimpleNamespace(success=True, result="payload") + outer = SimpleNamespace(success=True, result=inner) + assert unwrap_tool_result(outer) == "payload" + + def test_max_depth_exceeded_raises(self): + # Build 3 levels of nesting, but allow only 2 + level2 = SimpleNamespace(success=True, result="deep") + level1 = SimpleNamespace(success=True, result=level2) + level0 = SimpleNamespace(success=True, result=level1) + with pytest.raises(RuntimeError, match="Exceeded max unwrap depth"): + unwrap_tool_result(level0, max_depth=2) + + def test_iserror_list_content_stringified(self): + """isError=True with list content should stringify the list.""" + obj = {"isError": True, "content": ["err1", "err2"]} + with pytest.raises(RuntimeError, match=r"\['err1', 'err2'\]"): + unwrap_tool_result(obj) + + +class TestToSerializable: + """Tests for the to_serializable helper.""" + + def test_primitives(self): + assert to_serializable(None) is None + assert to_serializable("hello") == "hello" + assert to_serializable(42) == 42 + assert to_serializable(3.14) == 3.14 + assert to_serializable(True) is True + + def test_list_and_dict(self): + assert to_serializable([1, "a", None]) == [1, "a", None] + assert to_serializable({"k": [1, 2]}) == {"k": [1, 2]} + + def test_pydantic_model_dump(self): + obj = MagicMock() + obj.model_dump.return_value = {"field": "value"} + del obj.dict + # content attr is not a list, so it won't hit the MCP path first + obj.content = None + result = to_serializable(obj) + assert result == {"field": "value"} + + def test_mcp_tool_result_single_text(self): + item = SimpleNamespace(text="only line") + obj = SimpleNamespace(content=[item]) + assert to_serializable(obj) == "only line" + + def test_mcp_tool_result_multiple_text(self): + items = [SimpleNamespace(text="a"), SimpleNamespace(text="b")] + obj = SimpleNamespace(content=items) + assert to_serializable(obj) == ["a", "b"] + + def test_dict_content_items(self): + obj = SimpleNamespace(content=[{"text": "from_dict"}]) + assert to_serializable(obj) == "from_dict" + + def test_fallback_to_str(self): + class Custom: + def __str__(self): + return "custom_repr" + + assert to_serializable(Custom()) == "custom_repr"