Skip to content

Commit 939158c

Browse files
author
Chojan Shang
committed
feat: align impl with spec and tests
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent ce4bf5d commit 939158c

File tree

3 files changed

+139
-44
lines changed

3 files changed

+139
-44
lines changed

src/acp/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
KillTerminalCommandRequest,
2323
KillTerminalCommandResponse,
2424
LoadSessionRequest,
25+
LoadSessionResponse,
2526
NewSessionRequest,
2627
NewSessionResponse,
2728
PromptRequest,
@@ -33,6 +34,8 @@
3334
RequestPermissionRequest,
3435
RequestPermissionResponse,
3536
SessionNotification,
37+
SetSessionModelRequest,
38+
SetSessionModelResponse,
3639
SetSessionModeRequest,
3740
SetSessionModeResponse,
3841
TerminalOutputRequest,
@@ -55,6 +58,7 @@
5558
"NewSessionRequest",
5659
"NewSessionResponse",
5760
"LoadSessionRequest",
61+
"LoadSessionResponse",
5862
"AuthenticateRequest",
5963
"AuthenticateResponse",
6064
"PromptRequest",
@@ -69,6 +73,8 @@
6973
"SessionNotification",
7074
"SetSessionModeRequest",
7175
"SetSessionModeResponse",
76+
"SetSessionModelRequest",
77+
"SetSessionModelResponse",
7278
# terminal types
7379
"CreateTerminalRequest",
7480
"CreateTerminalResponse",

src/acp/core.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
KillTerminalCommandRequest,
2323
KillTerminalCommandResponse,
2424
LoadSessionRequest,
25+
LoadSessionResponse,
2526
NewSessionRequest,
2627
NewSessionResponse,
2728
PromptRequest,
@@ -33,6 +34,8 @@
3334
RequestPermissionRequest,
3435
RequestPermissionResponse,
3536
SessionNotification,
37+
SetSessionModelRequest,
38+
SetSessionModelResponse,
3639
SetSessionModeRequest,
3740
SetSessionModeResponse,
3841
TerminalOutputRequest,
@@ -79,6 +82,11 @@ def internal_error(data: dict | None = None) -> RequestError:
7982
def auth_required(data: dict | None = None) -> RequestError:
8083
return RequestError(-32000, "Authentication required", data)
8184

85+
@staticmethod
86+
def resource_not_found(uri: str | None = None) -> RequestError:
87+
data = {"uri": uri} if uri is not None else None
88+
return RequestError(-32002, "Resource not found", data)
89+
8290
def to_error_obj(self) -> dict:
8391
return {"code": self.code, "message": str(self), "data": self.data}
8492

@@ -253,16 +261,18 @@ async def initialize(self, params: InitializeRequest) -> InitializeResponse: ...
253261

254262
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: ...
255263

256-
async def loadSession(self, params: LoadSessionRequest) -> None: ...
264+
async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: ...
265+
266+
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: ...
267+
268+
async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse | None: ...
257269

258270
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: ...
259271

260272
async def prompt(self, params: PromptRequest) -> PromptResponse: ...
261273

262274
async def cancel(self, params: CancelNotification) -> None: ...
263275

264-
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: ...
265-
266276
# Extension hooks (optional)
267277
async def extMethod(self, method: str, params: dict) -> dict: ...
268278

@@ -332,7 +342,10 @@ async def _handle_agent_session_methods(self, agent: Agent, method: str, params:
332342
if not hasattr(agent, "loadSession"):
333343
raise RequestError.method_not_found(method)
334344
p = LoadSessionRequest.model_validate(params)
335-
return await agent.loadSession(p)
345+
result = await agent.loadSession(p)
346+
if isinstance(result, BaseModel):
347+
return result.model_dump()
348+
return result or {}
336349
if method == AGENT_METHODS["session_set_mode"]:
337350
if not hasattr(agent, "setSessionMode"):
338351
raise RequestError.method_not_found(method)
@@ -342,6 +355,12 @@ async def _handle_agent_session_methods(self, agent: Agent, method: str, params:
342355
if method == AGENT_METHODS["session_prompt"]:
343356
p = PromptRequest.model_validate(params)
344357
return await agent.prompt(p)
358+
if method == AGENT_METHODS["session_set_model"]:
359+
if not hasattr(agent, "setSessionModel"):
360+
raise RequestError.method_not_found(method)
361+
p = SetSessionModelRequest.model_validate(params)
362+
result = await agent.setSessionModel(p)
363+
return result.model_dump() if isinstance(result, BaseModel) else (result or {})
345364
if method == AGENT_METHODS["session_cancel"]:
346365
p = CancelNotification.model_validate(params)
347366
return await agent.cancel(p)
@@ -548,26 +567,37 @@ async def newSession(self, params: NewSessionRequest) -> NewSessionResponse:
548567
)
549568
return NewSessionResponse.model_validate(resp)
550569

551-
async def loadSession(self, params: LoadSessionRequest) -> None:
552-
await self._conn.send_request(
570+
async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse:
571+
resp = await self._conn.send_request(
553572
AGENT_METHODS["session_load"],
554573
params.model_dump(exclude_none=True, exclude_defaults=True),
555574
)
575+
payload = resp if isinstance(resp, dict) else {}
576+
return LoadSessionResponse.model_validate(payload)
556577

557-
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None:
578+
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse:
558579
resp = await self._conn.send_request(
559580
AGENT_METHODS["session_set_mode"],
560581
params.model_dump(exclude_none=True, exclude_defaults=True),
561582
)
562-
# May be empty object
563-
return SetSessionModeResponse.model_validate(resp) if isinstance(resp, dict) else None
583+
payload = resp if isinstance(resp, dict) else {}
584+
return SetSessionModeResponse.model_validate(payload)
585+
586+
async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse:
587+
resp = await self._conn.send_request(
588+
AGENT_METHODS["session_set_model"],
589+
params.model_dump(exclude_none=True, exclude_defaults=True),
590+
)
591+
payload = resp if isinstance(resp, dict) else {}
592+
return SetSessionModelResponse.model_validate(payload)
564593

565-
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None:
594+
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse:
566595
resp = await self._conn.send_request(
567596
AGENT_METHODS["authenticate"],
568597
params.model_dump(exclude_none=True, exclude_defaults=True),
569598
)
570-
return AuthenticateResponse.model_validate(resp) if isinstance(resp, dict) else None
599+
payload = resp if isinstance(resp, dict) else {}
600+
return AuthenticateResponse.model_validate(payload)
571601

572602
async def prompt(self, params: PromptRequest) -> PromptResponse:
573603
resp = await self._conn.send_request(
@@ -609,16 +639,18 @@ async def wait_for_exit(self) -> WaitForTerminalExitResponse:
609639
)
610640
return WaitForTerminalExitResponse.model_validate(resp)
611641

612-
async def kill(self) -> KillTerminalCommandResponse | None:
642+
async def kill(self) -> KillTerminalCommandResponse:
613643
resp = await self._conn.send_request(
614644
CLIENT_METHODS["terminal_kill"],
615645
{"sessionId": self._session_id, "terminalId": self.id},
616646
)
617-
return KillTerminalCommandResponse.model_validate(resp) if isinstance(resp, dict) else None
647+
payload = resp if isinstance(resp, dict) else {}
648+
return KillTerminalCommandResponse.model_validate(payload)
618649

619-
async def release(self) -> ReleaseTerminalResponse | None:
650+
async def release(self) -> ReleaseTerminalResponse:
620651
resp = await self._conn.send_request(
621652
CLIENT_METHODS["terminal_release"],
622653
{"sessionId": self._session_id, "terminalId": self.id},
623654
)
624-
return ReleaseTerminalResponse.model_validate(resp) if isinstance(resp, dict) else None
655+
payload = resp if isinstance(resp, dict) else {}
656+
return ReleaseTerminalResponse.model_validate(payload)

tests/test_rpc.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,39 @@
77
from acp import (
88
Agent,
99
AgentSideConnection,
10+
AuthenticateRequest,
11+
AuthenticateResponse,
1012
CancelNotification,
1113
Client,
1214
ClientSideConnection,
1315
InitializeRequest,
1416
InitializeResponse,
1517
LoadSessionRequest,
18+
LoadSessionResponse,
1619
NewSessionRequest,
1720
NewSessionResponse,
1821
PromptRequest,
1922
PromptResponse,
2023
ReadTextFileRequest,
2124
ReadTextFileResponse,
25+
RequestError,
2226
RequestPermissionRequest,
2327
RequestPermissionResponse,
2428
SessionNotification,
29+
SetSessionModelRequest,
30+
SetSessionModelResponse,
2531
SetSessionModeRequest,
32+
SetSessionModeResponse,
2633
WriteTextFileRequest,
34+
WriteTextFileResponse,
35+
)
36+
from acp.schema import (
37+
ContentBlock1,
38+
RequestPermissionOutcome1,
39+
RequestPermissionOutcome2,
40+
SessionUpdate1,
41+
SessionUpdate2,
2742
)
28-
from acp.schema import ContentBlock1, SessionUpdate1, SessionUpdate2
2943

3044
# --------------------- Test Utilities ---------------------
3145

@@ -77,18 +91,30 @@ class TestClient(Client):
7791
__test__ = False # prevent pytest from collecting this class
7892

7993
def __init__(self) -> None:
80-
self.permission_outcomes: list[dict] = []
94+
self.permission_outcomes: list[RequestPermissionResponse] = []
8195
self.files: dict[str, str] = {}
8296
self.notifications: list[SessionNotification] = []
8397
self.ext_calls: list[tuple[str, dict]] = []
8498
self.ext_notes: list[tuple[str, dict]] = []
8599

100+
def queue_permission_cancelled(self) -> None:
101+
self.permission_outcomes.append(
102+
RequestPermissionResponse(outcome=RequestPermissionOutcome1(outcome="cancelled"))
103+
)
104+
105+
def queue_permission_selected(self, option_id: str) -> None:
106+
self.permission_outcomes.append(
107+
RequestPermissionResponse(outcome=RequestPermissionOutcome2(optionId=option_id, outcome="selected"))
108+
)
109+
86110
async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse:
87-
outcome = self.permission_outcomes.pop() if self.permission_outcomes else {"outcome": "cancelled"}
88-
return RequestPermissionResponse.model_validate({"outcome": outcome})
111+
if self.permission_outcomes:
112+
return self.permission_outcomes.pop()
113+
return RequestPermissionResponse(outcome=RequestPermissionOutcome1(outcome="cancelled"))
89114

90-
async def writeTextFile(self, params: WriteTextFileRequest) -> None:
115+
async def writeTextFile(self, params: WriteTextFileRequest) -> WriteTextFileResponse:
91116
self.files[str(params.path)] = params.content
117+
return WriteTextFileResponse()
92118

93119
async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse:
94120
content = self.files.get(str(params.path), "default content")
@@ -98,24 +124,26 @@ async def sessionUpdate(self, params: SessionNotification) -> None:
98124
self.notifications.append(params)
99125

100126
# Optional terminal methods (not implemented in this test client)
101-
async def createTerminal(self, params) -> None: # pragma: no cover - placeholder
102-
pass
127+
async def createTerminal(self, params): # pragma: no cover - placeholder
128+
raise NotImplementedError
103129

104-
async def terminalOutput(self, params) -> None: # pragma: no cover - placeholder
105-
pass
130+
async def terminalOutput(self, params): # pragma: no cover - placeholder
131+
raise NotImplementedError
106132

107-
async def releaseTerminal(self, params) -> None: # pragma: no cover - placeholder
108-
pass
133+
async def releaseTerminal(self, params): # pragma: no cover - placeholder
134+
raise NotImplementedError
109135

110-
async def waitForTerminalExit(self, params) -> None: # pragma: no cover - placeholder
111-
pass
136+
async def waitForTerminalExit(self, params): # pragma: no cover - placeholder
137+
raise NotImplementedError
112138

113-
async def killTerminal(self, params) -> None: # pragma: no cover - placeholder
114-
pass
139+
async def killTerminal(self, params): # pragma: no cover - placeholder
140+
raise NotImplementedError
115141

116142
async def extMethod(self, method: str, params: dict) -> dict:
117143
self.ext_calls.append((method, params))
118-
return {"ok": True, "method": method}
144+
if method == "example.com/ping":
145+
return {"response": "pong", "params": params}
146+
raise RequestError.method_not_found(method)
119147

120148
async def extNotification(self, method: str, params: dict) -> None:
121149
self.ext_notes.append((method, params))
@@ -137,11 +165,11 @@ async def initialize(self, params: InitializeRequest) -> InitializeResponse:
137165
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse:
138166
return NewSessionResponse(sessionId="test-session-123")
139167

140-
async def loadSession(self, params: LoadSessionRequest) -> None:
141-
return None
168+
async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse:
169+
return LoadSessionResponse()
142170

143-
async def authenticate(self, params) -> None:
144-
return None
171+
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse:
172+
return AuthenticateResponse()
145173

146174
async def prompt(self, params: PromptRequest) -> PromptResponse:
147175
self.prompts.append(params)
@@ -150,12 +178,17 @@ async def prompt(self, params: PromptRequest) -> PromptResponse:
150178
async def cancel(self, params: CancelNotification) -> None:
151179
self.cancellations.append(params.sessionId)
152180

153-
async def setSessionMode(self, params):
154-
return {}
181+
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse:
182+
return SetSessionModeResponse()
183+
184+
async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse:
185+
return SetSessionModelResponse()
155186

156187
async def extMethod(self, method: str, params: dict) -> dict:
157188
self.ext_calls.append((method, params))
158-
return {"ok": True, "method": method}
189+
if method == "example.com/echo":
190+
return {"echo": params}
191+
raise RequestError.method_not_found(method)
159192

160193
async def extNotification(self, method: str, params: dict) -> None:
161194
self.ext_notes.append((method, params))
@@ -180,6 +213,22 @@ async def test_initialize_and_new_session():
180213
new_sess = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/test"))
181214
assert new_sess.sessionId == "test-session-123"
182215

216+
load_resp = await agent_conn.loadSession(
217+
LoadSessionRequest(sessionId=new_sess.sessionId, cwd="/test", mcpServers=[])
218+
)
219+
assert isinstance(load_resp, LoadSessionResponse)
220+
221+
auth_resp = await agent_conn.authenticate(AuthenticateRequest(methodId="password"))
222+
assert isinstance(auth_resp, AuthenticateResponse)
223+
224+
mode_resp = await agent_conn.setSessionMode(SetSessionModeRequest(sessionId=new_sess.sessionId, modeId="ask"))
225+
assert isinstance(mode_resp, SetSessionModeResponse)
226+
227+
model_resp = await agent_conn.setSessionModel(
228+
SetSessionModelRequest(sessionId=new_sess.sessionId, modelId="gpt-4o")
229+
)
230+
assert isinstance(model_resp, SetSessionModelResponse)
231+
183232

184233
@pytest.mark.asyncio
185234
async def test_bidirectional_file_ops():
@@ -195,9 +244,10 @@ async def test_bidirectional_file_ops():
195244
assert res.content == "Hello, World!"
196245

197246
# Agent asks client to write
198-
await client_conn.writeTextFile(
247+
write_result = await client_conn.writeTextFile(
199248
WriteTextFileRequest(sessionId="sess", path="/test/file.txt", content="Updated")
200249
)
250+
assert isinstance(write_result, WriteTextFileResponse)
201251
assert client.files["/test/file.txt"] == "Updated"
202252

203253

@@ -319,23 +369,30 @@ async def test_set_session_mode_and_extensions():
319369
agent = TestAgent()
320370
client = TestClient()
321371
agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader)
322-
_client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader)
372+
client_conn = AgentSideConnection(lambda _conn: agent, s.server_writer, s.server_reader)
323373

324374
# setSessionMode
325375
resp = await agent_conn.setSessionMode(SetSessionModeRequest(sessionId="sess", modeId="yolo"))
326-
# Either empty object or typed response depending on implementation
327-
assert resp is None or resp.__class__.__name__ == "SetSessionModeResponse"
376+
assert isinstance(resp, SetSessionModeResponse)
377+
378+
model_resp = await agent_conn.setSessionModel(SetSessionModelRequest(sessionId="sess", modelId="gpt-4o-mini"))
379+
assert isinstance(model_resp, SetSessionModelResponse)
328380

329381
# extMethod
330-
res = await agent_conn.extMethod("ping", {"x": 1})
331-
assert res.get("ok") is True
382+
echo = await agent_conn.extMethod("example.com/echo", {"x": 1})
383+
assert echo == {"echo": {"x": 1}}
332384

333385
# extNotification
334386
await agent_conn.extNotification("note", {"y": 2})
335387
# allow dispatch
336388
await asyncio.sleep(0.05)
337389
assert agent.ext_notes and agent.ext_notes[-1][0] == "note"
338390

391+
# client extension method
392+
ping = await client_conn.extMethod("example.com/ping", {"k": 3})
393+
assert ping == {"response": "pong", "params": {"k": 3}}
394+
assert client.ext_calls and client.ext_calls[-1] == ("example.com/ping", {"k": 3})
395+
339396

340397
@pytest.mark.asyncio
341398
async def test_ignore_invalid_messages():

0 commit comments

Comments
 (0)