Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.13'

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:

- name: Run pre-commit (no fix)
run: |
pre-commit run --all-files --hook-stage manual --show-diff-on-failure --color always
pre-commit run --all-files --show-diff-on-failure --color always
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install flake8 pytest
pip install torch --index-url https://download.pytorch.org/whl/cpu
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
4 changes: 3 additions & 1 deletion src/keep_gpu/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def keep(rank, args):
toc = time.time()

logger.info(
f"benchmark {rank} matmul: time span: {(toc - tic) * 1000 / 5000:.2f}ms"
"benchmark %s matmul: time span: %.2fms",
rank,
(toc - tic) * 1000 / args.matmul_iterations,
)

time.sleep(args.interval)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
controller_cls(
rank=i,
interval=interval,
vram_to_keep=vram_to_keep,
vram_to_keep=self.vram_to_keep,
busy_threshold=busy_threshold,
)
for i in self.gpu_ids
Expand Down
4 changes: 2 additions & 2 deletions src/keep_gpu/single_gpu_controller/base_gpu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def monitor(self):

def keep(self):
"""
Method to keep the specified amount of VRAM free.
Method to keep the specified amount of VRAM busy/occupied.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclasses must implement this method.")
Expand All @@ -46,7 +46,7 @@ def rest(self):

async def _keep(self):
"""
Asynchronous method to keep the specified amount of VRAM free.
Asynchronous method to keep the specified amount of VRAM busy/occupied.
This is a placeholder for subclasses to implement their logic.
"""
raise NotImplementedError("Subclasses must implement this method.")
Expand Down
43 changes: 31 additions & 12 deletions src/keep_gpu/single_gpu_controller/cuda_gpu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class CudaGPUController(BaseGPUController):
"""CudaGPUController
Keep a single CUDA GPU busy by repeatedly running lightweight
matrix-multiplication workloads in a background thread.
elementwise workloads in a background thread.

Typical usage:

Expand All @@ -42,16 +42,20 @@ def __init__(
*,
rank: int,
interval: float = 1.0,
matmul_iterations: int = 5000,
relu_iterations: int = 5000,
matmul_iterations: Optional[int] = None,
vram_to_keep: str | int = "1000 MB",
busy_threshold: int = 10,
):
"""
Args:
rank (int): Local CUDA device index to occupy.
interval (float, optional): Sleep time (seconds) between workload
batches. Defaults to 0.5.
matmul_iterations (int, optional): Number of matmul ops per batch.
batches. Defaults to 1.0.
relu_iterations (int, optional): Number of in-place ReLU ops per
batch.
matmul_iterations (int, optional): Legacy alias for
`relu_iterations`. When set, it overrides `relu_iterations`.
vram_to_keep (int or str, optional): Amount of VRAM to keep busy,
e.g. `"1000 MB"`, `"20 GB"`, or an integer like `1000 * 1000`.
This represents the total size of the matrix allocated to
Expand All @@ -65,12 +69,17 @@ def __init__(
self.rank = rank
self.device = torch.device(f"cuda:{rank}")
self.interval = interval
self.matmul_iterations = matmul_iterations
if matmul_iterations is not None:
relu_iterations = matmul_iterations
if relu_iterations <= 0:
raise ValueError("relu_iterations must be positive")
self.relu_iterations = relu_iterations
self.busy_threshold = busy_threshold
self.platform = ComputingPlatform.CUDA

self._stop_evt: Optional[threading.Event] = None
self._thread: Optional[threading.Thread] = None
self._num_elements: Optional[int] = None

@staticmethod
def parse_size(text: str) -> int:
Expand All @@ -86,6 +95,10 @@ def keep(self) -> None:
logger.warning("rank %s: keep thread already running", self.rank)
return

self._num_elements = int(self.vram_to_keep)
if self._num_elements <= 0:
raise ValueError("vram_to_keep must be positive")

self._stop_evt = threading.Event()
self._thread = threading.Thread(
target=self._keep_loop,
Expand Down Expand Up @@ -123,11 +136,17 @@ def __exit__(self, exc_type, exc, tb):
def _keep_loop(self) -> None:
"""Internal: run workloads until stop event is set."""
torch.cuda.set_device(self.rank)
num_elements = self._num_elements if self._num_elements is not None else 0
if num_elements <= 0:
logger.error(
"rank %s: invalid vram_to_keep=%s", self.rank, self.vram_to_keep
)
return
matrix = None
while not self._stop_evt.is_set():
try:
matrix = torch.rand(
self.vram_to_keep,
num_elements,
device=self.device,
dtype=torch.float32,
requires_grad=False,
Expand All @@ -149,7 +168,7 @@ def _keep_loop(self) -> None:
gpu_utilization,
)
else:
self._run_mat_batch(matrix)
self._run_relu_batch(matrix)
time.sleep(self.interval)
except RuntimeError as e:
# Handle OOM by clearing cache; then sleep and continue
Expand All @@ -165,21 +184,21 @@ def _keep_loop(self) -> None:
# Workload implementation
# ------------------------------------------------------------------
@torch.no_grad()
def _run_mat_batch(self, matrix: torch.Tensor) -> None:
"""Run a batch of dummy matmuls to keep GPU busy."""
def _run_relu_batch(self, matrix: torch.Tensor) -> None:
"""Run a batch of in-place ReLU ops to keep GPU busy."""

tic = time.time()
for _ in range(self.matmul_iterations):
for _ in range(self.relu_iterations):
torch.relu_(matrix)
if self._stop_evt.is_set():
break
torch.cuda.synchronize()
toc = time.time()

logger.debug(
"rank %s: mat ops batch done - avg %.2f ms",
"rank %s: relu ops batch done - avg %.2f ms",
self.rank,
(toc - tic) * 1000 / self.matmul_iterations,
(toc - tic) * 1000 / max(1, self.relu_iterations),
)

# ------------------------------------------------------------------
Expand Down
62 changes: 51 additions & 11 deletions src/keep_gpu/single_gpu_controller/rocm_gpu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ def __init__(
vram_to_keep: str | int = "1000 MB",
busy_threshold: int = 10,
iterations: int = 5000,
max_allocation_retries: Optional[int] = None,
):
super().__init__(vram_to_keep=vram_to_keep, interval=interval)
self.rank = rank
self.device = torch.device(f"cuda:{rank}")
self.busy_threshold = busy_threshold
self.iterations = iterations
self.max_allocation_retries = max_allocation_retries
self._stop_evt: Optional[threading.Event] = None
self._thread: Optional[threading.Thread] = None
self._failure_exc: Optional[Exception] = None
self._num_elements: Optional[int] = None

# Lazy rocm_smi import; keep handle for reuse
try:
Expand All @@ -46,6 +50,10 @@ def keep(self) -> None:
if self._thread and self._thread.is_alive():
logger.warning("rank %s: keep thread already running", self.rank)
return
self._failure_exc = None
self._num_elements = int(self.vram_to_keep)
if self._num_elements <= 0:
raise ValueError("vram_to_keep must be positive")
if self._rocm_smi:
try:
self._rocm_smi.rsmi_init()
Expand All @@ -62,12 +70,12 @@ def keep(self) -> None:
logger.info("rank %s: ROCm keep thread started", self.rank)

def release(self) -> None:
if not (self._thread and self._thread.is_alive()):
if self._thread and self._thread.is_alive():
self._stop_evt.set()
self._thread.join()
torch.cuda.empty_cache()
else:
logger.warning("rank %s: keep thread not running", self.rank)
return
self._stop_evt.set()
self._thread.join()
torch.cuda.empty_cache()
if self._rocm_smi:
try:
self._rocm_smi.rsmi_shut_down()
Expand Down Expand Up @@ -95,21 +103,45 @@ def _query_utilization(self) -> Optional[int]:
def _keep_loop(self) -> None:
torch.cuda.set_device(self.rank)
tensor = None
attempts = 0
num_elements = self._num_elements if self._num_elements is not None else 0
if num_elements <= 0:
logger.error(
"rank %s: invalid vram_to_keep=%s", self.rank, self.vram_to_keep
)
return
while not self._stop_evt.is_set():
try:
tensor = torch.rand(
self.vram_to_keep,
num_elements,
device=self.device,
dtype=torch.float32,
requires_grad=False,
)
break
except RuntimeError:
logger.exception("rank %s: failed to allocate tensor", self.rank)
except RuntimeError as exc:
attempts += 1
logger.error(
"rank %s: failed to allocate tensor (attempt %d%s): %s",
self.rank,
attempts,
(
f"/{self.max_allocation_retries}"
if self.max_allocation_retries is not None
else ""
),
exc,
)
if (
self.max_allocation_retries is not None
and attempts >= self.max_allocation_retries
):
self._failure_exc = RuntimeError(
f"rank {self.rank}: failed to allocate tensor after {attempts} attempts"
)
logger.error("%s", self._failure_exc)
return
time.sleep(self.interval)
if tensor is None:
logger.error("rank %s: failed to allocate tensor, exiting loop", self.rank)
raise RuntimeError("Failed to allocate tensor for ROCm GPU keeping")

while not self._stop_evt.is_set():
try:
Expand All @@ -127,6 +159,14 @@ def _keep_loop(self) -> None:
logger.exception("rank %s: unexpected error", self.rank)
time.sleep(self.interval)

def allocation_status(self) -> Optional[Exception]:
"""
Return allocation failure captured in the worker thread, if any.

The reference assignment/read is thread-safe for CPython's GIL model.
"""
return self._failure_exc

@torch.no_grad()
def _run_batch(self, tensor: torch.Tensor) -> None:
tic = time.time()
Expand Down
11 changes: 10 additions & 1 deletion src/keep_gpu/utilities/humanized_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,23 @@


def parse_size(text: str) -> int:
"""
Parse human-readable memory strings into float32 element counts.

The return value is the number of float32 elements needed to occupy the
requested memory size. When no unit is provided, the default unit is GB.
Supported units are the keys in `_UNITS`.
"""
text = text.strip().replace(" ", "")
m = re.fullmatch(r"([0-9]*\.?[0-9]+)([A-Za-z]*)", text)
if not m:
raise ValueError(f"invalid format: {text}, should be like '1000 MB'")
value, unit = m.groups()
unit = unit or "GB"
if len(unit) > 1:
unit = unit[:-1].upper() + unit[-1]
# Treat all-lowercase units as byte units ("gb" -> "GB", "gib" -> "GIB")
# while preserving explicit mixed-case bit forms ("Gb", "GIb").
unit = unit.upper() if unit.islower() else unit[:-1].upper() + unit[-1]
if unit not in _UNITS:
raise ValueError(f"unknown unit: {unit}, should be one of {_UNITS.keys()}")
return int(float(value) * _UNITS[unit] / 4) # float32 takes 4 bytes
11 changes: 7 additions & 4 deletions src/keep_gpu/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ def _build_console_handler(level: int) -> logging.Handler:
"""Create a colored console handler with filename:lineno."""
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
fmt = "%(log_color)s%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s"
color_fmt = "%(log_color)s%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s"
plain_fmt = (
"%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s"
)
if ColoredFormatter:
formatter = ColoredFormatter(
fmt,
color_fmt,
datefmt="%H:%M:%S",
log_colors={
"DEBUG": "cyan",
Expand All @@ -43,7 +46,7 @@ def _build_console_handler(level: int) -> logging.Handler:
},
)
else:
formatter = logging.Formatter(fmt, "%H:%M:%S")
formatter = logging.Formatter(plain_fmt, "%H:%M:%S")
handler.setFormatter(formatter)
return handler

Expand All @@ -53,7 +56,7 @@ def _build_file_handler(
) -> logging.Handler:
"""Create a file handler with filename:lineno."""
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = log_dir / f"{name}_{timestamp}.log"

Expand Down
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import torch


def pytest_addoption(parser):
Expand All @@ -23,4 +22,11 @@ def pytest_collection_modifyitems(config, items):

@pytest.fixture
def rocm_available():
return bool(torch.cuda.is_available() and getattr(torch.version, "hip", None))
try:
import torch
except Exception:
return False
try:
return bool(torch.cuda.is_available() and getattr(torch.version, "hip", None))
except Exception:
return False
Loading