From 62d7b32d47149ff2eae5a3a91952c45e2de14e51 Mon Sep 17 00:00:00 2001 From: 6uclz1 <9139177+6uclz1@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:41:47 +0900 Subject: [PATCH 1/3] reduce command layer duplication --- src/ableton_cli/commands/_validation.py | 80 ++++++++- src/ableton_cli/commands/effect.py | 80 +++++---- src/ableton_cli/commands/scenes.py | 168 +++++++++++++----- src/ableton_cli/commands/synth.py | 80 +++++---- src/ableton_cli/commands/track.py | 160 ++++++++--------- tests/commands/test_effect_command_adapter.py | 73 ++++++++ tests/commands/test_scenes_command_adapter.py | 102 +++++++++++ tests/commands/test_synth_command_adapter.py | 73 ++++++++ tests/commands/test_track_command_adapter.py | 101 +++++++++++ tests/commands/test_validation.py | 113 ++++++++++++ 10 files changed, 825 insertions(+), 205 deletions(-) create mode 100644 tests/commands/test_effect_command_adapter.py create mode 100644 tests/commands/test_scenes_command_adapter.py create mode 100644 tests/commands/test_synth_command_adapter.py create mode 100644 tests/commands/test_track_command_adapter.py diff --git a/src/ableton_cli/commands/_validation.py b/src/ableton_cli/commands/_validation.py index 1d70da3..beceed8 100644 --- a/src/ableton_cli/commands/_validation.py +++ b/src/ableton_cli/commands/_validation.py @@ -2,13 +2,23 @@ import json from pathlib import Path, PurePosixPath, PureWindowsPath -from typing import Any +from typing import Any, TypeVar from ..errors import AppError, ExitCode NOTE_KEYS = {"pitch", "start_time", "duration", "velocity", "mute"} TRACK_INDEX_HINT = "Use a valid track index from 'ableton-cli tracks list'." DEVICE_INDEX_HINT = "Use a valid device index from 'ableton-cli track info'." +SCENE_INDEX_HINT = "Use a valid scene index from 'scenes list'." +SCENE_SOURCE_HINT = "Use a valid source scene index from 'scenes list'." +SCENE_DESTINATION_HINT = "Use a valid destination scene index from 'scenes list'." +SCENE_NAME_HINT = "Pass a non-empty scene name." +SCENE_INSERT_INDEX_HINT = "Use -1 for append or a non-negative insertion index." +TRACK_NAME_HINT = "Pass a non-empty track name." +VOLUME_VALUE_HINT = "Use a normalized volume value such as 0.75." +PAN_VALUE_HINT = "Use a normalized panning value such as -0.25." + +TValue = TypeVar("TValue") def invalid_argument(message: str, hint: str) -> AppError: @@ -34,6 +44,10 @@ def require_device_index(value: int, *, hint: str = DEVICE_INDEX_HINT) -> int: return require_non_negative("device", value, hint=hint) +def require_scene_index(value: int, *, hint: str = SCENE_INDEX_HINT) -> int: + return require_non_negative("scene", value, hint=hint) + + def require_parameter_index(value: int, *, hint: str) -> int: return require_non_negative("parameter", value, hint=hint) @@ -72,6 +86,70 @@ def require_float_in_range( return value +def require_track_and_value(track: int, value: TValue) -> tuple[int, TValue]: + return require_track_index(track), value + + +def require_optional_track_index(track: int | None) -> int | None: + if track is None: + return None + return require_track_index(track) + + +def require_track_and_device(track: int, device: int) -> tuple[int, int]: + return require_track_index(track), require_device_index(device) + + +def require_scene_and_value(scene: int, value: TValue) -> tuple[int, TValue]: + return require_scene_index(scene), value + + +def require_track_and_name(track: int, value: str) -> tuple[int, str]: + valid_track = require_track_index(track) + valid_name = require_non_empty_string("name", value, hint=TRACK_NAME_HINT) + return valid_track, valid_name + + +def require_scene_and_name(scene: int, value: str) -> tuple[int, str]: + valid_scene = require_scene_index(scene) + valid_name = require_non_empty_string("name", value, hint=SCENE_NAME_HINT) + return valid_scene, valid_name + + +def require_scene_move(from_scene: int, to_scene: int) -> tuple[int, int]: + valid_from_scene = require_non_negative("from", from_scene, hint=SCENE_SOURCE_HINT) + valid_to_scene = require_non_negative("to", to_scene, hint=SCENE_DESTINATION_HINT) + return valid_from_scene, valid_to_scene + + +def require_scene_insert_index(index: int) -> int: + return require_minus_one_or_non_negative("index", index, hint=SCENE_INSERT_INDEX_HINT) + + +def require_track_and_volume(track: int, value: float) -> tuple[int, float]: + valid_track = require_track_index(track) + valid_value = require_float_in_range( + "value", + value, + minimum=0.0, + maximum=1.0, + hint=VOLUME_VALUE_HINT, + ) + return valid_track, valid_value + + +def require_track_and_pan(track: int, value: float) -> tuple[int, float]: + valid_track = require_track_index(track) + valid_value = require_float_in_range( + "value", + value, + minimum=-1.0, + maximum=1.0, + hint=PAN_VALUE_HINT, + ) + return valid_track, valid_value + + def require_non_empty_string(name: str, value: str, *, hint: str) -> str: stripped = value.strip() if not stripped: diff --git a/src/ableton_cli/commands/effect.py b/src/ableton_cli/commands/effect.py index 324da26..c38b65d 100644 --- a/src/ableton_cli/commands/effect.py +++ b/src/ableton_cli/commands/effect.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Annotated import typer @@ -8,10 +8,10 @@ from ..runtime import execute_command, get_client from ._validation import ( invalid_argument, - require_device_index, require_non_empty_string, + require_optional_track_index, require_parameter_index, - require_track_index, + require_track_and_device, ) _SUPPORTED_EFFECT_TYPES = ( @@ -31,6 +31,9 @@ parameters_app = typer.Typer(help="Effect parameter listing commands", no_args_is_help=True) parameter_app = typer.Typer(help="Effect parameter write commands", no_args_is_help=True) +TrackDeviceValidator = Callable[[int, int], tuple[int, int]] +TrackDeviceAction = Callable[[object, int, int], dict[str, object]] + def _normalize_effect_type(value: str) -> str: parsed = require_non_empty_string("effect_type", value, hint="Pass a non-empty effect type.") @@ -43,19 +46,6 @@ def _normalize_effect_type(value: str) -> str: return normalized -def _require_optional_track_index(track: int | None) -> int | None: - if track is None: - return None - return require_track_index(track) - - -def _require_track_and_device_index(track: int, device: int) -> tuple[int, int]: - return ( - require_track_index(track), - require_device_index(device), - ) - - def _require_effect_parameter_index(parameter: int) -> int: return require_parameter_index( parameter, @@ -63,21 +53,28 @@ def _require_effect_parameter_index(parameter: int) -> int: ) -def _execute_track_device_command( +def run_track_device_command( ctx: typer.Context, *, - command: str, + command_name: str, track: int, device: int, - action: Callable[[int, int], dict[str, object]], + fn: TrackDeviceAction, + validators: Sequence[TrackDeviceValidator] | None = None, ) -> None: + active_validators = validators if validators is not None else (require_track_and_device,) + def _run() -> dict[str, object]: - valid_track, valid_device = _require_track_and_device_index(track, device) - return action(valid_track, valid_device) + valid_track = track + valid_device = device + for validator in active_validators: + valid_track, valid_device = validator(valid_track, valid_device) + client = get_client(ctx) + return fn(client, valid_track, valid_device) execute_command( ctx, - command=command, + command=command_name, args={"track": track, "device": device}, action=_run, ) @@ -99,9 +96,10 @@ def effect_find( ] = None, ) -> None: def _run() -> dict[str, object]: - valid_track = _require_optional_track_index(track) + valid_track = require_optional_track_index(track) valid_type = _normalize_effect_type(effect_type) if effect_type is not None else None - return get_client(ctx).find_effect_devices(track=valid_track, effect_type=valid_type) + client = get_client(ctx) + return client.find_effect_devices(track=valid_track, effect_type=valid_type) execute_command( ctx, @@ -117,12 +115,12 @@ def effect_parameters_list( track: TrackArgument, device: DeviceArgument, ) -> None: - _execute_track_device_command( + run_track_device_command( ctx, - command="effect parameters list", + command_name="effect parameters list", track=track, device=device, - action=lambda valid_track, valid_device: get_client(ctx).list_effect_parameters( + fn=lambda client, valid_track, valid_device: client.list_effect_parameters( track=valid_track, device=valid_device, ), @@ -138,9 +136,10 @@ def effect_parameter_set( value: Annotated[float, typer.Argument(help="Target parameter value")], ) -> None: def _run() -> dict[str, object]: - valid_track, valid_device = _require_track_and_device_index(track, device) + valid_track, valid_device = require_track_and_device(track, device) valid_parameter = _require_effect_parameter_index(parameter) - return get_client(ctx).set_effect_parameter_safe( + client = get_client(ctx) + return client.set_effect_parameter_safe( track=valid_track, device=valid_device, parameter=valid_parameter, @@ -161,12 +160,12 @@ def effect_observe( track: TrackArgument, device: DeviceArgument, ) -> None: - _execute_track_device_command( + run_track_device_command( ctx, - command="effect observe", + command_name="effect observe", track=track, device=device, - action=lambda valid_track, valid_device: get_client(ctx).observe_effect_parameters( + fn=lambda client, valid_track, valid_device: client.observe_effect_parameters( track=valid_track, device=valid_device, ), @@ -181,11 +180,15 @@ def _build_standard_effect_app(effect_type: str, cli_name: str) -> typer.Typer: @standard_app.command("keys") def keys(ctx: typer.Context) -> None: + def _run() -> dict[str, object]: + client = get_client(ctx) + return client.list_standard_effect_keys(effect_type) + execute_command( ctx, command=f"effect {cli_name} keys", args={}, - action=lambda: get_client(ctx).list_standard_effect_keys(effect_type), + action=_run, ) @standard_app.command("set") @@ -197,13 +200,14 @@ def standard_set( value: Annotated[float, typer.Argument(help="Target parameter value")], ) -> None: def _run() -> dict[str, object]: - valid_track, valid_device = _require_track_and_device_index(track, device) + valid_track, valid_device = require_track_and_device(track, device) valid_key = require_non_empty_string( "key", key, hint="Pass a non-empty stable effect key.", ) - return get_client(ctx).set_standard_effect_parameter_safe( + client = get_client(ctx) + return client.set_standard_effect_parameter_safe( effect_type=effect_type, track=valid_track, device=valid_device, @@ -224,12 +228,12 @@ def standard_observe( track: TrackArgument, device: DeviceArgument, ) -> None: - _execute_track_device_command( + run_track_device_command( ctx, - command=f"effect {cli_name} observe", + command_name=f"effect {cli_name} observe", track=track, device=device, - action=lambda valid_track, valid_device: get_client(ctx).observe_standard_effect_state( + fn=lambda client, valid_track, valid_device: client.observe_standard_effect_state( effect_type=effect_type, track=valid_track, device=valid_device, diff --git a/src/ableton_cli/commands/scenes.py b/src/ableton_cli/commands/scenes.py index c177979..6d46d6d 100644 --- a/src/ableton_cli/commands/scenes.py +++ b/src/ableton_cli/commands/scenes.py @@ -1,27 +1,124 @@ from __future__ import annotations -from typing import Annotated +from collections.abc import Callable, Sequence +from typing import Annotated, TypeVar import typer from ..runtime import execute_command, get_client from ._validation import ( - require_minus_one_or_non_negative, - require_non_empty_string, - require_non_negative, + require_scene_and_name, + require_scene_and_value, + require_scene_index, + require_scene_insert_index, + require_scene_move, ) scenes_app = typer.Typer(help="Scenes commands", no_args_is_help=True) name_app = typer.Typer(help="Scenes naming commands", no_args_is_help=True) +TValue = TypeVar("TValue") + +SceneValidator = Callable[[int], int] +SceneValueValidator = Callable[[int, TValue], tuple[int, TValue]] +SceneMoveValidator = Callable[[int, int], tuple[int, int]] + +SceneAction = Callable[[object, int], dict[str, object]] +SceneValueAction = Callable[[object, int, TValue], dict[str, object]] +SceneMoveAction = Callable[[object, int, int], dict[str, object]] + + +def run_scene_command( + ctx: typer.Context, + *, + command_name: str, + scene: int, + fn: SceneAction, + validators: Sequence[SceneValidator] | None = None, +) -> None: + active_validators = validators if validators is not None else (require_scene_index,) + + def _run() -> dict[str, object]: + valid_scene = scene + for validator in active_validators: + valid_scene = validator(valid_scene) + client = get_client(ctx) + return fn(client, valid_scene) + + execute_command( + ctx, + command=command_name, + args={"scene": scene}, + action=_run, + ) + + +def run_scene_value_command( + ctx: typer.Context, + *, + command_name: str, + scene: int, + value: TValue, + fn: SceneValueAction[TValue], + value_name: str = "value", + validators: Sequence[SceneValueValidator[TValue]] | None = None, +) -> None: + active_validators = validators if validators is not None else (require_scene_and_value,) + + def _run() -> dict[str, object]: + valid_scene = scene + valid_value = value + for validator in active_validators: + valid_scene, valid_value = validator(valid_scene, valid_value) + client = get_client(ctx) + return fn(client, valid_scene, valid_value) + + execute_command( + ctx, + command=command_name, + args={"scene": scene, value_name: value}, + action=_run, + ) + + +def run_scene_move_command( + ctx: typer.Context, + *, + command_name: str, + from_scene: int, + to_scene: int, + fn: SceneMoveAction, + validators: Sequence[SceneMoveValidator] | None = None, +) -> None: + active_validators = validators if validators is not None else (require_scene_move,) + + def _run() -> dict[str, object]: + valid_from_scene = from_scene + valid_to_scene = to_scene + for validator in active_validators: + valid_from_scene, valid_to_scene = validator(valid_from_scene, valid_to_scene) + client = get_client(ctx) + return fn(client, valid_from_scene, valid_to_scene) + + execute_command( + ctx, + command=command_name, + args={"from": from_scene, "to": to_scene}, + action=_run, + ) + @scenes_app.command("list") def scenes_list(ctx: typer.Context) -> None: + def _run() -> dict[str, object]: + client = get_client(ctx) + return client.scenes_list() + execute_command( ctx, command="scenes list", args={}, - action=lambda: get_client(ctx).scenes_list(), + action=_run, ) @@ -37,12 +134,9 @@ def scenes_create( ] = -1, ) -> None: def _run() -> dict[str, object]: - valid_index = require_minus_one_or_non_negative( - "index", - index, - hint="Use -1 for append or a non-negative insertion index.", - ) - return get_client(ctx).create_scene(valid_index) + valid_index = require_scene_insert_index(index) + client = get_client(ctx) + return client.create_scene(valid_index) execute_command( ctx, @@ -58,16 +152,14 @@ def scenes_name_set( scene: Annotated[int, typer.Argument(help="Scene index (0-based)")], name: Annotated[str, typer.Argument(help="New scene name")], ) -> None: - def _run() -> dict[str, object]: - require_non_negative("scene", scene, hint="Use a valid scene index from 'scenes list'.") - valid_name = require_non_empty_string("name", name, hint="Pass a non-empty scene name.") - return get_client(ctx).set_scene_name(scene, valid_name) - - execute_command( + run_scene_value_command( ctx, - command="scenes name set", - args={"scene": scene, "name": name}, - action=_run, + command_name="scenes name set", + scene=scene, + value=name, + value_name="name", + validators=[require_scene_and_name], + fn=lambda client, valid_scene, valid_name: client.set_scene_name(valid_scene, valid_name), ) @@ -76,15 +168,11 @@ def scenes_fire( ctx: typer.Context, scene: Annotated[int, typer.Argument(help="Scene index (0-based)")], ) -> None: - def _run() -> dict[str, object]: - require_non_negative("scene", scene, hint="Use a valid scene index from 'scenes list'.") - return get_client(ctx).fire_scene(scene) - - execute_command( + run_scene_command( ctx, - command="scenes fire", - args={"scene": scene}, - action=_run, + command_name="scenes fire", + scene=scene, + fn=lambda client, valid_scene: client.fire_scene(valid_scene), ) @@ -94,24 +182,14 @@ def scenes_move( from_scene: Annotated[int, typer.Argument(help="Source scene index (0-based)")], to_scene: Annotated[int, typer.Argument(help="Destination scene index (0-based)")], ) -> None: - def _run() -> dict[str, object]: - require_non_negative( - "from", - from_scene, - hint="Use a valid source scene index from 'scenes list'.", - ) - require_non_negative( - "to", - to_scene, - hint="Use a valid destination scene index from 'scenes list'.", - ) - return get_client(ctx).scenes_move(from_scene, to_scene) - - execute_command( + run_scene_move_command( ctx, - command="scenes move", - args={"from": from_scene, "to": to_scene}, - action=_run, + command_name="scenes move", + from_scene=from_scene, + to_scene=to_scene, + fn=lambda client, valid_from_scene, valid_to_scene: client.scenes_move( + valid_from_scene, valid_to_scene + ), ) diff --git a/src/ableton_cli/commands/synth.py b/src/ableton_cli/commands/synth.py index 3153552..51f666e 100644 --- a/src/ableton_cli/commands/synth.py +++ b/src/ableton_cli/commands/synth.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Annotated import typer @@ -8,10 +8,10 @@ from ..runtime import execute_command, get_client from ._validation import ( invalid_argument, - require_device_index, require_non_empty_string, + require_optional_track_index, require_parameter_index, - require_track_index, + require_track_and_device, ) _SUPPORTED_SYNTH_TYPES = ("wavetable", "drift", "meld") @@ -24,6 +24,9 @@ parameters_app = typer.Typer(help="Synth parameter listing commands", no_args_is_help=True) parameter_app = typer.Typer(help="Synth parameter write commands", no_args_is_help=True) +TrackDeviceValidator = Callable[[int, int], tuple[int, int]] +TrackDeviceAction = Callable[[object, int, int], dict[str, object]] + def _normalize_synth_type(value: str) -> str: parsed = require_non_empty_string("synth_type", value, hint="Pass a non-empty synth type.") @@ -36,19 +39,6 @@ def _normalize_synth_type(value: str) -> str: return normalized -def _require_optional_track_index(track: int | None) -> int | None: - if track is None: - return None - return require_track_index(track) - - -def _require_track_and_device_index(track: int, device: int) -> tuple[int, int]: - return ( - require_track_index(track), - require_device_index(device), - ) - - def _require_synth_parameter_index(parameter: int) -> int: return require_parameter_index( parameter, @@ -56,21 +46,28 @@ def _require_synth_parameter_index(parameter: int) -> int: ) -def _execute_track_device_command( +def run_track_device_command( ctx: typer.Context, *, - command: str, + command_name: str, track: int, device: int, - action: Callable[[int, int], dict[str, object]], + fn: TrackDeviceAction, + validators: Sequence[TrackDeviceValidator] | None = None, ) -> None: + active_validators = validators if validators is not None else (require_track_and_device,) + def _run() -> dict[str, object]: - valid_track, valid_device = _require_track_and_device_index(track, device) - return action(valid_track, valid_device) + valid_track = track + valid_device = device + for validator in active_validators: + valid_track, valid_device = validator(valid_track, valid_device) + client = get_client(ctx) + return fn(client, valid_track, valid_device) execute_command( ctx, - command=command, + command=command_name, args={"track": track, "device": device}, action=_run, ) @@ -89,9 +86,10 @@ def synth_find( ] = None, ) -> None: def _run() -> dict[str, object]: - valid_track = _require_optional_track_index(track) + valid_track = require_optional_track_index(track) valid_type = _normalize_synth_type(synth_type) if synth_type is not None else None - return get_client(ctx).find_synth_devices(track=valid_track, synth_type=valid_type) + client = get_client(ctx) + return client.find_synth_devices(track=valid_track, synth_type=valid_type) execute_command( ctx, @@ -107,12 +105,12 @@ def synth_parameters_list( track: TrackArgument, device: DeviceArgument, ) -> None: - _execute_track_device_command( + run_track_device_command( ctx, - command="synth parameters list", + command_name="synth parameters list", track=track, device=device, - action=lambda valid_track, valid_device: get_client(ctx).list_synth_parameters( + fn=lambda client, valid_track, valid_device: client.list_synth_parameters( track=valid_track, device=valid_device, ), @@ -128,9 +126,10 @@ def synth_parameter_set( value: Annotated[float, typer.Argument(help="Target parameter value")], ) -> None: def _run() -> dict[str, object]: - valid_track, valid_device = _require_track_and_device_index(track, device) + valid_track, valid_device = require_track_and_device(track, device) valid_parameter = _require_synth_parameter_index(parameter) - return get_client(ctx).set_synth_parameter_safe( + client = get_client(ctx) + return client.set_synth_parameter_safe( track=valid_track, device=valid_device, parameter=valid_parameter, @@ -151,12 +150,12 @@ def synth_observe( track: TrackArgument, device: DeviceArgument, ) -> None: - _execute_track_device_command( + run_track_device_command( ctx, - command="synth observe", + command_name="synth observe", track=track, device=device, - action=lambda valid_track, valid_device: get_client(ctx).observe_synth_parameters( + fn=lambda client, valid_track, valid_device: client.observe_synth_parameters( track=valid_track, device=valid_device, ), @@ -171,11 +170,15 @@ def _build_standard_synth_app(synth_type: str) -> typer.Typer: @standard_app.command("keys") def keys(ctx: typer.Context) -> None: + def _run() -> dict[str, object]: + client = get_client(ctx) + return client.list_standard_synth_keys(synth_type) + execute_command( ctx, command=f"synth {synth_type} keys", args={}, - action=lambda: get_client(ctx).list_standard_synth_keys(synth_type), + action=_run, ) @standard_app.command("set") @@ -187,13 +190,14 @@ def standard_set( value: Annotated[float, typer.Argument(help="Target parameter value")], ) -> None: def _run() -> dict[str, object]: - valid_track, valid_device = _require_track_and_device_index(track, device) + valid_track, valid_device = require_track_and_device(track, device) valid_key = require_non_empty_string( "key", key, hint="Pass a non-empty stable synth key.", ) - return get_client(ctx).set_standard_synth_parameter_safe( + client = get_client(ctx) + return client.set_standard_synth_parameter_safe( synth_type=synth_type, track=valid_track, device=valid_device, @@ -214,12 +218,12 @@ def standard_observe( track: TrackArgument, device: DeviceArgument, ) -> None: - _execute_track_device_command( + run_track_device_command( ctx, - command=f"synth {synth_type} observe", + command_name=f"synth {synth_type} observe", track=track, device=device, - action=lambda valid_track, valid_device: get_client(ctx).observe_standard_synth_state( + fn=lambda client, valid_track, valid_device: client.observe_standard_synth_state( synth_type=synth_type, track=valid_track, device=valid_device, diff --git a/src/ableton_cli/commands/track.py b/src/ableton_cli/commands/track.py index 10e8594..af5ae07 100644 --- a/src/ableton_cli/commands/track.py +++ b/src/ableton_cli/commands/track.py @@ -1,12 +1,18 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Annotated, TypeVar import typer from ..runtime import execute_command, get_client -from ._validation import require_float_in_range, require_non_empty_string, require_track_index +from ._validation import ( + require_track_and_name, + require_track_and_pan, + require_track_and_value, + require_track_and_volume, + require_track_index, +) TValue = TypeVar("TValue") @@ -14,77 +20,65 @@ VolumeValueArgument = Annotated[float, typer.Argument(help="Volume value in [0.0, 1.0]")] PanningValueArgument = Annotated[float, typer.Argument(help="Panning value in [-1.0, 1.0]")] +TrackValidator = Callable[[int], int] +TrackValueValidator = Callable[[int, TValue], tuple[int, TValue]] +TrackAction = Callable[[object, int], dict[str, object]] +TrackValueAction = Callable[[object, int, TValue], dict[str, object]] -def _execute_track_get( + +def run_track_command( ctx: typer.Context, *, - command: str, + command_name: str, track: int, - action: Callable[[int], dict[str, object]], + fn: TrackAction, + validators: Sequence[TrackValidator] | None = None, ) -> None: + active_validators = validators if validators is not None else (require_track_index,) + def _run() -> dict[str, object]: - valid_track = require_track_index(track) - return action(valid_track) + valid_track = track + for validator in active_validators: + valid_track = validator(valid_track) + client = get_client(ctx) + return fn(client, valid_track) execute_command( ctx, - command=command, + command=command_name, args={"track": track}, action=_run, ) -def _execute_track_set( +def run_track_value_command( ctx: typer.Context, *, - command: str, + command_name: str, track: int, value: TValue, - action: Callable[[int, TValue], dict[str, object]], + fn: TrackValueAction[TValue], value_name: str = "value", - validator: Callable[[TValue], TValue] | None = None, + validators: Sequence[TrackValueValidator[TValue]] | None = None, ) -> None: + active_validators = validators if validators is not None else (require_track_and_value,) + def _run() -> dict[str, object]: - valid_track = require_track_index(track) - valid_value = validator(value) if validator is not None else value - return action(valid_track, valid_value) + valid_track = track + valid_value = value + for validator in active_validators: + valid_track, valid_value = validator(valid_track, valid_value) + client = get_client(ctx) + return fn(client, valid_track, valid_value) execute_command( ctx, - command=command, + command=command_name, args={"track": track, value_name: value}, action=_run, ) -def _require_volume_value(value: float) -> float: - return require_float_in_range( - "value", - value, - minimum=0.0, - maximum=1.0, - hint="Use a normalized volume value such as 0.75.", - ) - - -def _require_panning_value(value: float) -> float: - return require_float_in_range( - "value", - value, - minimum=-1.0, - maximum=1.0, - hint="Use a normalized panning value such as -0.25.", - ) - - -def _require_track_name(value: str) -> str: - return require_non_empty_string( - "name", - value, - hint="Pass a non-empty track name.", - ) - - track_app = typer.Typer(help="Single-track commands", no_args_is_help=True) volume_app = typer.Typer(help="Track volume commands", no_args_is_help=True) name_app = typer.Typer(help="Track naming commands", no_args_is_help=True) @@ -99,11 +93,11 @@ def track_info( ctx: typer.Context, track: TrackArgument, ) -> None: - _execute_track_get( + run_track_command( ctx, - command="track info", + command_name="track info", track=track, - action=lambda valid_track: get_client(ctx).get_track_info(valid_track), + fn=lambda client, valid_track: client.get_track_info(valid_track), ) @@ -112,11 +106,11 @@ def volume_get( ctx: typer.Context, track: TrackArgument, ) -> None: - _execute_track_get( + run_track_command( ctx, - command="track volume get", + command_name="track volume get", track=track, - action=lambda valid_track: get_client(ctx).track_volume_get(valid_track), + fn=lambda client, valid_track: client.track_volume_get(valid_track), ) @@ -126,13 +120,13 @@ def volume_set( track: TrackArgument, value: VolumeValueArgument, ) -> None: - _execute_track_set( + run_track_value_command( ctx, - command="track volume set", + command_name="track volume set", track=track, value=value, - validator=_require_volume_value, - action=lambda valid_track, valid_value: get_client(ctx).track_volume_set( + validators=[require_track_and_volume], + fn=lambda client, valid_track, valid_value: client.track_volume_set( valid_track, valid_value, ), @@ -145,14 +139,14 @@ def track_name_set( track: TrackArgument, name: Annotated[str, typer.Argument(help="New track name")], ) -> None: - _execute_track_set( + run_track_value_command( ctx, - command="track name set", + command_name="track name set", track=track, value=name, value_name="name", - validator=_require_track_name, - action=lambda valid_track, valid_name: get_client(ctx).set_track_name( + validators=[require_track_and_name], + fn=lambda client, valid_track, valid_name: client.set_track_name( valid_track, valid_name, ), @@ -164,11 +158,11 @@ def mute_get( ctx: typer.Context, track: TrackArgument, ) -> None: - _execute_track_get( + run_track_command( ctx, - command="track mute get", + command_name="track mute get", track=track, - action=lambda valid_track: get_client(ctx).track_mute_get(valid_track), + fn=lambda client, valid_track: client.track_mute_get(valid_track), ) @@ -178,12 +172,12 @@ def mute_set( track: TrackArgument, value: Annotated[bool, typer.Argument(help="Mute value: true|false")], ) -> None: - _execute_track_set( + run_track_value_command( ctx, - command="track mute set", + command_name="track mute set", track=track, value=value, - action=lambda valid_track, valid_value: get_client(ctx).track_mute_set( + fn=lambda client, valid_track, valid_value: client.track_mute_set( valid_track, valid_value, ), @@ -195,11 +189,11 @@ def solo_get( ctx: typer.Context, track: TrackArgument, ) -> None: - _execute_track_get( + run_track_command( ctx, - command="track solo get", + command_name="track solo get", track=track, - action=lambda valid_track: get_client(ctx).track_solo_get(valid_track), + fn=lambda client, valid_track: client.track_solo_get(valid_track), ) @@ -209,12 +203,12 @@ def solo_set( track: TrackArgument, value: Annotated[bool, typer.Argument(help="Solo value: true|false")], ) -> None: - _execute_track_set( + run_track_value_command( ctx, - command="track solo set", + command_name="track solo set", track=track, value=value, - action=lambda valid_track, valid_value: get_client(ctx).track_solo_set( + fn=lambda client, valid_track, valid_value: client.track_solo_set( valid_track, valid_value, ), @@ -226,11 +220,11 @@ def arm_get( ctx: typer.Context, track: TrackArgument, ) -> None: - _execute_track_get( + run_track_command( ctx, - command="track arm get", + command_name="track arm get", track=track, - action=lambda valid_track: get_client(ctx).track_arm_get(valid_track), + fn=lambda client, valid_track: client.track_arm_get(valid_track), ) @@ -240,12 +234,12 @@ def arm_set( track: TrackArgument, value: Annotated[bool, typer.Argument(help="Arm value: true|false")], ) -> None: - _execute_track_set( + run_track_value_command( ctx, - command="track arm set", + command_name="track arm set", track=track, value=value, - action=lambda valid_track, valid_value: get_client(ctx).track_arm_set( + fn=lambda client, valid_track, valid_value: client.track_arm_set( valid_track, valid_value, ), @@ -257,11 +251,11 @@ def panning_get( ctx: typer.Context, track: TrackArgument, ) -> None: - _execute_track_get( + run_track_command( ctx, - command="track panning get", + command_name="track panning get", track=track, - action=lambda valid_track: get_client(ctx).track_panning_get(valid_track), + fn=lambda client, valid_track: client.track_panning_get(valid_track), ) @@ -271,13 +265,13 @@ def panning_set( track: TrackArgument, value: PanningValueArgument, ) -> None: - _execute_track_set( + run_track_value_command( ctx, - command="track panning set", + command_name="track panning set", track=track, value=value, - validator=_require_panning_value, - action=lambda valid_track, valid_value: get_client(ctx).track_panning_set( + validators=[require_track_and_pan], + fn=lambda client, valid_track, valid_value: client.track_panning_set( valid_track, valid_value, ), diff --git a/tests/commands/test_effect_command_adapter.py b/tests/commands/test_effect_command_adapter.py new file mode 100644 index 0000000..8aba656 --- /dev/null +++ b/tests/commands/test_effect_command_adapter.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import pytest + +from ableton_cli.errors import AppError + + +def test_run_track_device_command_validates_indices_before_client_lookup(monkeypatch) -> None: + from ableton_cli.commands import effect + + get_client_calls = {"count": 0} + + def _get_client(_ctx): # noqa: ANN202 + get_client_calls["count"] += 1 + return object() + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del command, args, human_formatter + action() + + monkeypatch.setattr(effect, "get_client", _get_client) + monkeypatch.setattr(effect, "execute_command", _execute_command) + + with pytest.raises(AppError) as exc: + effect.run_track_device_command( + ctx=object(), + command_name="effect test", + track=1, + device=-1, + fn=lambda _client, _track, _device: {"ok": True}, + ) + + assert exc.value.message == "device must be >= 0, got -1" + assert get_client_calls["count"] == 0 + + +def test_run_track_device_command_applies_custom_validator(monkeypatch) -> None: + from ableton_cli.commands import effect + + captured: dict[str, object] = {} + client = object() + + def _get_client(_ctx): # noqa: ANN202 + return client + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del human_formatter + captured["command"] = command + captured["args"] = args + captured["result"] = action() + + def _validator(track_index: int, device_index: int) -> tuple[int, int]: + return track_index + 2, device_index + 2 + + monkeypatch.setattr(effect, "get_client", _get_client) + monkeypatch.setattr(effect, "execute_command", _execute_command) + + effect.run_track_device_command( + ctx=object(), + command_name="effect test", + track=0, + device=1, + fn=lambda resolved_client, valid_track, valid_device: { + "same_client": resolved_client is client, + "track": valid_track, + "device": valid_device, + }, + validators=[_validator], + ) + + assert captured["command"] == "effect test" + assert captured["args"] == {"track": 0, "device": 1} + assert captured["result"] == {"same_client": True, "track": 2, "device": 3} diff --git a/tests/commands/test_scenes_command_adapter.py b/tests/commands/test_scenes_command_adapter.py new file mode 100644 index 0000000..7778417 --- /dev/null +++ b/tests/commands/test_scenes_command_adapter.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import pytest + +from ableton_cli.errors import AppError + + +def test_run_scene_command_validates_scene_before_client_lookup(monkeypatch) -> None: + from ableton_cli.commands import scenes + + get_client_calls = {"count": 0} + + def _get_client(_ctx): # noqa: ANN202 + get_client_calls["count"] += 1 + return object() + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del command, args, human_formatter + action() + + monkeypatch.setattr(scenes, "get_client", _get_client) + monkeypatch.setattr(scenes, "execute_command", _execute_command) + + with pytest.raises(AppError) as exc: + scenes.run_scene_command( + ctx=object(), + command_name="scenes test", + scene=-1, + fn=lambda _client, _scene: {"ok": True}, + ) + + assert exc.value.message == "scene must be >= 0, got -1" + assert get_client_calls["count"] == 0 + + +def test_run_scene_value_command_applies_custom_validator(monkeypatch) -> None: + from ableton_cli.commands import scenes + + captured: dict[str, object] = {} + client = object() + + def _get_client(_ctx): # noqa: ANN202 + return client + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del human_formatter + captured["command"] = command + captured["args"] = args + captured["result"] = action() + + def _validator(scene_index: int, value: str) -> tuple[int, str]: + return scene_index + 1, value.strip().upper() + + monkeypatch.setattr(scenes, "get_client", _get_client) + monkeypatch.setattr(scenes, "execute_command", _execute_command) + + scenes.run_scene_value_command( + ctx=object(), + command_name="scenes test set", + scene=2, + value=" build ", + value_name="name", + fn=lambda resolved_client, valid_scene, valid_name: { + "same_client": resolved_client is client, + "scene": valid_scene, + "name": valid_name, + }, + validators=[_validator], + ) + + assert captured["command"] == "scenes test set" + assert captured["args"] == {"scene": 2, "name": " build "} + assert captured["result"] == {"same_client": True, "scene": 3, "name": "BUILD"} + + +def test_run_scene_move_command_validates_indices_before_client_lookup(monkeypatch) -> None: + from ableton_cli.commands import scenes + + get_client_calls = {"count": 0} + + def _get_client(_ctx): # noqa: ANN202 + get_client_calls["count"] += 1 + return object() + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del command, args, human_formatter + action() + + monkeypatch.setattr(scenes, "get_client", _get_client) + monkeypatch.setattr(scenes, "execute_command", _execute_command) + + with pytest.raises(AppError) as exc: + scenes.run_scene_move_command( + ctx=object(), + command_name="scenes move test", + from_scene=-1, + to_scene=1, + fn=lambda _client, _from_scene, _to_scene: {"ok": True}, + ) + + assert exc.value.message == "from must be >= 0, got -1" + assert get_client_calls["count"] == 0 diff --git a/tests/commands/test_synth_command_adapter.py b/tests/commands/test_synth_command_adapter.py new file mode 100644 index 0000000..7f83d3f --- /dev/null +++ b/tests/commands/test_synth_command_adapter.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import pytest + +from ableton_cli.errors import AppError + + +def test_run_track_device_command_validates_indices_before_client_lookup(monkeypatch) -> None: + from ableton_cli.commands import synth + + get_client_calls = {"count": 0} + + def _get_client(_ctx): # noqa: ANN202 + get_client_calls["count"] += 1 + return object() + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del command, args, human_formatter + action() + + monkeypatch.setattr(synth, "get_client", _get_client) + monkeypatch.setattr(synth, "execute_command", _execute_command) + + with pytest.raises(AppError) as exc: + synth.run_track_device_command( + ctx=object(), + command_name="synth test", + track=-1, + device=2, + fn=lambda _client, _track, _device: {"ok": True}, + ) + + assert exc.value.message == "track must be >= 0, got -1" + assert get_client_calls["count"] == 0 + + +def test_run_track_device_command_applies_custom_validator(monkeypatch) -> None: + from ableton_cli.commands import synth + + captured: dict[str, object] = {} + client = object() + + def _get_client(_ctx): # noqa: ANN202 + return client + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del human_formatter + captured["command"] = command + captured["args"] = args + captured["result"] = action() + + def _validator(track_index: int, device_index: int) -> tuple[int, int]: + return track_index + 1, device_index + 1 + + monkeypatch.setattr(synth, "get_client", _get_client) + monkeypatch.setattr(synth, "execute_command", _execute_command) + + synth.run_track_device_command( + ctx=object(), + command_name="synth test", + track=2, + device=3, + fn=lambda resolved_client, valid_track, valid_device: { + "same_client": resolved_client is client, + "track": valid_track, + "device": valid_device, + }, + validators=[_validator], + ) + + assert captured["command"] == "synth test" + assert captured["args"] == {"track": 2, "device": 3} + assert captured["result"] == {"same_client": True, "track": 3, "device": 4} diff --git a/tests/commands/test_track_command_adapter.py b/tests/commands/test_track_command_adapter.py new file mode 100644 index 0000000..d81ff86 --- /dev/null +++ b/tests/commands/test_track_command_adapter.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import pytest + +from ableton_cli.errors import AppError + + +def test_run_track_command_validates_track_before_client_lookup(monkeypatch) -> None: + from ableton_cli.commands import track + + get_client_calls = {"count": 0} + + def _get_client(_ctx): # noqa: ANN202 + get_client_calls["count"] += 1 + return object() + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del command, args, human_formatter + action() + + monkeypatch.setattr(track, "get_client", _get_client) + monkeypatch.setattr(track, "execute_command", _execute_command) + + with pytest.raises(AppError) as exc: + track.run_track_command( + ctx=object(), + command_name="track test", + track=-1, + fn=lambda _client, _track: {"ok": True}, + ) + + assert exc.value.message == "track must be >= 0, got -1" + assert get_client_calls["count"] == 0 + + +def test_run_track_value_command_applies_custom_validators(monkeypatch) -> None: + from ableton_cli.commands import track + + captured: dict[str, object] = {} + client = object() + + def _get_client(_ctx): # noqa: ANN202 + return client + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del human_formatter + captured["command"] = command + captured["args"] = args + captured["result"] = action() + + def _validator(track_index: int, value: float) -> tuple[int, float]: + return track_index + 1, value + 0.25 + + monkeypatch.setattr(track, "get_client", _get_client) + monkeypatch.setattr(track, "execute_command", _execute_command) + + track.run_track_value_command( + ctx=object(), + command_name="track test set", + track=1, + value=0.5, + fn=lambda resolved_client, valid_track, valid_value: { + "same_client": resolved_client is client, + "track": valid_track, + "value": valid_value, + }, + validators=[_validator], + ) + + assert captured["command"] == "track test set" + assert captured["args"] == {"track": 1, "value": 0.5} + assert captured["result"] == {"same_client": True, "track": 2, "value": 0.75} + + +def test_run_track_value_command_defaults_to_track_validation(monkeypatch) -> None: + from ableton_cli.commands import track + + get_client_calls = {"count": 0} + + def _get_client(_ctx): # noqa: ANN202 + get_client_calls["count"] += 1 + return object() + + def _execute_command(_ctx, *, command, args, action, human_formatter=None): # noqa: ANN202 + del command, args, human_formatter + action() + + monkeypatch.setattr(track, "get_client", _get_client) + monkeypatch.setattr(track, "execute_command", _execute_command) + + with pytest.raises(AppError) as exc: + track.run_track_value_command( + ctx=object(), + command_name="track test set", + track=-1, + value=True, + fn=lambda _client, _track, _value: {"ok": True}, + ) + + assert exc.value.message == "track must be >= 0, got -1" + assert get_client_calls["count"] == 0 diff --git a/tests/commands/test_validation.py b/tests/commands/test_validation.py index 52e5a2f..93686e5 100644 --- a/tests/commands/test_validation.py +++ b/tests/commands/test_validation.py @@ -98,3 +98,116 @@ def test_require_float_in_range_rejects_out_of_range_values() -> None: _assert_invalid_argument(exc) assert exc.value.message == "value must be between 0.0 and 1.0, got 1.2" assert exc.value.hint == "Use a normalized value such as 0.75." + + +def test_require_track_and_volume_accepts_valid_track_and_value() -> None: + from ableton_cli.commands._validation import require_track_and_volume + + assert require_track_and_volume(2, 0.75) == (2, 0.75) + + +def test_require_track_and_volume_rejects_out_of_range_value() -> None: + from ableton_cli.commands._validation import require_track_and_volume + + with pytest.raises(AppError) as exc: + require_track_and_volume(1, 1.1) + + _assert_invalid_argument(exc) + assert exc.value.message == "value must be between 0.0 and 1.0, got 1.1" + assert exc.value.hint == "Use a normalized volume value such as 0.75." + + +def test_require_track_and_pan_accepts_valid_track_and_value() -> None: + from ableton_cli.commands._validation import require_track_and_pan + + assert require_track_and_pan(1, -0.25) == (1, -0.25) + + +def test_require_track_and_pan_rejects_out_of_range_value() -> None: + from ableton_cli.commands._validation import require_track_and_pan + + with pytest.raises(AppError) as exc: + require_track_and_pan(0, -1.2) + + _assert_invalid_argument(exc) + assert exc.value.message == "value must be between -1.0 and 1.0, got -1.2" + assert exc.value.hint == "Use a normalized panning value such as -0.25." + + +def test_require_scene_and_name_accepts_valid_values() -> None: + from ableton_cli.commands._validation import require_scene_and_name + + assert require_scene_and_name(3, " Build ") == (3, "Build") + + +def test_require_scene_and_name_rejects_negative_scene() -> None: + from ableton_cli.commands._validation import require_scene_and_name + + with pytest.raises(AppError) as exc: + require_scene_and_name(-1, "Build") + + _assert_invalid_argument(exc) + assert exc.value.message == "scene must be >= 0, got -1" + assert exc.value.hint == "Use a valid scene index from 'scenes list'." + + +def test_require_scene_move_accepts_non_negative_values() -> None: + from ableton_cli.commands._validation import require_scene_move + + assert require_scene_move(2, 5) == (2, 5) + + +def test_require_scene_move_rejects_negative_source_index() -> None: + from ableton_cli.commands._validation import require_scene_move + + with pytest.raises(AppError) as exc: + require_scene_move(-1, 1) + + _assert_invalid_argument(exc) + assert exc.value.message == "from must be >= 0, got -1" + assert exc.value.hint == "Use a valid source scene index from 'scenes list'." + + +def test_require_scene_insert_index_rejects_values_below_minus_one() -> None: + from ableton_cli.commands._validation import require_scene_insert_index + + with pytest.raises(AppError) as exc: + require_scene_insert_index(-2) + + _assert_invalid_argument(exc) + assert exc.value.message == "index must be >= -1, got -2" + assert exc.value.hint == "Use -1 for append or a non-negative insertion index." + + +def test_require_optional_track_index_accepts_none() -> None: + from ableton_cli.commands._validation import require_optional_track_index + + assert require_optional_track_index(None) is None + + +def test_require_optional_track_index_rejects_negative_track() -> None: + from ableton_cli.commands._validation import require_optional_track_index + + with pytest.raises(AppError) as exc: + require_optional_track_index(-1) + + _assert_invalid_argument(exc) + assert exc.value.message == "track must be >= 0, got -1" + assert exc.value.hint == "Use a valid track index from 'ableton-cli tracks list'." + + +def test_require_track_and_device_accepts_non_negative_indices() -> None: + from ableton_cli.commands._validation import require_track_and_device + + assert require_track_and_device(1, 2) == (1, 2) + + +def test_require_track_and_device_rejects_negative_device_index() -> None: + from ableton_cli.commands._validation import require_track_and_device + + with pytest.raises(AppError) as exc: + require_track_and_device(1, -1) + + _assert_invalid_argument(exc) + assert exc.value.message == "device must be >= 0, got -1" + assert exc.value.hint == "Use a valid device index from 'ableton-cli track info'." From f596eb20d4f95234e5c441ff265a929d79385eb9 Mon Sep 17 00:00:00 2001 From: 6uclz1 <9139177+6uclz1@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:48:33 +0900 Subject: [PATCH 2/3] =?UTF-8?q?=E3=83=97=E3=83=AD=E3=83=88=E3=82=B3?= =?UTF-8?q?=E3=83=AB=E3=82=A8=E3=83=A9=E3=83=BC=E5=88=86=E9=A1=9E=E3=81=AE?= =?UTF-8?q?=E7=B2=92=E5=BA=A6=E6=94=B9=E5=96=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ableton_cli/client/protocol.py | 47 ++++++++++++++-------------- src/ableton_cli/errors.py | 2 ++ tests/test_exit_codes.py | 2 ++ tests/test_protocol.py | 50 +++++++++++++++++++++++++++++- 4 files changed, 77 insertions(+), 24 deletions(-) diff --git a/src/ableton_cli/client/protocol.py b/src/ableton_cli/client/protocol.py index 107664a..8ba5063 100644 --- a/src/ableton_cli/client/protocol.py +++ b/src/ableton_cli/client/protocol.py @@ -39,6 +39,15 @@ class Response: REQUIRED_RESPONSE_KEYS = {"ok", "request_id", "protocol_version"} +def _raise_protocol_error(error_code: str, message: str, hint: str) -> None: + raise AppError( + error_code=error_code, + message=message, + hint=hint, + exit_code=ExitCode.PROTOCOL_MISMATCH, + ) + + def make_request( name: str, args: dict[str, Any], @@ -60,26 +69,24 @@ def parse_response( ) -> Response: missing = REQUIRED_RESPONSE_KEYS.difference(payload) if missing: - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message=f"Invalid response payload, missing keys: {sorted(missing)}", hint="Ensure the Remote Script protocol implementation matches the CLI.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) response_protocol = payload.get("protocol_version") if not isinstance(response_protocol, int): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="protocol_version must be an integer", hint=( "Set matching protocol versions on both sides " "(--protocol-version or 'ableton-cli config set protocol_version ')." ), - exit_code=ExitCode.PROTOCOL_MISMATCH, ) if response_protocol != expected_protocol: - raise AppError( + _raise_protocol_error( error_code="PROTOCOL_VERSION_MISMATCH", message=( f"Protocol version mismatch (cli={expected_protocol}, remote={response_protocol})" @@ -88,51 +95,45 @@ def parse_response( "Align protocol_version in CLI and Remote Script " "(--protocol-version or 'ableton-cli config set protocol_version ')." ), - exit_code=ExitCode.PROTOCOL_MISMATCH, ) request_id = payload.get("request_id") if request_id != expected_request_id: - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_REQUEST_ID_MISMATCH", message=(f"request_id mismatch (expected={expected_request_id}, actual={request_id})"), hint="Check request routing in the Remote Script server.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) ok = payload.get("ok") if not isinstance(ok, bool): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'ok' must be a boolean in response payload", hint="Update Remote Script response format.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) result = payload.get("result") if result is not None and not isinstance(result, dict): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'result' must be an object when provided", hint="Return JSON object for result payloads.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) error = payload.get("error") if error is not None and not isinstance(error, dict): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'error' must be an object when provided", hint="Return structured error payload with code/message.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) if isinstance(error, dict) and "details" in error and error["details"] is not None: if not isinstance(error["details"], dict): - raise AppError( - error_code="PROTOCOL_VERSION_MISMATCH", + _raise_protocol_error( + error_code="PROTOCOL_INVALID_RESPONSE", message="'error.details' must be an object when provided", hint="Return structured error details as a JSON object.", - exit_code=ExitCode.PROTOCOL_MISMATCH, ) return Response( diff --git a/src/ableton_cli/errors.py b/src/ableton_cli/errors.py index 9031e40..5cec127 100644 --- a/src/ableton_cli/errors.py +++ b/src/ableton_cli/errors.py @@ -41,6 +41,8 @@ def to_payload(self) -> dict[str, Any]: "REMOTE_SCRIPT_NOT_INSTALLED": ExitCode.REMOTE_SCRIPT_NOT_DETECTED, "REMOTE_SCRIPT_INCOMPATIBLE": ExitCode.PROTOCOL_MISMATCH, "PROTOCOL_VERSION_MISMATCH": ExitCode.PROTOCOL_MISMATCH, + "PROTOCOL_INVALID_RESPONSE": ExitCode.PROTOCOL_MISMATCH, + "PROTOCOL_REQUEST_ID_MISMATCH": ExitCode.PROTOCOL_MISMATCH, "TIMEOUT": ExitCode.TIMEOUT, "BATCH_STEP_FAILED": ExitCode.EXECUTION_FAILED, "REMOTE_BUSY": ExitCode.EXECUTION_FAILED, diff --git a/tests/test_exit_codes.py b/tests/test_exit_codes.py index 229295e..8a33c4c 100644 --- a/tests/test_exit_codes.py +++ b/tests/test_exit_codes.py @@ -25,6 +25,8 @@ def test_remote_error_to_exit_code_mapping() -> None: ) assert exit_code_from_error_code("REMOTE_SCRIPT_INCOMPATIBLE") == ExitCode.PROTOCOL_MISMATCH assert exit_code_from_error_code("PROTOCOL_VERSION_MISMATCH") == ExitCode.PROTOCOL_MISMATCH + assert exit_code_from_error_code("PROTOCOL_INVALID_RESPONSE") == ExitCode.PROTOCOL_MISMATCH + assert exit_code_from_error_code("PROTOCOL_REQUEST_ID_MISMATCH") == ExitCode.PROTOCOL_MISMATCH assert exit_code_from_error_code("TIMEOUT") == ExitCode.TIMEOUT assert exit_code_from_error_code("BATCH_STEP_FAILED") == ExitCode.EXECUTION_FAILED assert exit_code_from_error_code("REMOTE_BUSY") == ExitCode.EXECUTION_FAILED diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 770b559..840e52e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -54,6 +54,54 @@ def test_parse_response_protocol_mismatch_raises() -> None: assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH +def test_parse_response_missing_keys_raises_invalid_response() -> None: + request = make_request(name="ping", args={}, protocol_version=2) + payload = { + "ok": True, + "request_id": request.request_id, + } + + with pytest.raises(AppError) as exc_info: + parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) + + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + +def test_parse_response_request_id_mismatch_raises() -> None: + request = make_request(name="ping", args={}, protocol_version=2) + payload = { + "ok": True, + "request_id": "other-request-id", + "protocol_version": 2, + "result": {"pong": True}, + "error": None, + } + + with pytest.raises(AppError) as exc_info: + parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) + + assert exc_info.value.error_code == "PROTOCOL_REQUEST_ID_MISMATCH" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + +def test_parse_response_rejects_non_integer_protocol_version() -> None: + request = make_request(name="ping", args={}, protocol_version=2) + payload = { + "ok": True, + "request_id": request.request_id, + "protocol_version": "2", + "result": {"pong": True}, + "error": None, + } + + with pytest.raises(AppError) as exc_info: + parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) + + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" + assert exc_info.value.exit_code == ExitCode.PROTOCOL_MISMATCH + + def test_parse_response_rejects_non_object_error_details() -> None: request = make_request(name="ping", args={}, protocol_version=2) payload = { @@ -67,7 +115,7 @@ def test_parse_response_rejects_non_object_error_details() -> None: with pytest.raises(AppError) as exc_info: parse_response(payload, expected_request_id=request.request_id, expected_protocol=2) - assert exc_info.value.error_code == "PROTOCOL_VERSION_MISMATCH" + assert exc_info.value.error_code == "PROTOCOL_INVALID_RESPONSE" def test_parse_response_accepts_error_details_object() -> None: From 1a2800930c1af6f078989b6ee317ade632b83f25 Mon Sep 17 00:00:00 2001 From: 6uclz1 <9139177+6uclz1@users.noreply.github.com> Date: Mon, 23 Feb 2026 13:01:06 +0900 Subject: [PATCH 3/3] reuse runtime client per context --- src/ableton_cli/runtime.py | 8 +++++++- tests/test_runtime_quiet.py | 20 +++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/ableton_cli/runtime.py b/src/ableton_cli/runtime.py index 9d79af3..39e2eee 100644 --- a/src/ableton_cli/runtime.py +++ b/src/ableton_cli/runtime.py @@ -28,6 +28,12 @@ class RuntimeContext: output_mode: OutputMode quiet: bool no_color: bool + _client: AbletonClient | None = None + + def client(self) -> AbletonClient: + if self._client is None: + self._client = AbletonClient(self.settings) + return self._client def get_runtime(ctx: typer.Context) -> RuntimeContext: @@ -39,7 +45,7 @@ def get_runtime(ctx: typer.Context) -> RuntimeContext: def get_client(ctx: typer.Context) -> AbletonClient: runtime = get_runtime(ctx) - return AbletonClient(runtime.settings) + return runtime.client() def execute_command( diff --git a/tests/test_runtime_quiet.py b/tests/test_runtime_quiet.py index ef736ff..a8a2cdd 100644 --- a/tests/test_runtime_quiet.py +++ b/tests/test_runtime_quiet.py @@ -8,7 +8,7 @@ from ableton_cli.config import Settings from ableton_cli.output import OutputMode -from ableton_cli.runtime import RuntimeContext, execute_command +from ableton_cli.runtime import RuntimeContext, execute_command, get_client def _context(*, quiet: bool) -> SimpleNamespace: @@ -65,3 +65,21 @@ def test_execute_command_not_quiet_emits_custom_human_formatter(monkeypatch) -> assert exc_info.value.exit_code == 0 assert len(emitted) == 1 assert emitted[0][0][0] == "Doctor Results" + + +def test_get_client_reuses_client_for_same_runtime(monkeypatch) -> None: + created_with: list[Settings] = [] + + class FakeClient: + def __init__(self, settings: Settings) -> None: + self.settings = settings + created_with.append(settings) + + monkeypatch.setattr("ableton_cli.runtime.AbletonClient", FakeClient) + + ctx = _context(quiet=False) + first = get_client(ctx) + second = get_client(ctx) + + assert first is second + assert created_with == [ctx.obj.settings]