77from acp import (
88 Agent ,
99 AgentSideConnection ,
10+ AuthenticateRequest ,
11+ AuthenticateResponse ,
1012 CancelNotification ,
1113 Client ,
1214 ClientSideConnection ,
1315 InitializeRequest ,
1416 InitializeResponse ,
1517 LoadSessionRequest ,
18+ LoadSessionResponse ,
1619 NewSessionRequest ,
1720 NewSessionResponse ,
1821 PromptRequest ,
1922 PromptResponse ,
2023 ReadTextFileRequest ,
2124 ReadTextFileResponse ,
25+ RequestError ,
2226 RequestPermissionRequest ,
2327 RequestPermissionResponse ,
2428 SessionNotification ,
29+ SetSessionModelRequest ,
30+ SetSessionModelResponse ,
2531 SetSessionModeRequest ,
32+ SetSessionModeResponse ,
2633 WriteTextFileRequest ,
34+ WriteTextFileResponse ,
35+ )
36+ from acp .schema import (
37+ ContentBlock1 ,
38+ RequestPermissionOutcome1 ,
39+ RequestPermissionOutcome2 ,
40+ SessionUpdate1 ,
41+ SessionUpdate2 ,
2742)
28- from acp .schema import ContentBlock1 , SessionUpdate1 , SessionUpdate2
2943
3044# --------------------- Test Utilities ---------------------
3145
@@ -77,18 +91,30 @@ class TestClient(Client):
7791 __test__ = False # prevent pytest from collecting this class
7892
7993 def __init__ (self ) -> None :
80- self .permission_outcomes : list [dict ] = []
94+ self .permission_outcomes : list [RequestPermissionResponse ] = []
8195 self .files : dict [str , str ] = {}
8296 self .notifications : list [SessionNotification ] = []
8397 self .ext_calls : list [tuple [str , dict ]] = []
8498 self .ext_notes : list [tuple [str , dict ]] = []
8599
100+ def queue_permission_cancelled (self ) -> None :
101+ self .permission_outcomes .append (
102+ RequestPermissionResponse (outcome = RequestPermissionOutcome1 (outcome = "cancelled" ))
103+ )
104+
105+ def queue_permission_selected (self , option_id : str ) -> None :
106+ self .permission_outcomes .append (
107+ RequestPermissionResponse (outcome = RequestPermissionOutcome2 (optionId = option_id , outcome = "selected" ))
108+ )
109+
86110 async def requestPermission (self , params : RequestPermissionRequest ) -> RequestPermissionResponse :
87- outcome = self .permission_outcomes .pop () if self .permission_outcomes else {"outcome" : "cancelled" }
88- return RequestPermissionResponse .model_validate ({"outcome" : outcome })
111+ if self .permission_outcomes :
112+ return self .permission_outcomes .pop ()
113+ return RequestPermissionResponse (outcome = RequestPermissionOutcome1 (outcome = "cancelled" ))
89114
90- async def writeTextFile (self , params : WriteTextFileRequest ) -> None :
115+ async def writeTextFile (self , params : WriteTextFileRequest ) -> WriteTextFileResponse :
91116 self .files [str (params .path )] = params .content
117+ return WriteTextFileResponse ()
92118
93119 async def readTextFile (self , params : ReadTextFileRequest ) -> ReadTextFileResponse :
94120 content = self .files .get (str (params .path ), "default content" )
@@ -98,24 +124,26 @@ async def sessionUpdate(self, params: SessionNotification) -> None:
98124 self .notifications .append (params )
99125
100126 # Optional terminal methods (not implemented in this test client)
101- async def createTerminal (self , params ) -> None : # pragma: no cover - placeholder
102- pass
127+ async def createTerminal (self , params ): # pragma: no cover - placeholder
128+ raise NotImplementedError
103129
104- async def terminalOutput (self , params ) -> None : # pragma: no cover - placeholder
105- pass
130+ async def terminalOutput (self , params ): # pragma: no cover - placeholder
131+ raise NotImplementedError
106132
107- async def releaseTerminal (self , params ) -> None : # pragma: no cover - placeholder
108- pass
133+ async def releaseTerminal (self , params ): # pragma: no cover - placeholder
134+ raise NotImplementedError
109135
110- async def waitForTerminalExit (self , params ) -> None : # pragma: no cover - placeholder
111- pass
136+ async def waitForTerminalExit (self , params ): # pragma: no cover - placeholder
137+ raise NotImplementedError
112138
113- async def killTerminal (self , params ) -> None : # pragma: no cover - placeholder
114- pass
139+ async def killTerminal (self , params ): # pragma: no cover - placeholder
140+ raise NotImplementedError
115141
116142 async def extMethod (self , method : str , params : dict ) -> dict :
117143 self .ext_calls .append ((method , params ))
118- return {"ok" : True , "method" : method }
144+ if method == "example.com/ping" :
145+ return {"response" : "pong" , "params" : params }
146+ raise RequestError .method_not_found (method )
119147
120148 async def extNotification (self , method : str , params : dict ) -> None :
121149 self .ext_notes .append ((method , params ))
@@ -137,11 +165,11 @@ async def initialize(self, params: InitializeRequest) -> InitializeResponse:
137165 async def newSession (self , params : NewSessionRequest ) -> NewSessionResponse :
138166 return NewSessionResponse (sessionId = "test-session-123" )
139167
140- async def loadSession (self , params : LoadSessionRequest ) -> None :
141- return None
168+ async def loadSession (self , params : LoadSessionRequest ) -> LoadSessionResponse :
169+ return LoadSessionResponse ()
142170
143- async def authenticate (self , params ) -> None :
144- return None
171+ async def authenticate (self , params : AuthenticateRequest ) -> AuthenticateResponse :
172+ return AuthenticateResponse ()
145173
146174 async def prompt (self , params : PromptRequest ) -> PromptResponse :
147175 self .prompts .append (params )
@@ -150,12 +178,17 @@ async def prompt(self, params: PromptRequest) -> PromptResponse:
150178 async def cancel (self , params : CancelNotification ) -> None :
151179 self .cancellations .append (params .sessionId )
152180
153- async def setSessionMode (self , params ):
154- return {}
181+ async def setSessionMode (self , params : SetSessionModeRequest ) -> SetSessionModeResponse :
182+ return SetSessionModeResponse ()
183+
184+ async def setSessionModel (self , params : SetSessionModelRequest ) -> SetSessionModelResponse :
185+ return SetSessionModelResponse ()
155186
156187 async def extMethod (self , method : str , params : dict ) -> dict :
157188 self .ext_calls .append ((method , params ))
158- return {"ok" : True , "method" : method }
189+ if method == "example.com/echo" :
190+ return {"echo" : params }
191+ raise RequestError .method_not_found (method )
159192
160193 async def extNotification (self , method : str , params : dict ) -> None :
161194 self .ext_notes .append ((method , params ))
@@ -180,6 +213,22 @@ async def test_initialize_and_new_session():
180213 new_sess = await agent_conn .newSession (NewSessionRequest (mcpServers = [], cwd = "/test" ))
181214 assert new_sess .sessionId == "test-session-123"
182215
216+ load_resp = await agent_conn .loadSession (
217+ LoadSessionRequest (sessionId = new_sess .sessionId , cwd = "/test" , mcpServers = [])
218+ )
219+ assert isinstance (load_resp , LoadSessionResponse )
220+
221+ auth_resp = await agent_conn .authenticate (AuthenticateRequest (methodId = "password" ))
222+ assert isinstance (auth_resp , AuthenticateResponse )
223+
224+ mode_resp = await agent_conn .setSessionMode (SetSessionModeRequest (sessionId = new_sess .sessionId , modeId = "ask" ))
225+ assert isinstance (mode_resp , SetSessionModeResponse )
226+
227+ model_resp = await agent_conn .setSessionModel (
228+ SetSessionModelRequest (sessionId = new_sess .sessionId , modelId = "gpt-4o" )
229+ )
230+ assert isinstance (model_resp , SetSessionModelResponse )
231+
183232
184233@pytest .mark .asyncio
185234async def test_bidirectional_file_ops ():
@@ -195,9 +244,10 @@ async def test_bidirectional_file_ops():
195244 assert res .content == "Hello, World!"
196245
197246 # Agent asks client to write
198- await client_conn .writeTextFile (
247+ write_result = await client_conn .writeTextFile (
199248 WriteTextFileRequest (sessionId = "sess" , path = "/test/file.txt" , content = "Updated" )
200249 )
250+ assert isinstance (write_result , WriteTextFileResponse )
201251 assert client .files ["/test/file.txt" ] == "Updated"
202252
203253
@@ -319,23 +369,30 @@ async def test_set_session_mode_and_extensions():
319369 agent = TestAgent ()
320370 client = TestClient ()
321371 agent_conn = ClientSideConnection (lambda _conn : client , s .client_writer , s .client_reader )
322- _client_conn = AgentSideConnection (lambda _conn : agent , s .server_writer , s .server_reader )
372+ client_conn = AgentSideConnection (lambda _conn : agent , s .server_writer , s .server_reader )
323373
324374 # setSessionMode
325375 resp = await agent_conn .setSessionMode (SetSessionModeRequest (sessionId = "sess" , modeId = "yolo" ))
326- # Either empty object or typed response depending on implementation
327- assert resp is None or resp .__class__ .__name__ == "SetSessionModeResponse"
376+ assert isinstance (resp , SetSessionModeResponse )
377+
378+ model_resp = await agent_conn .setSessionModel (SetSessionModelRequest (sessionId = "sess" , modelId = "gpt-4o-mini" ))
379+ assert isinstance (model_resp , SetSessionModelResponse )
328380
329381 # extMethod
330- res = await agent_conn .extMethod ("ping " , {"x" : 1 })
331- assert res . get ( "ok" ) is True
382+ echo = await agent_conn .extMethod ("example.com/echo " , {"x" : 1 })
383+ assert echo == { "echo" : { "x" : 1 }}
332384
333385 # extNotification
334386 await agent_conn .extNotification ("note" , {"y" : 2 })
335387 # allow dispatch
336388 await asyncio .sleep (0.05 )
337389 assert agent .ext_notes and agent .ext_notes [- 1 ][0 ] == "note"
338390
391+ # client extension method
392+ ping = await client_conn .extMethod ("example.com/ping" , {"k" : 3 })
393+ assert ping == {"response" : "pong" , "params" : {"k" : 3 }}
394+ assert client .ext_calls and client .ext_calls [- 1 ] == ("example.com/ping" , {"k" : 3 })
395+
339396
340397@pytest .mark .asyncio
341398async def test_ignore_invalid_messages ():
0 commit comments