From 9692d7111ed1256dd6b3a87af832d96a357e08ac Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Sun, 15 Feb 2026 21:17:50 +0800 Subject: [PATCH 1/4] test(cuda): strengthen VRAM consumption assertions --- tests/cuda_controller/test_context_manager.py | 68 +++++++++++++++++-- .../cuda_controller/test_keep_and_release.py | 68 ++++++++++++++++++- 2 files changed, 130 insertions(+), 6 deletions(-) diff --git a/tests/cuda_controller/test_context_manager.py b/tests/cuda_controller/test_context_manager.py index 854f9fd..e48d7ca 100644 --- a/tests/cuda_controller/test_context_manager.py +++ b/tests/cuda_controller/test_context_manager.py @@ -10,19 +10,77 @@ reason="Only run CUDA tests when CUDA is available", ) def test_cuda_controller_context_manager(): + """Validate VRAM target consumption during keep and recovery after release.""" ctrl = CudaGPUController( rank=torch.cuda.device_count() - 1, interval=0.05, - vram_to_keep="8MB", + vram_to_keep="64MB", relu_iterations=64, ) torch.cuda.set_device(ctrl.rank) + torch.cuda.empty_cache() + torch.cuda.synchronize() + + target_bytes = int(ctrl.vram_to_keep) * 4 # float32 bytes + free_bytes, _ = torch.cuda.mem_get_info(ctrl.rank) + if free_bytes < int(target_bytes * 1.2): + pytest.skip( + f"Insufficient free VRAM for assertion test: need ~{target_bytes}, have {free_bytes}" + ) + + alloc_tolerance = 8 * 1024 * 1024 + reserve_tolerance = 16 * 1024 * 1024 before_reserved = torch.cuda.memory_reserved(ctrl.rank) + before_allocated = torch.cuda.memory_allocated(ctrl.rank) + with ctrl: - time.sleep(0.3) assert ctrl._thread and ctrl._thread.is_alive() - during_reserved = torch.cuda.memory_reserved(ctrl.rank) - assert during_reserved >= before_reserved + reached_target = False + peak_alloc_delta = 0 + peak_reserved_delta = 0 + deadline = time.time() + 3.0 + while time.time() < deadline: + alloc_delta = max( + 0, torch.cuda.memory_allocated(ctrl.rank) - before_allocated + ) + reserved_delta = max( + 0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved + ) + peak_alloc_delta = max(peak_alloc_delta, alloc_delta) + peak_reserved_delta = max(peak_reserved_delta, reserved_delta) + + # allocated should track payload; reserved may be larger due allocator blocks + if ( + alloc_delta >= int(target_bytes * 0.95) + and reserved_delta >= alloc_delta + ): + reached_target = True + break + time.sleep(0.05) - assert not (ctrl._thread and ctrl._thread.is_alive()) + assert reached_target, ( + f"VRAM target not reached. target={target_bytes}, " + f"peak_alloc_delta={peak_alloc_delta}, peak_reserved_delta={peak_reserved_delta}" + ) + + release_deadline = time.time() + 3.0 + released = False + while time.time() < release_deadline: + alloc_after = torch.cuda.memory_allocated(ctrl.rank) + reserved_after = torch.cuda.memory_reserved(ctrl.rank) + alloc_delta_after = max(0, alloc_after - before_allocated) + reserved_delta_after = max(0, reserved_after - before_reserved) + if ( + alloc_delta_after <= alloc_tolerance + and reserved_delta_after <= reserve_tolerance + and not (ctrl._thread and ctrl._thread.is_alive()) + ): + released = True + break + time.sleep(0.05) + + assert released, ( + "VRAM did not return near baseline after release. " + f"alloc_delta_after={alloc_delta_after}, reserved_delta_after={reserved_delta_after}" + ) diff --git a/tests/cuda_controller/test_keep_and_release.py b/tests/cuda_controller/test_keep_and_release.py index 2e9970d..c13bbc9 100644 --- a/tests/cuda_controller/test_keep_and_release.py +++ b/tests/cuda_controller/test_keep_and_release.py @@ -1,9 +1,19 @@ import time -import torch import pytest +import torch + from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController +def _wait_until(predicate, timeout_s: float = 3.0, interval_s: float = 0.05) -> bool: + deadline = time.time() + timeout_s + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval_s) + return False + + @pytest.mark.skipif( not torch.cuda.is_available(), reason="Only run CUDA tests when CUDA is available", @@ -32,3 +42,59 @@ def test_cuda_controller_basic(): assert ctrl._thread and ctrl._thread.is_alive() time.sleep(0.2) assert not (ctrl._thread and ctrl._thread.is_alive()) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Only run CUDA tests when CUDA is available", +) +def test_cuda_controller_respects_vram_target_during_keep(): + """Ensure keep() consumes roughly requested VRAM and release() frees it.""" + ctrl = CudaGPUController( + rank=0, + interval=0.05, + vram_to_keep="32MB", + relu_iterations=32, + ) + torch.cuda.set_device(ctrl.rank) + torch.cuda.empty_cache() + torch.cuda.synchronize() + + target_bytes = int(ctrl.vram_to_keep) * 4 + free_bytes, _ = torch.cuda.mem_get_info(ctrl.rank) + if free_bytes < int(target_bytes * 1.2): + pytest.skip( + f"Insufficient free VRAM for assertion test: need ~{target_bytes}, have {free_bytes}" + ) + + before_alloc = torch.cuda.memory_allocated(ctrl.rank) + before_reserved = torch.cuda.memory_reserved(ctrl.rank) + alloc_tolerance = 8 * 1024 * 1024 + reserve_tolerance = 16 * 1024 * 1024 + + ctrl.keep() + reached = _wait_until( + lambda: ( + max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) + >= int(target_bytes * 0.95) + ), + timeout_s=3.0, + ) + assert reached, "keep() did not reach expected VRAM allocation target in time" + + alloc_delta = max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) + reserved_delta = max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) + assert alloc_delta >= int(target_bytes * 0.95) + assert reserved_delta >= alloc_delta + + ctrl.release() + released = _wait_until( + lambda: ( + max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) + <= alloc_tolerance + and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) + <= reserve_tolerance + ), + timeout_s=3.0, + ) + assert released, "VRAM did not return near baseline after release()" From 7d1b46c67de1731b24657287c75dbc43df3a293e Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Sun, 15 Feb 2026 22:07:09 +0800 Subject: [PATCH 2/4] Update tests/cuda_controller/test_context_manager.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- tests/cuda_controller/test_context_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/cuda_controller/test_context_manager.py b/tests/cuda_controller/test_context_manager.py index e48d7ca..c7e4fd9 100644 --- a/tests/cuda_controller/test_context_manager.py +++ b/tests/cuda_controller/test_context_manager.py @@ -66,6 +66,8 @@ def test_cuda_controller_context_manager(): release_deadline = time.time() + 3.0 released = False + alloc_delta_after = -1 + reserved_delta_after = -1 while time.time() < release_deadline: alloc_after = torch.cuda.memory_allocated(ctrl.rank) reserved_after = torch.cuda.memory_reserved(ctrl.rank) From 6551fa175ed11404306f5f0965297222a9779819 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Sun, 15 Feb 2026 22:07:26 +0800 Subject: [PATCH 3/4] Update tests/cuda_controller/test_keep_and_release.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/cuda_controller/test_keep_and_release.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cuda_controller/test_keep_and_release.py b/tests/cuda_controller/test_keep_and_release.py index c13bbc9..2545be9 100644 --- a/tests/cuda_controller/test_keep_and_release.py +++ b/tests/cuda_controller/test_keep_and_release.py @@ -75,8 +75,8 @@ def test_cuda_controller_respects_vram_target_during_keep(): ctrl.keep() reached = _wait_until( lambda: ( - max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) - >= int(target_bytes * 0.95) + (alloc_delta := max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc)) >= int(target_bytes * 0.95) + and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) >= alloc_delta ), timeout_s=3.0, ) From 56a39514439fc4b8a8476ec7cb206bd26f61f393 Mon Sep 17 00:00:00 2001 From: Wang Siyuan Date: Sun, 15 Feb 2026 22:16:08 +0800 Subject: [PATCH 4/4] test(cuda): factor polling helper and dedupe VRAM checks --- tests/cuda_controller/test_context_manager.py | 33 +++++++++---------- .../cuda_controller/test_keep_and_release.py | 26 +++++++-------- tests/polling.py | 16 +++++++++ 3 files changed, 43 insertions(+), 32 deletions(-) create mode 100644 tests/polling.py diff --git a/tests/cuda_controller/test_context_manager.py b/tests/cuda_controller/test_context_manager.py index c7e4fd9..a13d446 100644 --- a/tests/cuda_controller/test_context_manager.py +++ b/tests/cuda_controller/test_context_manager.py @@ -1,8 +1,8 @@ -import time import pytest import torch from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController +from tests.polling import wait_until @pytest.mark.skipif( @@ -36,11 +36,11 @@ def test_cuda_controller_context_manager(): with ctrl: assert ctrl._thread and ctrl._thread.is_alive() - reached_target = False peak_alloc_delta = 0 peak_reserved_delta = 0 - deadline = time.time() + 3.0 - while time.time() < deadline: + + def _target_reached() -> bool: + nonlocal peak_alloc_delta, peak_reserved_delta alloc_delta = max( 0, torch.cuda.memory_allocated(ctrl.rank) - before_allocated ) @@ -49,38 +49,35 @@ def test_cuda_controller_context_manager(): ) peak_alloc_delta = max(peak_alloc_delta, alloc_delta) peak_reserved_delta = max(peak_reserved_delta, reserved_delta) - # allocated should track payload; reserved may be larger due allocator blocks - if ( + return ( alloc_delta >= int(target_bytes * 0.95) and reserved_delta >= alloc_delta - ): - reached_target = True - break - time.sleep(0.05) + ) + + reached_target = wait_until(_target_reached, timeout_s=3.0) assert reached_target, ( f"VRAM target not reached. target={target_bytes}, " f"peak_alloc_delta={peak_alloc_delta}, peak_reserved_delta={peak_reserved_delta}" ) - release_deadline = time.time() + 3.0 - released = False alloc_delta_after = -1 reserved_delta_after = -1 - while time.time() < release_deadline: + + def _released() -> bool: + nonlocal alloc_delta_after, reserved_delta_after alloc_after = torch.cuda.memory_allocated(ctrl.rank) reserved_after = torch.cuda.memory_reserved(ctrl.rank) alloc_delta_after = max(0, alloc_after - before_allocated) reserved_delta_after = max(0, reserved_after - before_reserved) - if ( + return ( alloc_delta_after <= alloc_tolerance and reserved_delta_after <= reserve_tolerance and not (ctrl._thread and ctrl._thread.is_alive()) - ): - released = True - break - time.sleep(0.05) + ) + + released = wait_until(_released, timeout_s=3.0) assert released, ( "VRAM did not return near baseline after release. " diff --git a/tests/cuda_controller/test_keep_and_release.py b/tests/cuda_controller/test_keep_and_release.py index 2545be9..85e845e 100644 --- a/tests/cuda_controller/test_keep_and_release.py +++ b/tests/cuda_controller/test_keep_and_release.py @@ -1,17 +1,9 @@ -import time import pytest +import time import torch from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController - - -def _wait_until(predicate, timeout_s: float = 3.0, interval_s: float = 0.05) -> bool: - deadline = time.time() + timeout_s - while time.time() < deadline: - if predicate(): - return True - time.sleep(interval_s) - return False +from tests.polling import wait_until @pytest.mark.skipif( @@ -73,10 +65,16 @@ def test_cuda_controller_respects_vram_target_during_keep(): reserve_tolerance = 16 * 1024 * 1024 ctrl.keep() - reached = _wait_until( + reached = wait_until( lambda: ( - (alloc_delta := max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc)) >= int(target_bytes * 0.95) - and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) >= alloc_delta + ( + alloc_delta := max( + 0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc + ) + ) + >= int(target_bytes * 0.95) + and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved) + >= alloc_delta ), timeout_s=3.0, ) @@ -88,7 +86,7 @@ def test_cuda_controller_respects_vram_target_during_keep(): assert reserved_delta >= alloc_delta ctrl.release() - released = _wait_until( + released = wait_until( lambda: ( max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc) <= alloc_tolerance diff --git a/tests/polling.py b/tests/polling.py new file mode 100644 index 0000000..899e5b5 --- /dev/null +++ b/tests/polling.py @@ -0,0 +1,16 @@ +import time +from typing import Callable + + +def wait_until( + predicate: Callable[[], bool], + timeout_s: float = 3.0, + interval_s: float = 0.05, +) -> bool: + """Poll predicate until it returns True or timeout is reached.""" + deadline = time.time() + timeout_s + while time.time() < deadline: + if predicate(): + return True + time.sleep(interval_s) + return False