diff --git a/src/ableton_cli/client/protocol.py b/src/ableton_cli/client/protocol.py index 107664a..8ba5063 100644 --- a/src/ableton_cli/client/protocol.py +++ b/src/ableton_cli/client/protocol.py @@ -39,6 +39,15 @@ class Response: REQUIRED_RESPONSE_KEYS = {"ok", "request_id", "protocol_version"} +def _raise_protocol_error(error_code: str, message: str, hint: str) -> None: + raise AppError( + error_code=error_code, + message=message, + hint=hint, + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + + def make_request( name: str, args: dict[str, Any], @@ -60,26 +69,24 @@ def parse_response( ) -> Response: missing = REQUIRED_RESPONSE_KEYS.difference(payload) if missing: - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message=f"Invalid response payload, missing keys: {sorted(missing)}", hint="Ensure the Remote Script protocol implementation matches the CLI.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) response_protocol = payload.get("protocol_version") if not isinstance(response_protocol, int): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="protocol_version must be an integer", hint=( "Set matching protocol versions on both sides " "(--protocol-version or 'ableton-cli config set protocol_version ')." ), - exit_code=ExitCode.PROTOCOL_MISMATCH, ) if response_protocol != expected_protocol: - raise AppError( + _raise_protocol_error( error_code="PROTOCOL_VERSION_MISMATCH", message=( f"Protocol version mismatch (cli={expected_protocol}, remote={response_protocol})" @@ -88,51 +95,45 @@ def parse_response( "Align protocol_version in CLI and Remote Script " "(--protocol-version or 'ableton-cli config set protocol_version ')." ), - exit_code=ExitCode.PROTOCOL_MISMATCH, ) request_id = payload.get("request_id") if request_id != expected_request_id: - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_REQUEST_ID_MISMATCH", message=(f"request_id mismatch (expected={expected_request_id}, actual={request_id})"), hint="Check request routing in the Remote Script server.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) ok = payload.get("ok") if not isinstance(ok, bool): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'ok' must be a boolean in response payload", hint="Update Remote Script response format.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) result = payload.get("result") if result is not None and not isinstance(result, dict): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'result' must be an object when provided", hint="Return JSON object for result payloads.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) error = payload.get("error") if error is not None and not isinstance(error, dict): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'error' must be an object when provided", hint="Return structured error payload with code/message.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) if isinstance(error, dict) and "details" in error and error["details"] is not None: if not isinstance(error["details"], dict): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'error.details' must be an object when provided", hint="Return structured error details as a JSON object.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) return Response( diff --git a/src/ableton_cli/errors.py b/src/ableton_cli/errors.py index 9031e40..5cec127 100644 --- a/src/ableton_cli/errors.py +++ b/src/ableton_cli/errors.py @@ -41,6 +41,8 @@ def to_payload(self) -> dict[str, Any]: "REMOTE_SCRIPT_NOT_INSTALLED": ExitCode.REMOTE_SCRIPT_NOT_DETECTED, "REMOTE_SCRIPT_INCOMPATIBLE": ExitCode.PROTOCOL_MISMATCH, "PROTOCOL_VERSION_MISMATCH": ExitCode.PROTOCOL_MISMATCH, + "PROTOCOL_INVALID_RESPONSE": ExitCode.PROTOCOL_MISMATCH, + "PROTOCOL_REQUEST_ID_MISMATCH": ExitCode.PROTOCOL_MISMATCH, "TIMEOUT": ExitCode.TIMEOUT, "BATCH_STEP_FAILED": ExitCode.EXECUTION_FAILED, "REMOTE_BUSY": ExitCode.EXECUTION_FAILED, diff --git a/src/ableton_cli/runtime.py b/src/ableton_cli/runtime.py index 9d79af3..39e2eee 100644 --- a/src/ableton_cli/runtime.py +++ b/src/ableton_cli/runtime.py @@ -28,6 +28,12 @@ class RuntimeContext: output_mode: OutputMode quiet: bool no_color: bool + _client: AbletonClient | None = None + + def client(self) -> AbletonClient: + if self._client is None: + self._client = AbletonClient(self.settings) + return self._client def get_runtime(ctx: typer.Context) -> RuntimeContext: @@ -39,7 +45,7 @@ def get_runtime(ctx: typer.Context) -> RuntimeContext: def get_client(ctx: typer.Context) -> AbletonClient: runtime = get_runtime(ctx) - return AbletonClient(runtime.settings) + return runtime.client() def execute_command( diff --git a/tests/test_exit_codes.py b/tests/test_exit_codes.py index 229295e..8a33c4c 100644 --- a/tests/test_exit_codes.py +++ b/tests/test_exit_codes.py @@ -25,6 +25,8 @@ def test_remote_error_to_exit_code_mapping() -> None: ) assert exit_code_from_error_code("REMOTE_SCRIPT_INCOMPATIBLE") == ExitCode.PROTOCOL_MISMATCH assert exit_code_from_error_code("PROTOCOL_VERSION_MISMATCH") == ExitCode.PROTOCOL_MISMATCH + assert exit_code_from_error_code("PROTOCOL_INVALID_RESPONSE") == ExitCode.PROTOCOL_MISMATCH + assert exit_code_from_error_code("PROTOCOL_REQUEST_ID_MISMATCH") == ExitCode.PROTOCOL_MISMATCH assert exit_code_from_error_code("TIMEOUT") == ExitCode.TIMEOUT assert exit_code_from_error_code("BATCH_STEP_FAILED") == ExitCode.EXECUTION_FAILED assert exit_code_from_error_code("REMOTE_BUSY") == ExitCode.EXECUTION_FAILED diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 770b559..840e52e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -54,6 +54,54 @@ def test_parse_response_protocol_mismatch_raises() -> None: assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH +def test_parse_response_missing_keys_raises_invalid_response() -> None: + request = make_request(name="ping", args={}, protocol_version=2) + payload = { + "ok": True, + "request_id": request.request_id, + } + + with pytest.raises(AppError) as exc_info: + parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) + + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + +def test_parse_response_request_id_mismatch_raises() -> None: + request = make_request(name="ping", args={}, protocol_version=2) + payload = { + "ok": True, + "request_id": "other-request-id", + "protocol_version": 2, + "result": {"pong": True}, + "error": None, + } + + with pytest.raises(AppError) as exc_info: + parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) + + assert exc_info.value.error_code == "PROTOCOL_REQUEST_ID_MISMATCH" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + +def test_parse_response_rejects_non_integer_protocol_version() -> None: + request = make_request(name="ping", args={}, protocol_version=2) + payload = { + "ok": True, + "request_id": request.request_id, + "protocol_version": "2", + "result": {"pong": True}, + "error": None, + } + + with pytest.raises(AppError) as exc_info: + parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) + + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + def test_parse_response_rejects_non_object_error_details() -> None: request = make_request(name="ping", args={}, protocol_version=2) payload = { @@ -67,7 +115,7 @@ def test_parse_response_rejects_non_object_error_details() -> None: with pytest.raises(AppError) as exc_info: parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) - assert exc_info.value.error_code == "PROTOCOL_VERSION_MISMATCH" + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" def test_parse_response_accepts_error_details_object() -> None: diff --git a/tests/test_runtime_quiet.py b/tests/test_runtime_quiet.py index ef736ff..a8a2cdd 100644 --- a/tests/test_runtime_quiet.py +++ b/tests/test_runtime_quiet.py @@ -8,7 +8,7 @@ from ableton_cli.config import Settings from ableton_cli.output import OutputMode -from ableton_cli.runtime import RuntimeContext, execute_command +from ableton_cli.runtime import RuntimeContext, execute_command, get_client def _context(*, quiet: bool) -> SimpleNamespace: @@ -65,3 +65,21 @@ def test_execute_command_not_quiet_emits_custom_human_formatter(monkeypatch) -> assert exc_info.value.exit_code == 0 assert len(emitted) == 1 assert emitted[0][0][0] == "Doctor Results" + + +def test_get_client_reuses_client_for_same_runtime(monkeypatch) -> None: + created_with: list[Settings] = [] + + class FakeClient: + def __init__(self, settings: Settings) -> None: + self.settings = settings + created_with.append(settings) + + monkeypatch.setattr("ableton_cli.runtime.AbletonClient", FakeClient) + + ctx = _context(quiet=False) + first = get_client(ctx) + second = get_client(ctx) + + assert first is second + assert created_with == [ctx.obj.settings]