Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions tests/cuda_controller/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,85 @@
import time
import pytest
import torch

from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController
from tests.polling import wait_until


@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Only run CUDA tests when CUDA is available",
)
def test_cuda_controller_context_manager():
"""Validate VRAM target consumption during keep and recovery after release."""
ctrl = CudaGPUController(
rank=torch.cuda.device_count() - 1,
interval=0.05,
vram_to_keep="8MB",
vram_to_keep="64MB",
relu_iterations=64,
)

torch.cuda.set_device(ctrl.rank)
torch.cuda.empty_cache()
torch.cuda.synchronize()

target_bytes = int(ctrl.vram_to_keep) * 4 # float32 bytes
free_bytes, _ = torch.cuda.mem_get_info(ctrl.rank)
if free_bytes < int(target_bytes * 1.2):
pytest.skip(
f"Insufficient free VRAM for assertion test: need ~{target_bytes}, have {free_bytes}"
)

alloc_tolerance = 8 * 1024 * 1024
reserve_tolerance = 16 * 1024 * 1024
before_reserved = torch.cuda.memory_reserved(ctrl.rank)
before_allocated = torch.cuda.memory_allocated(ctrl.rank)

with ctrl:
time.sleep(0.3)
assert ctrl._thread and ctrl._thread.is_alive()
during_reserved = torch.cuda.memory_reserved(ctrl.rank)
assert during_reserved >= before_reserved
peak_alloc_delta = 0
peak_reserved_delta = 0

def _target_reached() -> bool:
nonlocal peak_alloc_delta, peak_reserved_delta
alloc_delta = max(
0, torch.cuda.memory_allocated(ctrl.rank) - before_allocated
)
reserved_delta = max(
0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved
)
peak_alloc_delta = max(peak_alloc_delta, alloc_delta)
peak_reserved_delta = max(peak_reserved_delta, reserved_delta)
# allocated should track payload; reserved may be larger due allocator blocks
return (
alloc_delta >= int(target_bytes * 0.95)
and reserved_delta >= alloc_delta
)

reached_target = wait_until(_target_reached, timeout_s=3.0)

assert not (ctrl._thread and ctrl._thread.is_alive())
assert reached_target, (
f"VRAM target not reached. target={target_bytes}, "
f"peak_alloc_delta={peak_alloc_delta}, peak_reserved_delta={peak_reserved_delta}"
)

alloc_delta_after = -1
reserved_delta_after = -1

def _released() -> bool:
nonlocal alloc_delta_after, reserved_delta_after
alloc_after = torch.cuda.memory_allocated(ctrl.rank)
reserved_after = torch.cuda.memory_reserved(ctrl.rank)
alloc_delta_after = max(0, alloc_after - before_allocated)
reserved_delta_after = max(0, reserved_after - before_reserved)
return (
alloc_delta_after <= alloc_tolerance
and reserved_delta_after <= reserve_tolerance
and not (ctrl._thread and ctrl._thread.is_alive())
)

released = wait_until(_released, timeout_s=3.0)

assert released, (
"VRAM did not return near baseline after release. "
f"alloc_delta_after={alloc_delta_after}, reserved_delta_after={reserved_delta_after}"
)
66 changes: 65 additions & 1 deletion tests/cuda_controller/test_keep_and_release.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import time
import torch
import pytest

from keep_gpu.single_gpu_controller.cuda_gpu_controller import CudaGPUController
from tests.polling import wait_until


@pytest.mark.skipif(
Expand Down Expand Up @@ -32,3 +34,65 @@ def test_cuda_controller_basic():
assert ctrl._thread and ctrl._thread.is_alive()
time.sleep(0.2)
assert not (ctrl._thread and ctrl._thread.is_alive())


@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Only run CUDA tests when CUDA is available",
)
def test_cuda_controller_respects_vram_target_during_keep():
"""Ensure keep() consumes roughly requested VRAM and release() frees it."""
ctrl = CudaGPUController(
rank=0,
interval=0.05,
vram_to_keep="32MB",
relu_iterations=32,
)
torch.cuda.set_device(ctrl.rank)
torch.cuda.empty_cache()
torch.cuda.synchronize()

target_bytes = int(ctrl.vram_to_keep) * 4
free_bytes, _ = torch.cuda.mem_get_info(ctrl.rank)
if free_bytes < int(target_bytes * 1.2):
pytest.skip(
f"Insufficient free VRAM for assertion test: need ~{target_bytes}, have {free_bytes}"
)

before_alloc = torch.cuda.memory_allocated(ctrl.rank)
before_reserved = torch.cuda.memory_reserved(ctrl.rank)
alloc_tolerance = 8 * 1024 * 1024
reserve_tolerance = 16 * 1024 * 1024

ctrl.keep()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Always release controller when post-keep assertions fail

test_cuda_controller_respects_vram_target_during_keep starts the background keep loop with ctrl.keep(), but cleanup is only done at the end of the happy path. If any assertion before ctrl.release() fails (for example the 3s target wait timing out on a busy CI GPU), the daemon thread keeps running and continues holding VRAM, which can cascade into unrelated CUDA test failures in the same pytest process. Use try/finally (or with ctrl:) so release is guaranteed on failure paths.

Useful? React with 👍 / 👎.

reached = wait_until(
lambda: (
(
alloc_delta := max(
0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc
)
)
>= int(target_bytes * 0.95)
and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved)
>= alloc_delta
),
timeout_s=3.0,
)
assert reached, "keep() did not reach expected VRAM allocation target in time"

alloc_delta = max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc)
reserved_delta = max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved)
assert alloc_delta >= int(target_bytes * 0.95)
assert reserved_delta >= alloc_delta

ctrl.release()
released = wait_until(
lambda: (
max(0, torch.cuda.memory_allocated(ctrl.rank) - before_alloc)
<= alloc_tolerance
and max(0, torch.cuda.memory_reserved(ctrl.rank) - before_reserved)
<= reserve_tolerance
),
timeout_s=3.0,
)
assert released, "VRAM did not return near baseline after release()"
16 changes: 16 additions & 0 deletions tests/polling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import time
from typing import Callable


def wait_until(
predicate: Callable[[], bool],
timeout_s: float = 3.0,
interval_s: float = 0.05,
) -> bool:
"""Poll predicate until it returns True or timeout is reached."""
deadline = time.time() + timeout_s
while time.time() < deadline:
if predicate():
return True
time.sleep(interval_s)
return False