From 0bfa7b477e3c123515f763816c23624e71c20317 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:06:35 +0800 Subject: [PATCH 01/11] Add minimal MCP JSON-RPC server and docs --- README.md | 12 +++ docs/getting-started.md | 16 ++++ pyproject.toml | 1 + src/keep_gpu/mcp/server.py | 149 +++++++++++++++++++++++++++++++++++++ tests/conftest.py | 5 -- tests/mcp/test_server.py | 50 +++++++++++++ 6 files changed, 228 insertions(+), 5 deletions(-) create mode 100644 src/keep_gpu/mcp/server.py create mode 100644 tests/mcp/test_server.py diff --git a/README.md b/README.md index d0a89a9..1e22000 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,18 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy - ROCm-only tests carry `@pytest.mark.rocm`; run with `pytest --run-rocm tests/rocm_controller`. - Markers: `rocm` (needs ROCm stack) and `large_memory` (opt-in locally). +### MCP endpoint (experimental) + +- Start a simple JSON-RPC server on stdin/stdout: + ```bash + keep-gpu-mcp-server + ``` +- Example request (one per line): + ```json + {"id": 1, "method": "start_keep", "params": {"gpu_ids": [0], "vram": "512MB", "interval": 60, "busy_threshold": 20}} + ``` +- Methods: `start_keep`, `stop_keep` (optional `job_id`, default stops all), `status` (optional `job_id`). + ## Contributing Contributions are welcome—especially around ROCm support, platform fallbacks, and scheduler-specific recipes. Open an issue or PR if you hit edge cases on your cluster. diff --git a/docs/getting-started.md b/docs/getting-started.md index f3949fe..6fcc6e5 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -45,6 +45,22 @@ understand the minimum knobs you need to keep a GPU occupied. - Fast CUDA checks: `pytest tests/cuda_controller tests/global_controller tests/utilities/test_platform_manager.py tests/test_cli_thresholds.py` - ROCm-only tests are marked `rocm`; run with `pytest --run-rocm tests/rocm_controller`. +## MCP endpoint (experimental) + +For automation clients that speak JSON-RPC (MCP-style), KeepGPU ships a tiny +stdin/stdout server: + +```bash +keep-gpu-mcp-server +# each request is a single JSON line; example: +echo '{"id":1,"method":"start_keep","params":{"gpu_ids":[0],"vram":"512MB","interval":60,"busy_threshold":20}}' | keep-gpu-mcp-server +``` + +Supported methods: +- `start_keep(gpu_ids?, vram?, interval?, busy_threshold?, job_id?)` +- `status(job_id?)` +- `stop_keep(job_id?)` (no job_id stops all) + === "Editable dev install" ```bash git clone https://github.com/Wangmerlyn/KeepGPU.git diff --git a/pyproject.toml b/pyproject.toml index dedf135..c5b859b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ [project.scripts] keep-gpu = "keep_gpu.cli:app" +keep-gpu-mcp-server = "keep_gpu.mcp.server:main" [project.optional-dependencies] dev = [ diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py new file mode 100644 index 0000000..9bde14f --- /dev/null +++ b/src/keep_gpu/mcp/server.py @@ -0,0 +1,149 @@ +""" +Minimal MCP-style JSON-RPC server for KeepGPU. + +The server reads JSON lines from stdin and writes JSON responses to stdout. +Supported methods: + - start_keep(gpu_ids, vram, interval, busy_threshold, job_id) + - stop_keep(job_id=None) # None stops all + - status(job_id=None) # None lists all +""" + +from __future__ import annotations + +import atexit +import json +import sys +import uuid +from dataclasses import dataclass, asdict +from typing import Any, Callable, Dict, List, Optional + +from keep_gpu.global_gpu_controller.global_gpu_controller import GlobalGPUController +from keep_gpu.utilities.logger import setup_logger + +logger = setup_logger(__name__) + + +@dataclass +class Session: + controller: GlobalGPUController + params: Dict[str, Any] + + +class KeepGPUServer: + def __init__( + self, + controller_factory: Optional[Callable[..., GlobalGPUController]] = None, + ) -> None: + self._sessions: Dict[str, Session] = {} + self._controller_factory = controller_factory or GlobalGPUController + atexit.register(self.shutdown) + + def start_keep( + self, + gpu_ids: Optional[List[int]] = None, + vram: str = "1GiB", + interval: int = 300, + busy_threshold: int = -1, + job_id: Optional[str] = None, + ) -> Dict[str, Any]: + job_id = job_id or str(uuid.uuid4()) + if job_id in self._sessions: + raise ValueError(f"job_id {job_id} already exists") + + controller = self._controller_factory( + gpu_ids=gpu_ids, + interval=interval, + vram_to_keep=vram, + busy_threshold=busy_threshold, + ) + controller.keep() + self._sessions[job_id] = Session( + controller=controller, + params={ + "gpu_ids": gpu_ids, + "vram": vram, + "interval": interval, + "busy_threshold": busy_threshold, + }, + ) + logger.info("Started keep session %s on GPUs %s", job_id, gpu_ids) + return {"job_id": job_id} + + def stop_keep(self, job_id: Optional[str] = None) -> Dict[str, Any]: + if job_id: + session = self._sessions.pop(job_id, None) + if session: + session.controller.release() + logger.info("Stopped keep session %s", job_id) + return {"stopped": [job_id]} + return {"stopped": [], "message": "job_id not found"} + + stopped: List[str] = [] + for jid, session in list(self._sessions.items()): + session.controller.release() + stopped.append(jid) + del self._sessions[jid] + if stopped: + logger.info("Stopped sessions: %s", stopped) + return {"stopped": stopped} + + def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: + if job_id: + session = self._sessions.get(job_id) + if not session: + return {"active": False, "job_id": job_id} + return { + "active": True, + "job_id": job_id, + "params": session.params, + } + return { + "active_jobs": [ + {"job_id": jid, **asdict(sess)} for jid, sess in self._sessions.items() + ] + } + + def shutdown(self) -> None: + try: + self.stop_keep(None) + except Exception: # pragma: no cover - defensive + # Avoid noisy errors during interpreter teardown + return + + +def _handle_request(server: KeepGPUServer, payload: Dict[str, Any]) -> Dict[str, Any]: + method = payload.get("method") + params = payload.get("params", {}) or {} + req_id = payload.get("id") + try: + if method == "start_keep": + result = server.start_keep(**params) + elif method == "stop_keep": + result = server.stop_keep(**params) + elif method == "status": + result = server.status(**params) + else: + raise ValueError(f"Unknown method: {method}") + return {"id": req_id, "result": result} + except Exception as exc: # pragma: no cover - defensive + logger.exception("Request failed") + return {"id": req_id, "error": {"message": str(exc)}} + + +def main() -> None: + server = KeepGPUServer() + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + payload = json.loads(line) + response = _handle_request(server, payload) + except Exception as exc: + response = {"error": {"message": str(exc)}} + sys.stdout.write(json.dumps(response) + "\n") + sys.stdout.flush() + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index 22d1680..33fd177 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,6 @@ def pytest_addoption(parser): ) -def pytest_configure(config): - config.addinivalue_line("markers", "rocm: tests that require ROCm stack") - config.addinivalue_line("markers", "large_memory: tests that use large VRAM") - - def pytest_collection_modifyitems(config, items): if config.getoption("--run-rocm"): return diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py new file mode 100644 index 0000000..4834e7c --- /dev/null +++ b/tests/mcp/test_server.py @@ -0,0 +1,50 @@ +from keep_gpu.mcp.server import KeepGPUServer + + +class DummyController: + def __init__(self, gpu_ids=None, interval=0, vram_to_keep=None, busy_threshold=0): + self.gpu_ids = gpu_ids + self.interval = interval + self.vram_to_keep = vram_to_keep + self.busy_threshold = busy_threshold + self.kept = False + self.released = False + + def keep(self): + self.kept = True + + def release(self): + self.released = True + + +def dummy_factory(**kwargs): + return DummyController(**kwargs) + + +def test_start_status_stop_cycle(): + server = KeepGPUServer(controller_factory=dummy_factory) + res = server.start_keep(gpu_ids=[1], vram="2GiB", interval=5, busy_threshold=20) + job_id = res["job_id"] + + status = server.status(job_id) + assert status["active"] + assert status["params"]["gpu_ids"] == [1] + assert status["params"]["vram"] == "2GiB" + assert status["params"]["interval"] == 5 + assert status["params"]["busy_threshold"] == 20 + + stopped = server.stop_keep(job_id) + assert job_id in stopped["stopped"] + assert server.status(job_id)["active"] is False + + +def test_stop_all(): + server = KeepGPUServer(controller_factory=dummy_factory) + job_a = server.start_keep()["job_id"] + job_b = server.start_keep()["job_id"] + + stopped = server.stop_keep() + assert set(stopped["stopped"]) == {job_a, job_b} + assert server.status(job_a)["active"] is False + assert server.status(job_b)["active"] is False + From 9f8ada34e79820a3c118014996d1fcbaf2acb302 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:17:52 +0800 Subject: [PATCH 02/11] Add GPU listing method to MCP server and docs --- README.md | 2 +- docs/getting-started.md | 1 + src/keep_gpu/mcp/server.py | 14 ++++++++++++++ tests/mcp/test_server.py | 6 ++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1e22000..e71aece 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy ```json {"id": 1, "method": "start_keep", "params": {"gpu_ids": [0], "vram": "512MB", "interval": 60, "busy_threshold": 20}} ``` -- Methods: `start_keep`, `stop_keep` (optional `job_id`, default stops all), `status` (optional `job_id`). +- Methods: `start_keep`, `stop_keep` (optional `job_id`, default stops all), `status` (optional `job_id`), `list_gpus` (basic info). ## Contributing diff --git a/docs/getting-started.md b/docs/getting-started.md index 6fcc6e5..dd30e44 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -60,6 +60,7 @@ Supported methods: - `start_keep(gpu_ids?, vram?, interval?, busy_threshold?, job_id?)` - `status(job_id?)` - `stop_keep(job_id?)` (no job_id stops all) +- `list_gpus()` (basic info) === "Editable dev install" ```bash diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index 9bde14f..48b2e2a 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -103,6 +103,18 @@ def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: ] } + def list_gpus(self) -> Dict[str, Any]: + """Return basic GPU info using torch (count + names).""" + try: + import torch + + count = torch.cuda.device_count() + names = [torch.cuda.get_device_name(i) for i in range(count)] + return {"count": count, "names": names} + except Exception as exc: # pragma: no cover - env-specific + logger.debug("list_gpus failed: %s", exc) + return {"count": 0, "names": [], "error": str(exc)} + def shutdown(self) -> None: try: self.stop_keep(None) @@ -122,6 +134,8 @@ def _handle_request(server: KeepGPUServer, payload: Dict[str, Any]) -> Dict[str, result = server.stop_keep(**params) elif method == "status": result = server.status(**params) + elif method == "list_gpus": + result = server.list_gpus() else: raise ValueError(f"Unknown method: {method}") return {"id": req_id, "result": result} diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 4834e7c..999db8d 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -48,3 +48,9 @@ def test_stop_all(): assert server.status(job_a)["active"] is False assert server.status(job_b)["active"] is False + +def test_list_gpus(): + server = KeepGPUServer(controller_factory=dummy_factory) + info = server.list_gpus() + assert "count" in info + assert "names" in info From 58169dd2f4bf21e1f9404a6e3797cdd637a3f43b Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:31:43 +0800 Subject: [PATCH 03/11] Add NVML-backed GPU info to MCP server --- src/keep_gpu/mcp/server.py | 14 ++--- src/keep_gpu/utilities/gpu_info.py | 91 ++++++++++++++++++++++++++++++ tests/mcp/test_server.py | 3 +- tests/utilities/test_gpu_info.py | 59 +++++++++++++++++++ 4 files changed, 155 insertions(+), 12 deletions(-) create mode 100644 src/keep_gpu/utilities/gpu_info.py create mode 100644 tests/utilities/test_gpu_info.py diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index 48b2e2a..e8a143b 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional from keep_gpu.global_gpu_controller.global_gpu_controller import GlobalGPUController +from keep_gpu.utilities.gpu_info import get_gpu_info from keep_gpu.utilities.logger import setup_logger logger = setup_logger(__name__) @@ -104,16 +105,9 @@ def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: } def list_gpus(self) -> Dict[str, Any]: - """Return basic GPU info using torch (count + names).""" - try: - import torch - - count = torch.cuda.device_count() - names = [torch.cuda.get_device_name(i) for i in range(count)] - return {"count": count, "names": names} - except Exception as exc: # pragma: no cover - env-specific - logger.debug("list_gpus failed: %s", exc) - return {"count": 0, "names": [], "error": str(exc)} + """Return detailed GPU info (id, name, memory, utilization).""" + infos = get_gpu_info() + return {"gpus": infos} def shutdown(self) -> None: try: diff --git a/src/keep_gpu/utilities/gpu_info.py b/src/keep_gpu/utilities/gpu_info.py new file mode 100644 index 0000000..6753a82 --- /dev/null +++ b/src/keep_gpu/utilities/gpu_info.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import Any, Dict, List + +import torch + +from keep_gpu.utilities.logger import setup_logger + +logger = setup_logger(__name__) + + +def _query_nvml() -> List[Dict[str, Any]]: + import pynvml + + pynvml.nvmlInit() + infos: List[Dict[str, Any]] = [] + try: + count = pynvml.nvmlDeviceGetCount() + for idx in range(count): + handle = pynvml.nvmlDeviceGetHandleByIndex(idx) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + util = pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + name = pynvml.nvmlDeviceGetName(handle) + if isinstance(name, bytes): + name = name.decode(errors="ignore") + infos.append( + { + "id": idx, + "platform": "cuda", + "name": name, + "memory_total": int(mem.total), + "memory_used": int(mem.used), + "utilization": int(util), + } + ) + finally: + try: + pynvml.nvmlShutdown() + except Exception: + pass + return infos + + +def _query_torch() -> List[Dict[str, Any]]: + infos: List[Dict[str, Any]] = [] + if not torch.cuda.is_available(): + return infos + try: + count = torch.cuda.device_count() + for idx in range(count): + torch.cuda.set_device(idx) + try: + free, total = torch.cuda.mem_get_info() + used = total - free + except Exception: + total = used = None + try: + name = torch.cuda.get_device_name(idx) + except Exception: + name = f"cuda:{idx}" + infos.append( + { + "id": idx, + "platform": "cuda" if torch.version.hip is None else "rocm", + "name": name, + "memory_total": int(total) if total is not None else None, + "memory_used": int(used) if used is not None else None, + "utilization": None, + } + ) + except Exception as exc: # pragma: no cover - defensive + logger.debug("Torch GPU info failed: %s", exc) + return infos + + +def get_gpu_info() -> List[Dict[str, Any]]: + """ + Return a list of GPU info dicts: id, platform, name, memory_total, memory_used, utilization. + Tries NVML first (CUDA), then falls back to torch.cuda data. + """ + try: + infos = _query_nvml() + if infos: + return infos + except Exception as exc: + logger.debug("NVML info failed: %s", exc) + + return _query_torch() + + +__all__ = ["get_gpu_info"] diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 999db8d..765874c 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -52,5 +52,4 @@ def test_stop_all(): def test_list_gpus(): server = KeepGPUServer(controller_factory=dummy_factory) info = server.list_gpus() - assert "count" in info - assert "names" in info + assert "gpus" in info diff --git a/tests/utilities/test_gpu_info.py b/tests/utilities/test_gpu_info.py new file mode 100644 index 0000000..3a0c6cb --- /dev/null +++ b/tests/utilities/test_gpu_info.py @@ -0,0 +1,59 @@ +import sys + +from keep_gpu.utilities import gpu_info + + +class DummyNVMLMemory: + def __init__(self, total: int, used: int): + self.total = total + self.used = used + + +class DummyNVMLUtil: + def __init__(self, gpu: int): + self.gpu = gpu + + +def test_get_gpu_info_nvml(monkeypatch): + class DummyNVML: + def __init__(self): + self.shutdown_calls = 0 + + @staticmethod + def nvmlInit(): + return None + + @staticmethod + def nvmlDeviceGetCount(): + return 1 + + @staticmethod + def nvmlDeviceGetHandleByIndex(index): + assert index == 0 + return "handle" + + @staticmethod + def nvmlDeviceGetMemoryInfo(handle): + return DummyNVMLMemory(total=2048, used=1024) + + @staticmethod + def nvmlDeviceGetUtilizationRates(handle): + return DummyNVMLUtil(gpu=55) + + @staticmethod + def nvmlDeviceGetName(handle): + return b"Mock GPU" + + def nvmlShutdown(self): + self.shutdown_calls += 1 + + dummy_nvml = DummyNVML() + monkeypatch.setitem(sys.modules, "pynvml", dummy_nvml) + + infos = gpu_info.get_gpu_info() + assert len(infos) == 1 + info = infos[0] + assert info["name"] == "Mock GPU" + assert info["memory_total"] == 2048 + assert info["memory_used"] == 1024 + assert info["utilization"] == 55 From 14002d996a1249a4bdd9f1352f87e4ad89248857 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:36:31 +0800 Subject: [PATCH 04/11] Support ROCm SMI in GPU info listing --- src/keep_gpu/utilities/gpu_info.py | 58 ++++++++++++++++++++++++++++- tests/utilities/test_gpu_info.py | 60 ++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/src/keep_gpu/utilities/gpu_info.py b/src/keep_gpu/utilities/gpu_info.py index 6753a82..6a964ef 100644 --- a/src/keep_gpu/utilities/gpu_info.py +++ b/src/keep_gpu/utilities/gpu_info.py @@ -41,6 +41,55 @@ def _query_nvml() -> List[Dict[str, Any]]: return infos +def _query_rocm() -> List[Dict[str, Any]]: + try: + import rocm_smi # type: ignore + except Exception as exc: # pragma: no cover - env-specific + logger.debug("rocm_smi import failed: %s", exc) + return [] + + infos: List[Dict[str, Any]] = [] + try: + rocm_smi.rsmi_init() + # Use torch to enumerate devices for names/memory + count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + for idx in range(count): + util = None + try: + util = int(rocm_smi.rsmi_dev_busy_percent_get(idx)) + except Exception as exc: + logger.debug("ROCm util query failed for %s: %s", idx, exc) + + try: + torch.cuda.set_device(idx) + free, total = torch.cuda.mem_get_info() + used = total - free + except Exception: + total = used = None + + try: + name = torch.cuda.get_device_name(idx) + except Exception: + name = f"rocm:{idx}" + + infos.append( + { + "id": idx, + "platform": "rocm", + "name": name, + "memory_total": int(total) if total is not None else None, + "memory_used": int(used) if used is not None else None, + "utilization": util, + } + ) + finally: + try: + rocm_smi.rsmi_shut_down() + except Exception: + pass + return infos + + def _query_torch() -> List[Dict[str, Any]]: infos: List[Dict[str, Any]] = [] if not torch.cuda.is_available(): @@ -76,7 +125,7 @@ def _query_torch() -> List[Dict[str, Any]]: def get_gpu_info() -> List[Dict[str, Any]]: """ Return a list of GPU info dicts: id, platform, name, memory_total, memory_used, utilization. - Tries NVML first (CUDA), then falls back to torch.cuda data. + Tries NVML first (CUDA), then ROCm SMI, then falls back to torch.cuda data. """ try: infos = _query_nvml() @@ -85,6 +134,13 @@ def get_gpu_info() -> List[Dict[str, Any]]: except Exception as exc: logger.debug("NVML info failed: %s", exc) + try: + infos = _query_rocm() + if infos: + return infos + except Exception as exc: + logger.debug("ROCm info failed: %s", exc) + return _query_torch() diff --git a/tests/utilities/test_gpu_info.py b/tests/utilities/test_gpu_info.py index 3a0c6cb..d0d74ea 100644 --- a/tests/utilities/test_gpu_info.py +++ b/tests/utilities/test_gpu_info.py @@ -57,3 +57,63 @@ def nvmlShutdown(self): assert info["memory_total"] == 2048 assert info["memory_used"] == 1024 assert info["utilization"] == 55 + + +def test_get_gpu_info_rocm(monkeypatch): + # remove nvml so ROCm path is used + monkeypatch.setitem(sys.modules, "pynvml", None) + + class DummyTorchCuda: + @staticmethod + def is_available(): + return True + + @staticmethod + def device_count(): + return 1 + + @staticmethod + def mem_get_info(): + return (50, 100) + + @staticmethod + def get_device_name(idx): + return f"ROCm {idx}" + + @staticmethod + def set_device(idx): + return None + + monkeypatch.setattr( + gpu_info, + "torch", + type( + "T", (), {"cuda": DummyTorchCuda, "version": type("V", (), {"hip": "6.0"})} + ), + ) + + class DummyROCM: + calls = 0 + + @classmethod + def rsmi_init(cls): + cls.calls += 1 + + @classmethod + def rsmi_dev_busy_percent_get(cls, idx): + assert idx == 0 + return 77 + + @classmethod + def rsmi_shut_down(cls): + cls.calls += 1 + + monkeypatch.setitem(sys.modules, "rocm_smi", DummyROCM) + + infos = gpu_info.get_gpu_info() + assert len(infos) == 1 + info = infos[0] + assert info["platform"] == "rocm" + assert info["utilization"] == 77 + assert info["memory_total"] == 100 + assert info["memory_used"] == 50 From 0c70e08c1794ec62aa57b1a61fcc7db053baccc8 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:42:27 +0800 Subject: [PATCH 05/11] Harden GPU info tests and JSON-RPC flow --- tests/mcp/test_server.py | 25 ++++++++++++++++++++++++- tests/utilities/test_gpu_info.py | 7 +++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 765874c..d5ad2b8 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -1,4 +1,4 @@ -from keep_gpu.mcp.server import KeepGPUServer +from keep_gpu.mcp.server import KeepGPUServer, _handle_request class DummyController: @@ -53,3 +53,26 @@ def test_list_gpus(): server = KeepGPUServer(controller_factory=dummy_factory) info = server.list_gpus() assert "gpus" in info + + +def test_end_to_end_jsonrpc(): + server = KeepGPUServer(controller_factory=dummy_factory) + # start_keep + req = { + "id": 1, + "method": "start_keep", + "params": {"gpu_ids": [0], "vram": "256MB", "interval": 1, "busy_threshold": 5}, + } + resp = _handle_request(server, req) + assert "result" in resp and "job_id" in resp["result"] + job_id = resp["result"]["job_id"] + + # status + status_req = {"id": 2, "method": "status", "params": {"job_id": job_id}} + status_resp = _handle_request(server, status_req) + assert status_resp["result"]["active"] is True + + # stop_keep + stop_req = {"id": 3, "method": "stop_keep", "params": {"job_id": job_id}} + stop_resp = _handle_request(server, stop_req) + assert job_id in stop_resp["result"]["stopped"] diff --git a/tests/utilities/test_gpu_info.py b/tests/utilities/test_gpu_info.py index d0d74ea..f07eae9 100644 --- a/tests/utilities/test_gpu_info.py +++ b/tests/utilities/test_gpu_info.py @@ -1,5 +1,7 @@ import sys +import pytest + from keep_gpu.utilities import gpu_info @@ -14,6 +16,10 @@ def __init__(self, gpu: int): self.gpu = gpu +@pytest.mark.skipif( + not hasattr(gpu_info, "torch") or not gpu_info.torch.cuda.is_available(), + reason="CUDA not available for NVML path", +) def test_get_gpu_info_nvml(monkeypatch): class DummyNVML: def __init__(self): @@ -59,6 +65,7 @@ def nvmlShutdown(self): assert info["utilization"] == 55 +@pytest.mark.rocm def test_get_gpu_info_rocm(monkeypatch): # remove nvml so ROCm path is used monkeypatch.setitem(sys.modules, "pynvml", None) From da690f919536c3da2c8df5f1f43d43dd8bb3846a Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:50:51 +0800 Subject: [PATCH 06/11] Add contributor guide and link from docs --- README.md | 1 + docs/contributing.md | 62 ++++++++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 2 ++ 3 files changed, 65 insertions(+) create mode 100644 docs/contributing.md diff --git a/README.md b/README.md index e71aece..00208b9 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy ## Contributing Contributions are welcome—especially around ROCm support, platform fallbacks, and scheduler-specific recipes. Open an issue or PR if you hit edge cases on your cluster. +See `docs/contributing.md` for dev setup, test commands, and PR tips. ## Credits diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..5cd5d3f --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,62 @@ +# Contributing & Development + +Thanks for helping improve KeepGPU! This page collects the key commands and +expectations so you can get productive quickly and avoid surprises in CI. + +## Setup + +- Clone and install dev extras: + ```bash + git clone https://github.com/Wangmerlyn/KeepGPU.git + cd KeepGPU + pip install -e ".[dev]" # add .[rocm] if you need ROCm SMI + ``` +- Ensure you have the right torch build for your platform (CUDA/ROCm/CPU). +- Optional: install `nvidia-ml-py` (CUDA) or `rocm-smi` (ROCm) for telemetry. + +## Tests + +- Fast CUDA suite: + ```bash + pytest tests/cuda_controller tests/global_controller \ + tests/utilities/test_platform_manager.py tests/test_cli_thresholds.py + ``` +- ROCm-only tests are marked `rocm` and skipped by default; run with: + ```bash + pytest --run-rocm tests/rocm_controller + ``` +- MCP + utilities: + ```bash + pytest tests/mcp tests/utilities/test_gpu_info.py + ``` +- All tests honor markers `rocm` and `large_memory`; avoid enabling + `large_memory` in CI. + +## Lint/format + +- Run pre-commit hooks locally before pushing: + ```bash + pre-commit run --all-files + ``` + +## MCP server (experimental) + +- Start: `keep-gpu-mcp-server` (stdin/stdout JSON-RPC) +- Methods: `start_keep`, `stop_keep`, `status`, `list_gpus` +- Example request: + ```json + {"id":1,"method":"start_keep","params":{"gpu_ids":[0],"vram":"512MB","interval":60,"busy_threshold":20}} + ``` + +## Pull requests + +- Keep changesets focused; small commits are welcome. +- Add/adjust tests for new behavior; skip GPU-specific tests in CI via markers. +- Update docs/README when behavior or interfaces change. +- Stick to the existing style (Typer CLI, Rich logging) and keep code paths + simple—avoid over-engineering. + +## Support + +- Issues/PRs: https://github.com/Wangmerlyn/KeepGPU +- Code of Conduct: see `CODE_OF_CONDUCT.rst` diff --git a/mkdocs.yml b/mkdocs.yml index 6c3a200..d26c9b8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -23,6 +23,8 @@ nav: - Reference: - CLI Reference: reference/cli.md - API Reference: reference/api.md + - Project: + - Contributing: contributing.md plugins: - search From b73e53251be8356fef40f58652c6380538f23d8f Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 13:55:00 +0800 Subject: [PATCH 07/11] Add MCP client config examples --- README.md | 7 +++++++ docs/contributing.md | 2 +- docs/getting-started.md | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 00208b9..a92cc77 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,13 @@ with GlobalGPUController(gpu_ids=[0, 1], vram_to_keep="750MB", interval=90, busy {"id": 1, "method": "start_keep", "params": {"gpu_ids": [0], "vram": "512MB", "interval": 60, "busy_threshold": 20}} ``` - Methods: `start_keep`, `stop_keep` (optional `job_id`, default stops all), `status` (optional `job_id`), `list_gpus` (basic info). +- Minimal client config (stdio MCP): + ```yaml + servers: + keepgpu: + command: ["keep-gpu-mcp-server"] + adapter: stdio + ``` ## Contributing diff --git a/docs/contributing.md b/docs/contributing.md index 5cd5d3f..9754a33 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -51,7 +51,7 @@ expectations so you can get productive quickly and avoid surprises in CI. ## Pull requests - Keep changesets focused; small commits are welcome. -- Add/adjust tests for new behavior; skip GPU-specific tests in CI via markers. +- Add/adjust tests for new behavior; skip GPU-specific tests in CI by way of markers. - Update docs/README when behavior or interfaces change. - Stick to the existing style (Typer CLI, Rich logging) and keep code paths simple—avoid over-engineering. diff --git a/docs/getting-started.md b/docs/getting-started.md index dd30e44..00b8f55 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -62,6 +62,21 @@ Supported methods: - `stop_keep(job_id?)` (no job_id stops all) - `list_gpus()` (basic info) +### Example MCP client config (stdio) + +If your agent expects an MCP server definition, a minimal stdio config looks like: + +```yaml +servers: + keepgpu: + description: "KeepGPU MCP server" + command: ["keep-gpu-mcp-server"] + adapter: stdio +``` + +Tools exposed: `start_keep`, `stop_keep`, `status`, `list_gpus`. Each request is +a single JSON line; see above for an example payload. + === "Editable dev install" ```bash git clone https://github.com/Wangmerlyn/KeepGPU.git From 69c64dad7396fa36291c9f252f93e7420d3ad41c Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 14:15:36 +0800 Subject: [PATCH 08/11] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/keep_gpu/mcp/server.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index e8a143b..cc8af1e 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -79,14 +79,13 @@ def stop_keep(self, job_id: Optional[str] = None) -> Dict[str, Any]: return {"stopped": [job_id]} return {"stopped": [], "message": "job_id not found"} - stopped: List[str] = [] - for jid, session in list(self._sessions.items()): + stopped_ids = list(self._sessions.keys()) + for job_id in stopped_ids: + session = self._sessions.pop(job_id) session.controller.release() - stopped.append(jid) - del self._sessions[jid] - if stopped: - logger.info("Stopped sessions: %s", stopped) - return {"stopped": stopped} + if stopped_ids: + logger.info("Stopped sessions: %s", stopped_ids) + return {"stopped": stopped_ids} def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: if job_id: @@ -100,7 +99,7 @@ def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: } return { "active_jobs": [ - {"job_id": jid, **asdict(sess)} for jid, sess in self._sessions.items() + {"job_id": jid, "params": sess.params} for jid, sess in self._sessions.items() ] } From c4ff82d0fa3f27849641a2f6f0faeea528dcc68e Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 14:21:23 +0800 Subject: [PATCH 09/11] Address MCP review: status all and restore CUDA device --- src/keep_gpu/mcp/server.py | 16 ++++++++++------ src/keep_gpu/utilities/gpu_info.py | 14 ++++++++++++++ tests/mcp/test_server.py | 17 +++++++++++++++++ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index e8a143b..9da82a7 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -14,7 +14,7 @@ import json import sys import uuid -from dataclasses import dataclass, asdict +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional from keep_gpu.global_gpu_controller.global_gpu_controller import GlobalGPUController @@ -70,12 +70,15 @@ def start_keep( logger.info("Started keep session %s on GPUs %s", job_id, gpu_ids) return {"job_id": job_id} - def stop_keep(self, job_id: Optional[str] = None) -> Dict[str, Any]: + def stop_keep( + self, job_id: Optional[str] = None, *, quiet: bool = False + ) -> Dict[str, Any]: if job_id: session = self._sessions.pop(job_id, None) if session: session.controller.release() - logger.info("Stopped keep session %s", job_id) + if not quiet: + logger.info("Stopped keep session %s", job_id) return {"stopped": [job_id]} return {"stopped": [], "message": "job_id not found"} @@ -84,7 +87,7 @@ def stop_keep(self, job_id: Optional[str] = None) -> Dict[str, Any]: session.controller.release() stopped.append(jid) del self._sessions[jid] - if stopped: + if stopped and not quiet: logger.info("Stopped sessions: %s", stopped) return {"stopped": stopped} @@ -100,7 +103,8 @@ def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: } return { "active_jobs": [ - {"job_id": jid, **asdict(sess)} for jid, sess in self._sessions.items() + {"job_id": jid, "params": sess.params} + for jid, sess in self._sessions.items() ] } @@ -111,7 +115,7 @@ def list_gpus(self) -> Dict[str, Any]: def shutdown(self) -> None: try: - self.stop_keep(None) + self.stop_keep(None, quiet=True) except Exception: # pragma: no cover - defensive # Avoid noisy errors during interpreter teardown return diff --git a/src/keep_gpu/utilities/gpu_info.py b/src/keep_gpu/utilities/gpu_info.py index 6a964ef..babf540 100644 --- a/src/keep_gpu/utilities/gpu_info.py +++ b/src/keep_gpu/utilities/gpu_info.py @@ -49,8 +49,11 @@ def _query_rocm() -> List[Dict[str, Any]]: return [] infos: List[Dict[str, Any]] = [] + current_device = None try: rocm_smi.rsmi_init() + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() # Use torch to enumerate devices for names/memory count = torch.cuda.device_count() if torch.cuda.is_available() else 0 for idx in range(count): @@ -83,6 +86,11 @@ def _query_rocm() -> List[Dict[str, Any]]: } ) finally: + if current_device is not None: + try: + torch.cuda.set_device(current_device) + except Exception: + pass try: rocm_smi.rsmi_shut_down() except Exception: @@ -94,6 +102,7 @@ def _query_torch() -> List[Dict[str, Any]]: infos: List[Dict[str, Any]] = [] if not torch.cuda.is_available(): return infos + current_device = torch.cuda.current_device() try: count = torch.cuda.device_count() for idx in range(count): @@ -119,6 +128,11 @@ def _query_torch() -> List[Dict[str, Any]]: ) except Exception as exc: # pragma: no cover - defensive logger.debug("Torch GPU info failed: %s", exc) + finally: + try: + torch.cuda.set_device(current_device) + except Exception: + pass return infos diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index d5ad2b8..494deb4 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -76,3 +76,20 @@ def test_end_to_end_jsonrpc(): stop_req = {"id": 3, "method": "stop_keep", "params": {"job_id": job_id}} stop_resp = _handle_request(server, stop_req) assert job_id in stop_resp["result"]["stopped"] + + +def test_status_all(): + server = KeepGPUServer(controller_factory=dummy_factory) + job_a = server.start_keep(gpu_ids=[0])["job_id"] + job_b = server.start_keep(gpu_ids=[1])["job_id"] + + status = server.status() + assert "active_jobs" in status + assert len(status["active_jobs"]) == 2 + + job_statuses = {job["job_id"]: job for job in status["active_jobs"]} + assert job_a in job_statuses + assert job_b in job_statuses + assert job_statuses[job_a]["params"]["gpu_ids"] == [0] + assert job_statuses[job_b]["params"]["gpu_ids"] == [1] + assert "controller" not in job_statuses[job_a] From dfcb31f5b8c84f09dd94ac2992aa042ec4efafd6 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 14:30:28 +0800 Subject: [PATCH 10/11] Update src/keep_gpu/mcp/server.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/keep_gpu/mcp/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index cc8af1e..c5aa8bd 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -14,7 +14,7 @@ import json import sys import uuid -from dataclasses import dataclass, asdict +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional from keep_gpu.global_gpu_controller.global_gpu_controller import GlobalGPUController From 683010bc905a11ec6406e610096453a7791b97c4 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Tue, 9 Dec 2025 14:34:40 +0800 Subject: [PATCH 11/11] fix with pre-commit --- src/keep_gpu/mcp/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/keep_gpu/mcp/server.py b/src/keep_gpu/mcp/server.py index cc8af1e..47a06f4 100644 --- a/src/keep_gpu/mcp/server.py +++ b/src/keep_gpu/mcp/server.py @@ -14,7 +14,7 @@ import json import sys import uuid -from dataclasses import dataclass, asdict +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional from keep_gpu.global_gpu_controller.global_gpu_controller import GlobalGPUController @@ -99,7 +99,8 @@ def status(self, job_id: Optional[str] = None) -> Dict[str, Any]: } return { "active_jobs": [ - {"job_id": jid, "params": sess.params} for jid, sess in self._sessions.items() + {"job_id": jid, "params": sess.params} + for jid, sess in self._sessions.items() ] }