Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Pytest configuration and shared fixtures for pyECT tests."""

import pytest
import torch


@pytest.fixture
def device():
"""Return the device to use for tests."""
return torch.device("cpu")


@pytest.fixture
def triangle_vertices():
"""Return vertices for a simple 2D triangle."""
return torch.tensor([
[-1.0, 0.0],
[0.0, 1.0],
[1.0, 0.0]
])


@pytest.fixture
def triangle_edges():
"""Return edge indices for a simple triangle."""
return torch.tensor([
[0, 1],
[1, 2],
[2, 0]
])


@pytest.fixture
def triangle_faces():
"""Return face indices for a simple triangle."""
return torch.tensor([[0, 1, 2]])


@pytest.fixture
def tetrahedron_vertices():
"""Return vertices for a simple 3D tetrahedron."""
return torch.tensor([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.5, 1.0, 0.0],
[0.5, 0.5, 1.0]
])


@pytest.fixture
def tetrahedron_edges():
"""Return edge indices for a tetrahedron."""
return torch.tensor([
[0, 1], [0, 2], [0, 3],
[1, 2], [1, 3], [2, 3]
])


@pytest.fixture
def tetrahedron_faces():
"""Return face indices for a tetrahedron."""
return torch.tensor([
[0, 1, 2],
[0, 1, 3],
[0, 2, 3],
[1, 2, 3]
])


@pytest.fixture
def tetrahedron_tets():
"""Return tetrahedron indices."""
return torch.tensor([[0, 1, 2, 3]])


def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)
config.addinivalue_line(
"markers", "cuda: marks tests that require CUDA"
)


def has_cuda():
"""Check if CUDA is available."""
return torch.cuda.is_available()


def has_mps():
"""Check if MPS (Apple Silicon) is available."""
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
Loading