Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch: {}

jobs:
test:
Expand All @@ -15,6 +16,7 @@ jobs:
matrix:
os: [macos-latest, ubuntu-latest]
python-version: ["3.11", "3.12", "3.13"]
backend: [numpy, numba, jax, torch]

steps:
- uses: actions/checkout@v3
Expand All @@ -31,10 +33,15 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install the project
run: uv sync --all-extras --dev
run: |
if [ "${{ matrix.backend }}" = "numpy" ]; then
uv sync
else
uv sync --extra ${{ matrix.backend }}
fi

- name: Run tests
run: uv run pytest tests/
run: uv run pytest tests/ --backend ${{matrix.backend }}

coverage:
runs-on: ubuntu-latest
Expand Down
45 changes: 40 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,53 @@
from scoringrules.backend import backends

DATA_DIR = Path(__file__).parent / "data"
RUN_TESTS = ["numpy", "numba", "jax", "torch"]
BACKENDS = [b for b in backends.available_backends if b in RUN_TESTS]

if os.getenv("SR_TEST_OUTPUT", "False").lower() in ("true", "1", "t"):
OUT_DIR = Path(__file__).parent / "output"
OUT_DIR.mkdir(exist_ok=True)
else:
OUT_DIR = None

for backend in RUN_TESTS:
backends.register_backend(backend)

def pytest_addoption(parser):
"""Add custom command-line options for pytest."""
parser.addoption(
"--backend",
action="store",
default=None,
help="Specify backend to test",
)


def get_test_backends(config):
"""Determine which backends to test."""
backend_option = config.getoption("--backend")

if backend_option:
requested = backend_option.split(",")
else:
requested = ["numpy", "numba", "jax", "torch"]

available = backends.available_backends
test_backends = [b for b in requested if b in available]

# Register backends
for b in test_backends:
try:
backends.register_backend(b)
except Exception as e:
print(f"Warning: Could not register backend '{b}': {e}")

return test_backends


# This generates the parametrization
def pytest_generate_tests(metafunc):
if "backend" in metafunc.fixturenames:
backends_to_test = get_test_backends(metafunc.config)
if not backends_to_test:
pytest.fail("No backends available for testing")
metafunc.parametrize("backend", backends_to_test)


@pytest.fixture()
Expand All @@ -27,5 +63,4 @@ def probability_forecasts():
skip_header=1,
usecols=(1, -1),
)

return data
6 changes: 0 additions & 6 deletions tests/test_brier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@

import scoringrules as sr

from .conftest import BACKENDS


@pytest.mark.parametrize("backend", BACKENDS)
def test_brier(backend):
# test exceptions
with pytest.raises(ValueError):
Expand All @@ -25,7 +22,6 @@ def test_brier(backend):
assert np.isclose(res, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_rps(backend):
# test exceptions
with pytest.raises(ValueError):
Expand Down Expand Up @@ -54,7 +50,6 @@ def test_rps(backend):
assert np.allclose(res1, res2)


@pytest.mark.parametrize("backend", BACKENDS)
def test_logs(backend):
# test exceptions
with pytest.raises(ValueError):
Expand All @@ -73,7 +68,6 @@ def test_logs(backend):
assert np.isclose(res, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_rls(backend):
# test exceptions
with pytest.raises(ValueError):
Expand Down
Loading
Loading