|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import argparse |
| 4 | +import asyncio |
| 5 | +import asyncio.subprocess |
| 6 | +import contextlib |
| 7 | +import json |
| 8 | +import os |
| 9 | +import shutil |
| 10 | +import sys |
| 11 | +from pathlib import Path |
| 12 | +from typing import Iterable |
| 13 | + |
| 14 | +from acp import ( |
| 15 | + Client, |
| 16 | + ClientSideConnection, |
| 17 | + PROTOCOL_VERSION, |
| 18 | + RequestError, |
| 19 | +) |
| 20 | +from acp.schema import ( |
| 21 | + AgentMessageChunk, |
| 22 | + AgentPlanUpdate, |
| 23 | + AgentThoughtChunk, |
| 24 | + AllowedOutcome, |
| 25 | + CancelNotification, |
| 26 | + ClientCapabilities, |
| 27 | + FileEditToolCallContent, |
| 28 | + FileSystemCapability, |
| 29 | + CreateTerminalRequest, |
| 30 | + CreateTerminalResponse, |
| 31 | + DeniedOutcome, |
| 32 | + EmbeddedResourceContentBlock, |
| 33 | + KillTerminalCommandRequest, |
| 34 | + KillTerminalCommandResponse, |
| 35 | + InitializeRequest, |
| 36 | + NewSessionRequest, |
| 37 | + PermissionOption, |
| 38 | + PromptRequest, |
| 39 | + ReadTextFileRequest, |
| 40 | + ReadTextFileResponse, |
| 41 | + RequestPermissionRequest, |
| 42 | + RequestPermissionResponse, |
| 43 | + ResourceContentBlock, |
| 44 | + ReleaseTerminalRequest, |
| 45 | + ReleaseTerminalResponse, |
| 46 | + SessionNotification, |
| 47 | + TerminalToolCallContent, |
| 48 | + TerminalOutputRequest, |
| 49 | + TerminalOutputResponse, |
| 50 | + TextContentBlock, |
| 51 | + ToolCallProgress, |
| 52 | + ToolCallStart, |
| 53 | + UserMessageChunk, |
| 54 | + WaitForTerminalExitRequest, |
| 55 | + WaitForTerminalExitResponse, |
| 56 | + WriteTextFileRequest, |
| 57 | + WriteTextFileResponse, |
| 58 | +) |
| 59 | + |
| 60 | + |
| 61 | +class GeminiClient(Client): |
| 62 | + """Minimal client implementation that can drive the Gemini CLI over ACP.""" |
| 63 | + |
| 64 | + def __init__(self, auto_approve: bool) -> None: |
| 65 | + self._auto_approve = auto_approve |
| 66 | + |
| 67 | + async def requestPermission( |
| 68 | + self, |
| 69 | + params: RequestPermissionRequest, |
| 70 | + ) -> RequestPermissionResponse: # type: ignore[override] |
| 71 | + if self._auto_approve: |
| 72 | + option = _pick_preferred_option(params.options) |
| 73 | + if option is None: |
| 74 | + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) |
| 75 | + return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) |
| 76 | + |
| 77 | + title = params.toolCall.title or "<permission>" |
| 78 | + if not params.options: |
| 79 | + print(f"\n🔐 Permission requested: {title} (no options, cancelling)") |
| 80 | + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) |
| 81 | + print(f"\n🔐 Permission requested: {title}") |
| 82 | + for idx, opt in enumerate(params.options, start=1): |
| 83 | + print(f" {idx}. {opt.name} ({opt.kind})") |
| 84 | + |
| 85 | + loop = asyncio.get_running_loop() |
| 86 | + while True: |
| 87 | + choice = await loop.run_in_executor(None, lambda: input("Select option: ").strip()) |
| 88 | + if not choice: |
| 89 | + continue |
| 90 | + if choice.isdigit(): |
| 91 | + idx = int(choice) - 1 |
| 92 | + if 0 <= idx < len(params.options): |
| 93 | + opt = params.options[idx] |
| 94 | + return RequestPermissionResponse(outcome=AllowedOutcome(optionId=opt.optionId, outcome="selected")) |
| 95 | + print("Invalid selection, try again.") |
| 96 | + |
| 97 | + async def writeTextFile( |
| 98 | + self, |
| 99 | + params: WriteTextFileRequest, |
| 100 | + ) -> WriteTextFileResponse: # type: ignore[override] |
| 101 | + path = Path(params.path) |
| 102 | + if not path.is_absolute(): |
| 103 | + raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) |
| 104 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 105 | + path.write_text(params.content) |
| 106 | + print(f"[Client] Wrote {path} ({len(params.content)} bytes)") |
| 107 | + return WriteTextFileResponse() |
| 108 | + |
| 109 | + async def readTextFile( |
| 110 | + self, |
| 111 | + params: ReadTextFileRequest, |
| 112 | + ) -> ReadTextFileResponse: # type: ignore[override] |
| 113 | + path = Path(params.path) |
| 114 | + if not path.is_absolute(): |
| 115 | + raise RequestError.invalid_params({"path": params.path, "reason": "path must be absolute"}) |
| 116 | + text = path.read_text() |
| 117 | + print(f"[Client] Read {path} ({len(text)} bytes)") |
| 118 | + if params.line is not None or params.limit is not None: |
| 119 | + text = _slice_text(text, params.line, params.limit) |
| 120 | + return ReadTextFileResponse(content=text) |
| 121 | + |
| 122 | + async def sessionUpdate( |
| 123 | + self, |
| 124 | + params: SessionNotification, |
| 125 | + ) -> None: # type: ignore[override] |
| 126 | + update = params.update |
| 127 | + if isinstance(update, AgentMessageChunk): |
| 128 | + _print_text_content(update.content) |
| 129 | + elif isinstance(update, AgentThoughtChunk): |
| 130 | + print("\n[agent_thought]") |
| 131 | + _print_text_content(update.content) |
| 132 | + elif isinstance(update, UserMessageChunk): |
| 133 | + print("\n[user_message]") |
| 134 | + _print_text_content(update.content) |
| 135 | + elif isinstance(update, AgentPlanUpdate): |
| 136 | + print("\n[plan]") |
| 137 | + for entry in update.entries: |
| 138 | + print(f" - {entry.status.upper():<10} {entry.content}") |
| 139 | + elif isinstance(update, ToolCallStart): |
| 140 | + print(f"\n🔧 {update.title} ({update.status or 'pending'})") |
| 141 | + elif isinstance(update, ToolCallProgress): |
| 142 | + status = update.status or "in_progress" |
| 143 | + print(f"\n🔧 Tool call `{update.toolCallId}` → {status}") |
| 144 | + if update.content: |
| 145 | + for item in update.content: |
| 146 | + if isinstance(item, FileEditToolCallContent): |
| 147 | + print(f" diff: {item.path}") |
| 148 | + elif isinstance(item, TerminalToolCallContent): |
| 149 | + print(f" terminal: {item.terminalId}") |
| 150 | + elif isinstance(item, dict): |
| 151 | + print(f" content: {json.dumps(item, indent=2)}") |
| 152 | + else: |
| 153 | + print(f"\n[session update] {update}") |
| 154 | + |
| 155 | + # Optional / terminal-related methods --------------------------------- |
| 156 | + async def createTerminal( |
| 157 | + self, |
| 158 | + params: CreateTerminalRequest, |
| 159 | + ) -> CreateTerminalResponse: # type: ignore[override] |
| 160 | + print(f"[Client] createTerminal: {params}") |
| 161 | + return CreateTerminalResponse(terminalId="term-1") |
| 162 | + |
| 163 | + async def terminalOutput( |
| 164 | + self, |
| 165 | + params: TerminalOutputRequest, |
| 166 | + ) -> TerminalOutputResponse: # type: ignore[override] |
| 167 | + print(f"[Client] terminalOutput: {params}") |
| 168 | + return TerminalOutputResponse(output="", truncated=False) |
| 169 | + |
| 170 | + async def releaseTerminal( |
| 171 | + self, |
| 172 | + params: ReleaseTerminalRequest, |
| 173 | + ) -> ReleaseTerminalResponse: # type: ignore[override] |
| 174 | + print(f"[Client] releaseTerminal: {params}") |
| 175 | + return ReleaseTerminalResponse() |
| 176 | + |
| 177 | + async def waitForTerminalExit( |
| 178 | + self, |
| 179 | + params: WaitForTerminalExitRequest, |
| 180 | + ) -> WaitForTerminalExitResponse: # type: ignore[override] |
| 181 | + print(f"[Client] waitForTerminalExit: {params}") |
| 182 | + return WaitForTerminalExitResponse() |
| 183 | + |
| 184 | + async def killTerminal( |
| 185 | + self, |
| 186 | + params: KillTerminalCommandRequest, |
| 187 | + ) -> KillTerminalCommandResponse: # type: ignore[override] |
| 188 | + print(f"[Client] killTerminal: {params}") |
| 189 | + return KillTerminalCommandResponse() |
| 190 | + |
| 191 | + |
| 192 | +def _pick_preferred_option(options: Iterable[PermissionOption]) -> PermissionOption | None: |
| 193 | + best: PermissionOption | None = None |
| 194 | + for option in options: |
| 195 | + if option.kind in {"allow_once", "allow_always"}: |
| 196 | + return option |
| 197 | + best = best or option |
| 198 | + return best |
| 199 | + |
| 200 | + |
| 201 | +def _slice_text(content: str, line: int | None, limit: int | None) -> str: |
| 202 | + lines = content.splitlines() |
| 203 | + start = 0 |
| 204 | + if line: |
| 205 | + start = max(line - 1, 0) |
| 206 | + end = len(lines) |
| 207 | + if limit: |
| 208 | + end = min(start + limit, end) |
| 209 | + return "\n".join(lines[start:end]) |
| 210 | + |
| 211 | + |
| 212 | +def _print_text_content(content: object) -> None: |
| 213 | + if isinstance(content, TextContentBlock): |
| 214 | + print(content.text) |
| 215 | + elif isinstance(content, ResourceContentBlock): |
| 216 | + print(f"{content.name or content.uri}") |
| 217 | + elif isinstance(content, EmbeddedResourceContentBlock): |
| 218 | + resource = content.resource |
| 219 | + text = getattr(resource, "text", None) |
| 220 | + if text: |
| 221 | + print(text) |
| 222 | + else: |
| 223 | + blob = getattr(resource, "blob", None) |
| 224 | + print(blob if blob else "<embedded resource>") |
| 225 | + elif isinstance(content, dict): |
| 226 | + text = content.get("text") # type: ignore[union-attr] |
| 227 | + if text: |
| 228 | + print(text) |
| 229 | + |
| 230 | + |
| 231 | +async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None: |
| 232 | + print("Type a message and press Enter to send.") |
| 233 | + print("Commands: :cancel, :exit") |
| 234 | + |
| 235 | + loop = asyncio.get_running_loop() |
| 236 | + while True: |
| 237 | + try: |
| 238 | + line = await loop.run_in_executor(None, lambda: input("\n> ").strip()) |
| 239 | + except (EOFError, KeyboardInterrupt): |
| 240 | + print("\nExiting.") |
| 241 | + break |
| 242 | + |
| 243 | + if not line: |
| 244 | + continue |
| 245 | + if line in {":exit", ":quit"}: |
| 246 | + break |
| 247 | + if line == ":cancel": |
| 248 | + await conn.cancel(CancelNotification(sessionId=session_id)) |
| 249 | + continue |
| 250 | + |
| 251 | + try: |
| 252 | + await conn.prompt( |
| 253 | + PromptRequest( |
| 254 | + sessionId=session_id, |
| 255 | + prompt=[TextContentBlock(type="text", text=line)], |
| 256 | + ) |
| 257 | + ) |
| 258 | + except RequestError as err: |
| 259 | + _print_request_error("prompt", err) |
| 260 | + except Exception as exc: # noqa: BLE001 |
| 261 | + print(f"Prompt failed: {exc}", file=sys.stderr) |
| 262 | + |
| 263 | + |
| 264 | +def _resolve_gemini_cli(binary: str | None) -> str: |
| 265 | + if binary: |
| 266 | + return binary |
| 267 | + env_value = os.environ.get("ACP_GEMINI_BIN") |
| 268 | + if env_value: |
| 269 | + return env_value |
| 270 | + resolved = shutil.which("gemini") |
| 271 | + if resolved: |
| 272 | + return resolved |
| 273 | + raise FileNotFoundError("Unable to locate `gemini` CLI, provide --gemini path") |
| 274 | + |
| 275 | + |
| 276 | +async def run(argv: list[str]) -> int: |
| 277 | + parser = argparse.ArgumentParser(description="Interact with the Gemini CLI over ACP.") |
| 278 | + parser.add_argument("--gemini", help="Path to the Gemini CLI binary") |
| 279 | + parser.add_argument("--model", help="Model identifier to pass to Gemini") |
| 280 | + parser.add_argument("--sandbox", action="store_true", help="Enable Gemini sandbox mode") |
| 281 | + parser.add_argument("--debug", action="store_true", help="Pass --debug to Gemini") |
| 282 | + parser.add_argument("--yolo", action="store_true", help="Auto-approve permission prompts") |
| 283 | + args = parser.parse_args(argv[1:]) |
| 284 | + |
| 285 | + try: |
| 286 | + gemini_path = _resolve_gemini_cli(args.gemini) |
| 287 | + except FileNotFoundError as exc: |
| 288 | + print(exc, file=sys.stderr) |
| 289 | + return 1 |
| 290 | + |
| 291 | + cmd = [gemini_path, "--experimental-acp"] |
| 292 | + if args.model: |
| 293 | + cmd += ["--model", args.model] |
| 294 | + if args.sandbox: |
| 295 | + cmd.append("--sandbox") |
| 296 | + if args.debug: |
| 297 | + cmd.append("--debug") |
| 298 | + |
| 299 | + try: |
| 300 | + proc = await asyncio.create_subprocess_exec( |
| 301 | + *cmd, |
| 302 | + stdin=asyncio.subprocess.PIPE, |
| 303 | + stdout=asyncio.subprocess.PIPE, |
| 304 | + stderr=None, |
| 305 | + ) |
| 306 | + except FileNotFoundError as exc: |
| 307 | + print(f"Failed to start Gemini CLI: {exc}", file=sys.stderr) |
| 308 | + return 1 |
| 309 | + |
| 310 | + if proc.stdin is None or proc.stdout is None: |
| 311 | + print("Gemini process did not expose stdio pipes.", file=sys.stderr) |
| 312 | + proc.terminate() |
| 313 | + with contextlib.suppress(ProcessLookupError): |
| 314 | + await proc.wait() |
| 315 | + return 1 |
| 316 | + |
| 317 | + client_impl = GeminiClient(auto_approve=args.yolo) |
| 318 | + conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout) |
| 319 | + |
| 320 | + try: |
| 321 | + init_resp = await conn.initialize( |
| 322 | + InitializeRequest( |
| 323 | + protocolVersion=PROTOCOL_VERSION, |
| 324 | + clientCapabilities=ClientCapabilities( |
| 325 | + fs=FileSystemCapability(readTextFile=True, writeTextFile=True), |
| 326 | + terminal=True, |
| 327 | + ), |
| 328 | + ) |
| 329 | + ) |
| 330 | + except RequestError as err: |
| 331 | + _print_request_error("initialize", err) |
| 332 | + await _shutdown(proc, conn) |
| 333 | + return 1 |
| 334 | + except Exception as exc: # noqa: BLE001 |
| 335 | + print(f"initialize error: {exc}", file=sys.stderr) |
| 336 | + await _shutdown(proc, conn) |
| 337 | + return 1 |
| 338 | + |
| 339 | + print(f"✅ Connected to Gemini (protocol v{init_resp.protocolVersion})") |
| 340 | + |
| 341 | + try: |
| 342 | + session = await conn.newSession( |
| 343 | + NewSessionRequest( |
| 344 | + cwd=os.getcwd(), |
| 345 | + mcpServers=[], |
| 346 | + ) |
| 347 | + ) |
| 348 | + except RequestError as err: |
| 349 | + _print_request_error("new_session", err) |
| 350 | + await _shutdown(proc, conn) |
| 351 | + return 1 |
| 352 | + except Exception as exc: # noqa: BLE001 |
| 353 | + print(f"new_session error: {exc}", file=sys.stderr) |
| 354 | + await _shutdown(proc, conn) |
| 355 | + return 1 |
| 356 | + |
| 357 | + print(f"📝 Created session: {session.sessionId}") |
| 358 | + |
| 359 | + try: |
| 360 | + await interactive_loop(conn, session.sessionId) |
| 361 | + finally: |
| 362 | + await _shutdown(proc, conn) |
| 363 | + |
| 364 | + return 0 |
| 365 | + |
| 366 | + |
| 367 | +def _print_request_error(stage: str, err: RequestError) -> None: |
| 368 | + payload = err.to_error_obj() |
| 369 | + message = payload.get("message", "") |
| 370 | + code = payload.get("code") |
| 371 | + print(f"{stage} error ({code}): {message}", file=sys.stderr) |
| 372 | + data = payload.get("data") |
| 373 | + if data is not None: |
| 374 | + try: |
| 375 | + formatted = json.dumps(data, indent=2) |
| 376 | + except TypeError: |
| 377 | + formatted = str(data) |
| 378 | + print(formatted, file=sys.stderr) |
| 379 | + |
| 380 | + |
| 381 | +async def _shutdown(proc: asyncio.subprocess.Process, conn: ClientSideConnection) -> None: |
| 382 | + with contextlib.suppress(Exception): |
| 383 | + await conn.close() |
| 384 | + if proc.returncode is None: |
| 385 | + proc.terminate() |
| 386 | + try: |
| 387 | + await asyncio.wait_for(proc.wait(), timeout=5) |
| 388 | + except asyncio.TimeoutError: |
| 389 | + proc.kill() |
| 390 | + await proc.wait() |
| 391 | + |
| 392 | + |
| 393 | +def main(argv: list[str] | None = None) -> int: |
| 394 | + args = sys.argv if argv is None else argv |
| 395 | + return asyncio.run(run(list(args))) |
| 396 | + |
| 397 | + |
| 398 | +if __name__ == "__main__": |
| 399 | + raise SystemExit(main()) |
0 commit comments