Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
Benchmark inference performance of the MACE medium model on GPU.

Run benchmarks:
pytest tests/test_benchmark.py --benchmark-save=<some name>

To also include torch.compile benchmarks:
MACE_FULL_BENCH=1 pytest tests/test_benchmark.py --benchmark-save=<some name>

Convert results to CSV:
python tests/test_benchmark.py > results.csv
"""

import json
import os
from pathlib import Path
Expand All @@ -21,17 +34,18 @@ def is_mace_full_bench():
@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8)
@pytest.mark.parametrize("size", (3, 5, 7, 9))
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("enable_cueq", [False, True])
@pytest.mark.parametrize("compile_mode", [None, "default"])
def test_inference(
benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda"
benchmark, size: int, dtype: str, enable_cueq: bool, compile_mode: Optional[str],device: str = "cuda"
):
if not is_mace_full_bench() and compile_mode is not None:
pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute")

with torch_tools.default_dtype(dtype):
model = load_mace_mp_medium(dtype, compile_mode, device)
model = load_mace_mp_medium(dtype, enable_cueq, compile_mode, device)
batch = create_batch(size, model, device, is_compiled=compile_mode is not None)
log_bench_info(benchmark, dtype, compile_mode, batch)
log_bench_info(benchmark, dtype, enable_cueq, compile_mode, batch)

def func():
torch.cuda.synchronize()
Expand All @@ -41,12 +55,13 @@ def func():
benchmark(func)


def load_mace_mp_medium(dtype, compile_mode, device):
def load_mace_mp_medium(dtype, enable_cueq, compile_mode, device):
calc = mace_mp(
model="medium",
default_dtype=dtype,
device=device,
enable_cueq=enable_cueq,
compile_mode=compile_mode,
device=device,
fullgraph=False,
)
model = calc.models[0].to(device)
Expand All @@ -73,10 +88,11 @@ def create_batch(size: int, model: torch.nn.Module, device: str, is_compiled: bo
return batch.to_dict()


def log_bench_info(benchmark, dtype, compile_mode, batch):
def log_bench_info(benchmark, dtype, enable_cueq, compile_mode, batch):
benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0])
benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1])
benchmark.extra_info["dtype"] = dtype
benchmark.extra_info["cueq_enabled"] = enable_cueq
benchmark.extra_info["is_compiled"] = compile_mode is not None
benchmark.extra_info["device_name"] = torch.cuda.get_device_name()

Expand All @@ -97,6 +113,7 @@ def process_benchmark_file(bench_file: Path) -> pd.DataFrame:
"num_atoms",
"num_edges",
"dtype",
"cueq_enabled",
"is_compiled",
"device_name",
"median",
Expand Down