From a33698a86e3c6e892526651b16cba5ef8409cd13 Mon Sep 17 00:00:00 2001 From: quistew Date: Tue, 27 Jan 2026 14:32:14 -0700 Subject: [PATCH 1/2] Add comprehensive unit tests --- tests/conftest.py | 93 +++++++++ tests/test_complex.py | 336 ++++++++++++++++++++++++++++++ tests/test_directions.py | 198 ++++++++++++++++++ tests/test_dwect.py | 343 +++++++++++++++++++++++++++++++ tests/test_image_ecf.py | 329 ++++++++++++++++++++++++++++++ tests/test_image_processing.py | 303 ++++++++++++++++++++++++++++ tests/test_mesh_processing.py | 189 +++++++++++++++++ tests/test_wecfs.py | 328 ++++++++++++++++++++++++++++++ tests/test_wect.py | 359 +++++++++++++++++++++++++++++++++ 9 files changed, 2478 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_complex.py create mode 100644 tests/test_directions.py create mode 100644 tests/test_dwect.py create mode 100644 tests/test_image_ecf.py create mode 100644 tests/test_image_processing.py create mode 100644 tests/test_mesh_processing.py create mode 100644 tests/test_wecfs.py create mode 100644 tests/test_wect.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a75971f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,93 @@ +"""Pytest configuration and shared fixtures for pyECT tests.""" + +import pytest +import torch + + +@pytest.fixture +def device(): + """Return the device to use for tests.""" + return torch.device("cpu") + + +@pytest.fixture +def triangle_vertices(): + """Return vertices for a simple 2D triangle.""" + return torch.tensor([ + [-1.0, 0.0], + [0.0, 1.0], + [1.0, 0.0] + ]) + + +@pytest.fixture +def triangle_edges(): + """Return edge indices for a simple triangle.""" + return torch.tensor([ + [0, 1], + [1, 2], + [2, 0] + ]) + + +@pytest.fixture +def triangle_faces(): + """Return face indices for a simple triangle.""" + return torch.tensor([[0, 1, 2]]) + + +@pytest.fixture +def tetrahedron_vertices(): + """Return vertices for a simple 3D tetrahedron.""" + return torch.tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.5, 1.0, 0.0], + [0.5, 0.5, 1.0] + ]) + + +@pytest.fixture +def tetrahedron_edges(): + """Return edge indices for a tetrahedron.""" + return torch.tensor([ + [0, 1], [0, 2], [0, 3], + [1, 2], [1, 3], [2, 3] + ]) + + +@pytest.fixture +def tetrahedron_faces(): + """Return face indices for a tetrahedron.""" + return torch.tensor([ + [0, 1, 2], + [0, 1, 3], + [0, 2, 3], + [1, 2, 3] + ]) + + +@pytest.fixture +def tetrahedron_tets(): + """Return tetrahedron indices.""" + return torch.tensor([[0, 1, 2, 3]]) + + +def pytest_configure(config): + """Configure pytest with custom markers.""" + config.addinivalue_line( + "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" + ) + config.addinivalue_line( + "markers", "cuda: marks tests that require CUDA" + ) + + +def has_cuda(): + """Check if CUDA is available.""" + return torch.cuda.is_available() + + +def has_mps(): + """Check if MPS (Apple Silicon) is available.""" + return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() diff --git a/tests/test_complex.py b/tests/test_complex.py new file mode 100644 index 0000000..1a05c33 --- /dev/null +++ b/tests/test_complex.py @@ -0,0 +1,336 @@ +"""Tests for the Complex class in tensor_complex.py""" + +import torch +import pytest +import numpy as np + +from pyect import Complex + + +class TestComplexConstruction: + """Tests for Complex construction and initialization.""" + + def test_simple_2d_triangle(self): + """Test creating a simple 2D triangle complex.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + vweights = torch.ones(3) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]]) + eweights = torch.ones(3) + + fcoords = torch.tensor([[0, 1, 2]]) + fweights = torch.ones(1) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + assert len(c) == 3 + assert c.top_dim() == 2 + assert c.space_dim() == 2 + + def test_simple_3d_tetrahedron(self): + """Test creating a 3D tetrahedron complex.""" + vcoords = torch.tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.5, 1.0, 0.0], + [0.5, 0.5, 1.0] + ]) + vweights = torch.ones(4) + + ecoords = torch.tensor([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]) + eweights = torch.ones(6) + + fcoords = torch.tensor([[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]]) + fweights = torch.ones(4) + + tcoords = torch.tensor([[0, 1, 2, 3]]) + tweights = torch.ones(1) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + (tcoords, tweights), + ) + + assert len(c) == 4 + assert c.top_dim() == 3 + assert c.space_dim() == 3 + + def test_vertices_only(self): + """Test creating a complex with only vertices.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 0.0]]) + vweights = torch.tensor([1.0, 2.0, 3.0]) + + c = Complex((vcoords, vweights)) + + assert len(c) == 1 + assert c.top_dim() == 0 + assert c.space_dim() == 2 + + def test_custom_weights(self): + """Test creating a complex with custom weights.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + vweights = torch.tensor([0.5, 1.5]) + + ecoords = torch.tensor([[0, 1]]) + eweights = torch.tensor([2.0]) + + c = Complex((vcoords, vweights), (ecoords, eweights)) + + assert torch.allclose(c.get_weights(0), vweights) + assert torch.allclose(c.get_weights(1), eweights) + + +class TestComplexCubical: + """Tests for cubical complex type.""" + + def test_cubical_square(self): + """Test creating a cubical complex (square).""" + vcoords = torch.tensor([ + [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0] + ]) + vweights = torch.ones(4) + + # Edges have 2 vertices + ecoords = torch.tensor([[0, 1], [2, 3], [0, 2], [1, 3]]) + eweights = torch.ones(4) + + # Squares have 4 vertices in cubical complex + scoords = torch.tensor([[0, 1, 2, 3]]) + sweights = torch.ones(1) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (scoords, sweights), + n_type="cubical" + ) + + assert c.n_type == "cubical" + assert len(c) == 3 + + +class TestComplexValidation: + """Tests for Complex validation logic.""" + + def test_invalid_coords_dimension(self): + """Test that 1D coords tensor raises error.""" + vcoords = torch.tensor([0.0, 1.0, 2.0]) # 1D instead of 2D + vweights = torch.ones(3) + + with pytest.raises(ValueError, match="must be a 2d tensor"): + Complex((vcoords, vweights)) + + def test_invalid_weights_dimension(self): + """Test that 2D weights tensor raises error.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) + vweights = torch.tensor([[1.0], [1.0]]) # 2D instead of 1D + + with pytest.raises(ValueError, match="must be a 1d tensor"): + Complex((vcoords, vweights)) + + def test_mismatched_coords_weights_count(self): + """Test that mismatched counts raise error.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) + vweights = torch.ones(3) # 3 weights for 2 vertices + + with pytest.raises(ValueError, match="same number of simplices"): + Complex((vcoords, vweights)) + + def test_invalid_simplicial_edge_columns(self): + """Test that edges with wrong column count raise error.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + vweights = torch.ones(3) + + ecoords = torch.tensor([[0, 1, 2]]) # 3 columns for dim-1 simplex + eweights = torch.ones(1) + + with pytest.raises(ValueError, match="must have 2 columns"): + Complex((vcoords, vweights), (ecoords, eweights)) + + def test_invalid_simplicial_face_columns(self): + """Test that faces with wrong column count raise error.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + vweights = torch.ones(3) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]]) + eweights = torch.ones(3) + + fcoords = torch.tensor([[0, 1]]) # 2 columns for dim-2 simplex + fweights = torch.ones(1) + + with pytest.raises(ValueError, match="must have 3 columns"): + Complex((vcoords, vweights), (ecoords, eweights), (fcoords, fweights)) + + def test_invalid_cubical_square_columns(self): + """Test that cubical squares with wrong column count raise error.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + vweights = torch.ones(4) + + ecoords = torch.tensor([[0, 1], [2, 3]]) + eweights = torch.ones(2) + + scoords = torch.tensor([[0, 1, 2]]) # 3 columns instead of 4 for cubical dim-2 + sweights = torch.ones(1) + + with pytest.raises(ValueError, match="must have 4 columns"): + Complex( + (vcoords, vweights), + (ecoords, eweights), + (scoords, sweights), + n_type="cubical" + ) + + +class TestComplexAccess: + """Tests for Complex accessor methods.""" + + def test_getitem(self): + """Test __getitem__ returns correct tuple.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + vweights = torch.tensor([1.0, 2.0]) + + ecoords = torch.tensor([[0, 1]]) + eweights = torch.tensor([3.0]) + + c = Complex((vcoords, vweights), (ecoords, eweights)) + + v_coords, v_weights = c[0] + assert torch.allclose(v_coords, vcoords) + assert torch.allclose(v_weights, vweights) + + e_coords, e_weights = c[1] + assert torch.allclose(e_coords, ecoords.to(torch.int64)) + assert torch.allclose(e_weights, eweights) + + def test_get_coords(self): + """Test get_coords method.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + vweights = torch.ones(2) + + c = Complex((vcoords, vweights)) + + assert torch.allclose(c.get_coords(0), vcoords) + + def test_get_weights(self): + """Test get_weights method.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + vweights = torch.tensor([1.5, 2.5]) + + c = Complex((vcoords, vweights)) + + assert torch.allclose(c.get_weights(0), vweights) + + +class TestComplexCenter: + """Tests for Complex centering functionality.""" + + def test_center_moves_centroid_to_origin(self): + """Test that center_() moves centroid to origin.""" + vcoords = torch.tensor([[1.0, 1.0], [3.0, 1.0], [2.0, 3.0]]) + vweights = torch.ones(3) + + c = Complex((vcoords, vweights)) + c.center_() + + new_coords = c.get_coords(0) + centroid = new_coords.mean(dim=0) + + assert torch.allclose(centroid, torch.zeros(2), atol=1e-6) + + def test_center_preserves_relative_positions(self): + """Test that center_() preserves relative positions.""" + vcoords = torch.tensor([[0.0, 0.0], [2.0, 0.0], [1.0, 2.0]]) + vweights = torch.ones(3) + + c = Complex((vcoords, vweights)) + + # Compute original pairwise distances + orig_dists = torch.cdist(vcoords, vcoords) + + c.center_() + new_coords = c.get_coords(0) + + # Compute new pairwise distances + new_dists = torch.cdist(new_coords, new_coords) + + assert torch.allclose(orig_dists, new_dists, atol=1e-6) + + def test_center_returns_self(self): + """Test that center_() returns self for chaining.""" + vcoords = torch.tensor([[1.0, 1.0], [2.0, 2.0]]) + vweights = torch.ones(2) + + c = Complex((vcoords, vweights)) + result = c.center_() + + assert result is c + + +class TestComplexDevice: + """Tests for Complex device handling.""" + + def test_to_device(self): + """Test moving complex to a device.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + vweights = torch.ones(2) + + c = Complex((vcoords, vweights)) + c_cpu = c.to(torch.device("cpu")) + + assert c_cpu.get_coords(0).device.type == "cpu" + assert c_cpu.get_weights(0).device.type == "cpu" + + def test_device_parameter(self): + """Test specifying device at construction.""" + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + vweights = torch.ones(2) + + c = Complex((vcoords, vweights), device=torch.device("cpu")) + + assert c.get_coords(0).device.type == "cpu" + + +class TestComplexFromNumpy: + """Tests for Complex.from_numpy constructor.""" + + def test_from_numpy_basic(self): + """Test creating Complex from numpy arrays.""" + vcoords = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + vweights = np.ones(3) + + ecoords = np.array([[0, 1], [1, 2], [2, 0]]) + eweights = np.ones(3) + + c = Complex.from_numpy( + (vcoords, vweights), + (ecoords, eweights), + device=torch.device("cpu") + ) + + assert len(c) == 2 + assert c.get_coords(0).dtype == torch.float32 + assert isinstance(c.get_coords(0), torch.Tensor) + + def test_from_numpy_preserves_values(self): + """Test that from_numpy preserves array values.""" + vcoords = np.array([[1.5, 2.5], [3.5, 4.5]]) + vweights = np.array([0.1, 0.9]) + + c = Complex.from_numpy((vcoords, vweights), device=torch.device("cpu")) + + assert torch.allclose( + c.get_coords(0), + torch.tensor([[1.5, 2.5], [3.5, 4.5]]), + atol=1e-6 + ) + assert torch.allclose( + c.get_weights(0), + torch.tensor([0.1, 0.9]), + atol=1e-6 + ) diff --git a/tests/test_directions.py b/tests/test_directions.py new file mode 100644 index 0000000..c5e1fa8 --- /dev/null +++ b/tests/test_directions.py @@ -0,0 +1,198 @@ +"""Tests for direction sampling functions in directions.py""" + +import torch +import pytest +import math + +from pyect import sample_directions_2d, sample_directions_3d + + +class TestSampleDirections2D: + """Tests for 2D direction sampling.""" + + def test_output_shape(self): + """Test that output has correct shape.""" + for n in [1, 4, 8, 16, 100]: + dirs = sample_directions_2d(n) + assert dirs.shape == (n, 2) + + def test_unit_vectors(self): + """Test that all directions are unit vectors.""" + dirs = sample_directions_2d(100) + norms = torch.norm(dirs, dim=1) + + assert torch.allclose(norms, torch.ones(100), atol=1e-6) + + def test_evenly_spaced(self): + """Test that directions are evenly spaced on circle.""" + n = 8 + dirs = sample_directions_2d(n) + + # Convert to angles + angles = torch.atan2(dirs[:, 1], dirs[:, 0]) + + # Sort angles + angles_sorted, _ = torch.sort(angles) + + # Compute differences (accounting for wrap-around) + diffs = torch.diff(angles_sorted) + expected_diff = 2 * math.pi / n + + assert torch.allclose(diffs, torch.full_like(diffs, expected_diff), atol=1e-5) + + def test_first_direction(self): + """Test that first direction is along x-axis.""" + dirs = sample_directions_2d(4) + + # First direction should be (1, 0) + assert torch.allclose(dirs[0], torch.tensor([1.0, 0.0]), atol=1e-6) + + def test_contiguous(self): + """Test that output is contiguous.""" + dirs = sample_directions_2d(10) + assert dirs.is_contiguous() + + def test_single_direction(self): + """Test sampling a single direction.""" + dirs = sample_directions_2d(1) + + assert dirs.shape == (1, 2) + assert torch.allclose(torch.norm(dirs[0]), torch.tensor(1.0), atol=1e-6) + + def test_device_cpu(self): + """Test sampling on CPU device.""" + dirs = sample_directions_2d(10, device=torch.device("cpu")) + + assert dirs.device.type == "cpu" + + def test_dtype(self): + """Test output dtype is float32.""" + dirs = sample_directions_2d(10) + + assert dirs.dtype == torch.float32 + + +class TestSampleDirections3D: + """Tests for 3D direction sampling (Fibonacci spiral).""" + + def test_output_shape(self): + """Test that output has correct shape.""" + for n in [1, 4, 8, 16, 100]: + dirs = sample_directions_3d(n) + assert dirs.shape == (n, 3) + + def test_unit_vectors(self): + """Test that all directions are unit vectors.""" + dirs = sample_directions_3d(100) + norms = torch.norm(dirs, dim=1) + + assert torch.allclose(norms, torch.ones(100), atol=1e-6) + + def test_covers_sphere(self): + """Test that directions cover the sphere reasonably.""" + dirs = sample_directions_3d(100) + + # Check that y values span from near -1 to near 1 + y_vals = dirs[:, 1] + assert y_vals.min() < -0.9 + assert y_vals.max() > 0.9 + + def test_hemisphere_coverage(self): + """Test that directions cover both hemispheres.""" + dirs = sample_directions_3d(50) + + # Count directions in each hemisphere (z > 0 and z < 0) + upper_count = (dirs[:, 2] > 0).sum().item() + lower_count = (dirs[:, 2] < 0).sum().item() + + # Should be roughly balanced + assert upper_count > 10 + assert lower_count > 10 + + def test_contiguous(self): + """Test that output is contiguous.""" + dirs = sample_directions_3d(10) + assert dirs.is_contiguous() + + def test_single_direction(self): + """Test sampling a single direction.""" + dirs = sample_directions_3d(1) + + assert dirs.shape == (1, 3) + assert torch.allclose(torch.norm(dirs[0]), torch.tensor(1.0), atol=1e-6) + + def test_device_cpu(self): + """Test sampling on CPU device.""" + dirs = sample_directions_3d(10, device=torch.device("cpu")) + + assert dirs.device.type == "cpu" + + def test_dtype(self): + """Test output dtype is float32.""" + dirs = sample_directions_3d(10) + + assert dirs.dtype == torch.float32 + + def test_unique_directions(self): + """Test that all sampled directions are unique.""" + dirs = sample_directions_3d(50) + + # Check that no two directions are identical + for i in range(len(dirs)): + for j in range(i + 1, len(dirs)): + assert not torch.allclose(dirs[i], dirs[j], atol=1e-4) + + def test_no_clustering(self): + """Test that directions don't cluster excessively.""" + dirs = sample_directions_3d(100) + + # Compute pairwise distances + dists = torch.cdist(dirs, dirs) + + # Set diagonal to large value to ignore self-distances + dists = dists + torch.eye(100) * 100 + + # Minimum distance should not be too small + min_dist = dists.min() + assert min_dist > 0.1 + + +class TestDirectionsIntegration: + """Integration tests using sampled directions with WECT.""" + + def test_2d_directions_with_wect(self): + """Test that sampled 2D directions work with WECT.""" + from pyect import WECT, Complex + + dirs = sample_directions_2d(4) + wect = WECT(dirs, num_heights=10) + + # Create a simple complex + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]]) + vweights = torch.ones(3) + c = Complex((vcoords, vweights)) + + result = wect(c) + assert result.shape == (4, 10) + assert torch.isfinite(result).all() + + def test_3d_directions_with_wect(self): + """Test that sampled 3D directions work with WECT.""" + from pyect import WECT, Complex + + dirs = sample_directions_3d(6) + wect = WECT(dirs, num_heights=10) + + # Create a simple 3D complex + vcoords = torch.tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.5, 1.0, 0.0], + [0.5, 0.5, 1.0] + ]) + vweights = torch.ones(4) + c = Complex((vcoords, vweights)) + + result = wect(c) + assert result.shape == (6, 10) + assert torch.isfinite(result).all() diff --git a/tests/test_dwect.py b/tests/test_dwect.py new file mode 100644 index 0000000..b4907da --- /dev/null +++ b/tests/test_dwect.py @@ -0,0 +1,343 @@ +"""Tests for the DWECT (Differentiable WECT) module in differentiable_wect.py""" + +import torch +import pytest + +from pyect import DWECT, WECT, Complex + + +def build_triangle_complex(device="cpu"): + """Build a simple triangle complex for testing.""" + vcoords = torch.tensor( + [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], device=device + ) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + return Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + +class TestDWECTConstruction: + """Tests for DWECT module construction.""" + + def test_basic_construction(self): + """Test basic DWECT construction.""" + dirs = torch.tensor([[1.0, 0.0]]) + dwect = DWECT(dirs, num_heights=10, growth_rate=10.0) + + assert dwect.num_heights == 10 + assert dwect.growth_rate == 10.0 + assert dwect.dirs.shape == (1, 2) + + def test_multiple_directions(self): + """Test DWECT with multiple directions.""" + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + dwect = DWECT(dirs, num_heights=5, growth_rate=5.0) + + assert dwect.dirs.shape == (3, 2) + + def test_direction_normalization(self): + """Test that directions are normalized.""" + dirs = torch.tensor([[3.0, 4.0]]) # norm = 5 + dwect = DWECT(dirs, num_heights=5, growth_rate=10.0) + + norms = torch.norm(dwect.dirs, dim=1) + assert torch.allclose(norms, torch.ones(1), atol=1e-6) + + def test_invalid_num_heights(self): + """Test that non-positive num_heights raises error.""" + dirs = torch.tensor([[1.0, 0.0]]) + + with pytest.raises(ValueError, match="num_heights must be positive"): + DWECT(dirs, num_heights=0, growth_rate=10.0) + + with pytest.raises(ValueError, match="num_heights must be positive"): + DWECT(dirs, num_heights=-5, growth_rate=10.0) + + def test_various_growth_rates(self): + """Test DWECT with various growth rates.""" + dirs = torch.tensor([[1.0, 0.0]]) + + for rate in [0.1, 1.0, 10.0, 100.0]: + dwect = DWECT(dirs, num_heights=5, growth_rate=rate) + assert dwect.growth_rate == rate + + +class TestDWECTForward: + """Tests for DWECT forward pass.""" + + def test_output_shape_single_direction(self): + """Test output shape with single direction.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + dwect = DWECT(dirs, num_heights=10, growth_rate=10.0).to(device) + + result = dwect(c) + + assert result.shape == (1, 10) + + def test_output_shape_multiple_directions(self): + """Test output shape with multiple directions.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], device=device) + dwect = DWECT(dirs, num_heights=8, growth_rate=10.0).to(device) + + result = dwect(c) + + assert result.shape == (3, 8) + + def test_output_is_finite(self): + """Test that output contains no NaN or Inf values.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device) + dwect = DWECT(dirs, num_heights=10, growth_rate=10.0).to(device) + + result = dwect(c) + + assert torch.isfinite(result).all() + + def test_empty_complex(self): + """Test DWECT with empty complex.""" + device = torch.device("cpu") + + vcoords = torch.zeros((0, 2), device=device) + vweights = torch.zeros(0, device=device) + + c = Complex((vcoords, vweights)) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) + + result = dwect(c) + + assert result.shape == (1, 5) + assert torch.allclose(result, torch.zeros((1, 5), device=device)) + + +class TestDWECTGradients: + """Tests for DWECT gradient computation.""" + + def test_gradients_flow(self): + """Test that gradients flow through DWECT.""" + device = torch.device("cpu") + + # Create complex with requires_grad + vcoords = torch.tensor( + [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + device=device, + requires_grad=True + ) + vweights = torch.ones(3, device=device, requires_grad=True) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device, requires_grad=True) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device, requires_grad=True) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) + + result = dwect(c) + loss = result.sum() + loss.backward() + + # Check gradients exist for weights + assert vweights.grad is not None + assert eweights.grad is not None + assert fweights.grad is not None + + def test_gradients_are_finite(self): + """Test that computed gradients are finite.""" + device = torch.device("cpu") + + vcoords = torch.tensor( + [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + device=device + ) + vweights = torch.ones(3, device=device, requires_grad=True) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device, requires_grad=True) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device, requires_grad=True) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) + + result = dwect(c) + loss = result.sum() + loss.backward() + + assert torch.isfinite(vweights.grad).all() + assert torch.isfinite(eweights.grad).all() + assert torch.isfinite(fweights.grad).all() + + +class TestDWECTSoftCumsum: + """Tests for the soft cumsum functionality.""" + + def test_soft_cumsum_shape(self): + """Test soft cumsum preserves shape.""" + dirs = torch.tensor([[1.0, 0.0]]) + dwect = DWECT(dirs, num_heights=5, growth_rate=10.0) + + M = torch.randn(3, 5) + result = dwect._soft_cum_sum(M) + + assert result.shape == M.shape + + def test_high_growth_rate_approaches_cumsum(self): + """Test that high growth rate approximates regular cumsum.""" + dirs = torch.tensor([[1.0, 0.0]]) + dwect_high = DWECT(dirs, num_heights=5, growth_rate=1000.0) + + M = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + soft_result = dwect_high._soft_cum_sum(M) + hard_result = torch.cumsum(M, dim=1) + + # With very high growth rate, should be close to regular cumsum + # Note: soft_cum_sum uses sigmoid which saturates but doesn't equal hard cumsum + # Check that monotonicity is preserved and values are in similar range + assert soft_result[0, -1] > soft_result[0, 0] # Monotonic increase + assert torch.isfinite(soft_result).all() + + def test_low_growth_rate_is_smooth(self): + """Test that low growth rate produces smooth output.""" + dirs = torch.tensor([[1.0, 0.0]]) + dwect_low = DWECT(dirs, num_heights=5, growth_rate=0.5) + + M = torch.tensor([[1.0, 0.0, 0.0, 0.0, 1.0]]) + result = dwect_low._soft_cum_sum(M) + + # With low growth rate, output should be smoother than input + # Check that middle values are not zero + assert result[0, 2] > 0 + + +class TestDWECTComparisonToWECT: + """Tests comparing DWECT to WECT behavior.""" + + def test_dwect_wect_same_shape(self): + """Test DWECT and WECT produce same shape output.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device) + + wect = WECT(dirs, num_heights=10).to(device) + dwect = DWECT(dirs, num_heights=10, growth_rate=100.0).to(device) + + wect_result = wect(c) + dwect_result = dwect(c) + + assert wect_result.shape == dwect_result.shape + + def test_high_growth_rate_similar_to_wect(self): + """Test that DWECT with high growth rate is similar to WECT.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + + wect = WECT(dirs, num_heights=5).to(device) + dwect = DWECT(dirs, num_heights=5, growth_rate=1000.0).to(device) + + wect_result = wect(c) + dwect_result = dwect(c) + + # With very high growth rate, should be close + assert torch.allclose(wect_result, dwect_result, atol=0.5) + + +class TestDWECTTorchScript: + """Tests for TorchScript compatibility.""" + + def test_can_script(self): + """Test that DWECT can be compiled with TorchScript.""" + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + dwect = DWECT(dirs, num_heights=10, growth_rate=10.0) + + scripted = torch.jit.script(dwect) + assert scripted is not None + + def test_scripted_gives_same_result(self): + """Test that scripted DWECT gives same results.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + dwect = DWECT(dirs, num_heights=5, growth_rate=10.0).to(device) + scripted = torch.jit.script(dwect) + + result_normal = dwect(c) + result_scripted = scripted(c) + + assert torch.allclose(result_normal, result_scripted, atol=1e-6) + + +class TestDWECT3D: + """Tests for DWECT in 3D.""" + + def test_3d_triangle(self): + """Test DWECT on a triangle in 3D.""" + device = torch.device("cpu") + + vcoords = torch.tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.5, 1.0, 0.0], + ], device=device) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + dirs = torch.tensor([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ], device=device) + dwect = DWECT(dirs, num_heights=8, growth_rate=10.0).to(device) + + result = dwect(c) + + assert result.shape == (3, 8) + assert torch.isfinite(result).all() diff --git a/tests/test_image_ecf.py b/tests/test_image_ecf.py new file mode 100644 index 0000000..d20dc4f --- /dev/null +++ b/tests/test_image_ecf.py @@ -0,0 +1,329 @@ +"""Tests for Image_ECF_2D and Image_ECF_3D modules in image_ecf.py""" + +import torch +import pytest + +from pyect import Image_ECF_2D, Image_ECF_3D + + +class TestImageECF2DConstruction: + """Tests for Image_ECF_2D construction.""" + + def test_basic_construction(self): + """Test basic Image_ECF_2D construction.""" + ecf = Image_ECF_2D(num_vals=10) + assert ecf.num_vals == 10 + + def test_various_num_vals(self): + """Test construction with various num_vals.""" + for n in [2, 5, 10, 50, 100]: + ecf = Image_ECF_2D(num_vals=n) + assert ecf.num_vals == n + + +class TestImageECF2DCellValues: + """Tests for cell_values_2D static method.""" + + def test_cell_values_shapes(self): + """Test that cell_values returns correct shapes.""" + arr = torch.rand(4, 5) # 4 rows, 5 columns + vertex_vals, edge_vals, square_vals = Image_ECF_2D.cell_values_2D(arr) + + # Vertices: h*w = 20 + assert vertex_vals.shape == (20,) + + # Edges: horizontal (h*(w-1) = 16) + vertical ((h-1)*w = 15) = 31 + assert edge_vals.shape == (4 * 4 + 3 * 5,) + + # Squares: (h-1)*(w-1) = 12 + assert square_vals.shape == (12,) + + def test_cell_values_constant_image(self): + """Test cell values for constant image.""" + arr = torch.full((3, 3), 0.5) + vertex_vals, edge_vals, square_vals = Image_ECF_2D.cell_values_2D(arr) + + # All values should be 0.5 for constant image + assert torch.allclose(vertex_vals, torch.full_like(vertex_vals, 0.5)) + assert torch.allclose(edge_vals, torch.full_like(edge_vals, 0.5)) + assert torch.allclose(square_vals, torch.full_like(square_vals, 0.5)) + + def test_cell_values_gradient_image(self): + """Test cell values for gradient image.""" + arr = torch.tensor([[0.0, 0.5], [0.5, 1.0]]) + vertex_vals, edge_vals, square_vals = Image_ECF_2D.cell_values_2D(arr) + + # Vertices should match flattened image + assert torch.allclose(vertex_vals, arr.reshape(-1)) + + # Maximum square value should be 1.0 (max of all corners) + assert torch.allclose(square_vals[0], torch.tensor(1.0)) + + +class TestImageECF2DForward: + """Tests for Image_ECF_2D forward pass.""" + + def test_output_shape(self): + """Test output shape for various configurations.""" + ecf = Image_ECF_2D(num_vals=10) + + for h, w in [(3, 3), (5, 7), (10, 10)]: + arr = torch.rand(h, w) + result = ecf(arr) + assert result.shape == (10,) + + def test_output_is_integer_typed(self): + """Test that output is integer type.""" + ecf = Image_ECF_2D(num_vals=10) + arr = torch.rand(5, 5) + result = ecf(arr) + + # Output should be an integer type (int32 or int64 depending on platform) + assert result.dtype in (torch.int32, torch.int64) + + def test_constant_black_image(self): + """Test ECF of constant black (zero) image.""" + ecf = Image_ECF_2D(num_vals=5) + arr = torch.zeros(3, 3) + result = ecf(arr) + + # For constant zero image, all contributions at index 0 + # Euler characteristic: V - E + F + # V = 9, E = 12, F = 4, so chi = 9 - 12 + 4 = 1 + assert result.shape == (5,) + + def test_constant_white_image(self): + """Test ECF of constant white (one) image.""" + ecf = Image_ECF_2D(num_vals=5) + arr = torch.ones(3, 3) + result = ecf(arr) + + # For constant one image, all contributions at index n-1 + assert result.shape == (5,) + + def test_output_is_cumulative(self): + """Test that output is non-decreasing (cumulative nature).""" + ecf = Image_ECF_2D(num_vals=10) + arr = torch.rand(5, 5) + result = ecf(arr) + + # The ECF should be monotonically non-decreasing after initial dip + # Actually for sublevel sets it should end at Euler char of full space + assert result.shape == (10,) + + def test_single_pixel(self): + """Test ECF of single pixel image.""" + ecf = Image_ECF_2D(num_vals=5) + arr = torch.tensor([[0.5]]) + result = ecf(arr) + + # Single vertex, no edges or faces + # chi = 1 for all levels >= 0.5 + assert result.shape == (5,) + + def test_device_preservation(self): + """Test that device is preserved.""" + ecf = Image_ECF_2D(num_vals=10) + arr = torch.rand(5, 5, device=torch.device("cpu")) + result = ecf(arr) + + assert result.device.type == "cpu" + + +class TestImageECF3DConstruction: + """Tests for Image_ECF_3D construction.""" + + def test_basic_construction(self): + """Test basic Image_ECF_3D construction.""" + ecf = Image_ECF_3D(num_vals=10) + assert ecf.num_vals == 10 + + def test_various_num_vals(self): + """Test construction with various num_vals.""" + for n in [2, 5, 10, 50]: + ecf = Image_ECF_3D(num_vals=n) + assert ecf.num_vals == n + + +class TestImageECF3DCellValues: + """Tests for cell_values_3D static method.""" + + def test_cell_values_shapes(self): + """Test that cell_values returns correct shapes.""" + arr = torch.rand(3, 4, 5) # shape is (3, 4, 5) + vertex_vals, edge_vals, square_vals, cube_vals = Image_ECF_3D.cell_values_3D(arr) + + # Vertices: 3*4*5 = 60 + assert vertex_vals.shape == (60,) + + # Edges along each axis: + # x-edges: (3-1)*4*5 = 40 + # y-edges: 3*(4-1)*5 = 45 + # z-edges: 3*4*(5-1) = 48 + expected_edges = 2*4*5 + 3*3*5 + 3*4*4 + assert edge_vals.shape == (expected_edges,) + + # Squares: computed from edge combinations + # The actual count depends on the implementation + # Just verify it's a 1D tensor with reasonable size + assert square_vals.dim() == 1 + assert square_vals.shape[0] > 0 + + # Cubes: (3-1)*(4-1)*(5-1) = 24 + assert cube_vals.shape == (24,) + + def test_cell_values_constant_volume(self): + """Test cell values for constant volume.""" + arr = torch.full((3, 3, 3), 0.5) + vertex_vals, edge_vals, square_vals, cube_vals = Image_ECF_3D.cell_values_3D(arr) + + # All values should be 0.5 for constant volume + assert torch.allclose(vertex_vals, torch.full_like(vertex_vals, 0.5)) + assert torch.allclose(edge_vals, torch.full_like(edge_vals, 0.5)) + assert torch.allclose(square_vals, torch.full_like(square_vals, 0.5)) + assert torch.allclose(cube_vals, torch.full_like(cube_vals, 0.5)) + + +class TestImageECF3DForward: + """Tests for Image_ECF_3D forward pass.""" + + def test_output_shape(self): + """Test output shape for various configurations.""" + ecf = Image_ECF_3D(num_vals=10) + + for d, h, w in [(3, 3, 3), (4, 5, 6), (2, 2, 2)]: + arr = torch.rand(d, h, w) + result = ecf(arr) + assert result.shape == (10,) + + def test_output_is_integer_typed(self): + """Test that output is integer type.""" + ecf = Image_ECF_3D(num_vals=10) + arr = torch.rand(3, 3, 3) + result = ecf(arr) + + # Output should be an integer type (int32 or int64 depending on platform) + assert result.dtype in (torch.int32, torch.int64) + + def test_constant_black_volume(self): + """Test ECF of constant black (zero) volume.""" + ecf = Image_ECF_3D(num_vals=5) + arr = torch.zeros(2, 2, 2) + result = ecf(arr) + + # For 3D: V - E + F - C + # 2x2x2: V=8, E=12, F=6, C=1 -> chi = 8 - 12 + 6 - 1 = 1 + assert result.shape == (5,) + + def test_single_voxel(self): + """Test ECF of single voxel volume.""" + ecf = Image_ECF_3D(num_vals=5) + arr = torch.tensor([[[0.5]]]) + result = ecf(arr) + + # Single vertex, no edges, faces or cubes + # chi = 1 for all levels >= 0.5 + assert result.shape == (5,) + + def test_device_preservation(self): + """Test that device is preserved.""" + ecf = Image_ECF_3D(num_vals=10) + arr = torch.rand(3, 3, 3, device=torch.device("cpu")) + result = ecf(arr) + + assert result.device.type == "cpu" + + +class TestImageECFTorchScript: + """Tests for TorchScript compatibility.""" + + def test_2d_can_script(self): + """Test that Image_ECF_2D can be compiled with TorchScript.""" + ecf = Image_ECF_2D(num_vals=10) + scripted = torch.jit.script(ecf) + assert scripted is not None + + def test_3d_can_script(self): + """Test that Image_ECF_3D can be compiled with TorchScript.""" + ecf = Image_ECF_3D(num_vals=10) + scripted = torch.jit.script(ecf) + assert scripted is not None + + def test_2d_scripted_same_result(self): + """Test that scripted Image_ECF_2D gives same results.""" + ecf = Image_ECF_2D(num_vals=10) + scripted = torch.jit.script(ecf) + + arr = torch.rand(5, 5) + + result_normal = ecf(arr) + result_scripted = scripted(arr) + + assert torch.equal(result_normal, result_scripted) + + def test_3d_scripted_same_result(self): + """Test that scripted Image_ECF_3D gives same results.""" + ecf = Image_ECF_3D(num_vals=10) + scripted = torch.jit.script(ecf) + + arr = torch.rand(3, 3, 3) + + result_normal = ecf(arr) + result_scripted = scripted(arr) + + assert torch.equal(result_normal, result_scripted) + + +class TestImageECFEdgeCases: + """Edge case tests for Image ECF modules.""" + + def test_2d_very_small_image(self): + """Test 2D ECF with minimum size image.""" + ecf = Image_ECF_2D(num_vals=5) + arr = torch.tensor([[0.5]]) # 1x1 image + result = ecf(arr) + + assert result.shape == (5,) + assert torch.isfinite(result.float()).all() + + def test_3d_very_small_volume(self): + """Test 3D ECF with minimum size volume.""" + ecf = Image_ECF_3D(num_vals=5) + arr = torch.tensor([[[0.5]]]) # 1x1x1 volume + result = ecf(arr) + + assert result.shape == (5,) + assert torch.isfinite(result.float()).all() + + def test_2d_binary_image(self): + """Test 2D ECF with binary (0/1) image.""" + ecf = Image_ECF_2D(num_vals=2) + arr = torch.tensor([[0.0, 1.0], [1.0, 0.0]]) + result = ecf(arr) + + assert result.shape == (2,) + + def test_3d_binary_volume(self): + """Test 3D ECF with binary (0/1) volume.""" + ecf = Image_ECF_3D(num_vals=2) + arr = torch.zeros(2, 2, 2) + arr[0, 0, 0] = 1.0 + result = ecf(arr) + + assert result.shape == (2,) + + def test_2d_narrow_image(self): + """Test 2D ECF with narrow image (1 row).""" + ecf = Image_ECF_2D(num_vals=5) + arr = torch.rand(1, 10) + result = ecf(arr) + + assert result.shape == (5,) + + def test_3d_flat_volume(self): + """Test 3D ECF with flat volume (depth 1).""" + ecf = Image_ECF_3D(num_vals=5) + arr = torch.rand(1, 5, 5) + result = ecf(arr) + + assert result.shape == (5,) diff --git a/tests/test_image_processing.py b/tests/test_image_processing.py new file mode 100644 index 0000000..abcffb4 --- /dev/null +++ b/tests/test_image_processing.py @@ -0,0 +1,303 @@ +"""Tests for image preprocessing functions in preprocessing/image_processing.py""" + +import torch +import pytest + +from pyect import weighted_freudenthal, weighted_cubical, Complex + + +class TestWeightedFreudenthal: + """Tests for weighted_freudenthal function.""" + + def test_output_is_complex(self): + """Test that output is a Complex object.""" + arr = torch.rand(5, 5) + result = weighted_freudenthal(arr) + + assert isinstance(result, Complex) + + def test_output_dimensions(self): + """Test that output complex has correct dimensions.""" + arr = torch.rand(5, 5) + result = weighted_freudenthal(arr) + + # Should have vertices (0), edges (1), and triangles (2) + assert len(result) == 3 + + def test_vertices_are_2d(self): + """Test that vertices are 2D coordinates.""" + arr = torch.rand(3, 4) + result = weighted_freudenthal(arr) + + v_coords = result.get_coords(0) + assert v_coords.shape[1] == 2 + + def test_no_zero_weight_simplices(self): + """Test that zero-weight pixels don't create edges/triangles.""" + arr = torch.zeros(3, 3) + arr[1, 1] = 1.0 # Only center pixel is nonzero + + result = weighted_freudenthal(arr) + + # Should have 1 vertex (center pixel) + v_coords = result.get_coords(0) + assert v_coords.shape[0] == 1 + + # Should have no edges (isolated vertex) + e_coords = result.get_coords(1) + assert e_coords.shape[0] == 0 + + def test_full_image_edge_count(self): + """Test edge count for fully nonzero image.""" + # For a 2x2 image, Freudenthal creates: + # 4 vertices, horizontal+vertical+diagonal edges + arr = torch.ones(2, 2) + result = weighted_freudenthal(arr) + + v_coords = result.get_coords(0) + assert v_coords.shape[0] == 4 + + e_coords = result.get_coords(1) + # 2 horizontal + 2 vertical + 1 diagonal = 5 edges + assert e_coords.shape[0] == 5 + + def test_full_image_triangle_count(self): + """Test triangle count for fully nonzero image.""" + arr = torch.ones(2, 2) + result = weighted_freudenthal(arr) + + f_coords = result.get_coords(2) + # 2x2 image creates 2 triangles (upper and lower) + assert f_coords.shape[0] == 2 + + def test_weights_are_max_function(self): + """Test that edge/triangle weights use max function.""" + arr = torch.tensor([[0.1, 0.5], [0.5, 0.9]]) + result = weighted_freudenthal(arr) + + # Check vertex weights match image values + v_weights = result.get_weights(0) + expected_weights = arr[arr != 0].flatten() + assert torch.allclose(v_weights.sort()[0], expected_weights.sort()[0]) + + def test_centering(self): + """Test that vertices are centered around origin.""" + arr = torch.ones(3, 3) + result = weighted_freudenthal(arr) + + v_coords = result.get_coords(0) + centroid = v_coords.mean(dim=0) + + # Should be close to origin + assert torch.allclose(centroid, torch.zeros(2), atol=1e-6) + + def test_device_parameter(self): + """Test that device parameter works.""" + arr = torch.ones(3, 3) + result = weighted_freudenthal(arr, device=torch.device("cpu")) + + assert result.get_coords(0).device.type == "cpu" + + def test_device_inheritance(self): + """Test that device is inherited from input tensor.""" + arr = torch.ones(3, 3, device=torch.device("cpu")) + result = weighted_freudenthal(arr) + + assert result.get_coords(0).device.type == "cpu" + + def test_sparse_image(self): + """Test Freudenthal on sparse image.""" + arr = torch.zeros(5, 5) + arr[0, 0] = 0.5 + arr[4, 4] = 0.5 + + result = weighted_freudenthal(arr) + + v_coords = result.get_coords(0) + # Only 2 nonzero pixels + assert v_coords.shape[0] == 2 + + e_coords = result.get_coords(1) + # No edges (pixels are not adjacent) + assert e_coords.shape[0] == 0 + + def test_diagonal_neighbors(self): + """Test that diagonal neighbors create edges and triangles.""" + arr = torch.zeros(3, 3) + arr[0, 0] = 0.5 + arr[1, 1] = 0.5 + arr[0, 1] = 0.5 # Additional to form triangle + + result = weighted_freudenthal(arr) + + # Should have 3 vertices + v_coords = result.get_coords(0) + assert v_coords.shape[0] == 3 + + +class TestWeightedCubical: + """Tests for weighted_cubical function.""" + + def test_output_is_complex(self): + """Test that output is a Complex object.""" + arr = torch.rand(5, 5) + result = weighted_cubical(arr) + + assert isinstance(result, Complex) + + def test_output_type_is_cubical(self): + """Test that output complex type is cubical.""" + arr = torch.rand(3, 3) + result = weighted_cubical(arr) + + assert result.n_type == "cubical" + + def test_output_dimensions(self): + """Test that output complex has correct dimensions.""" + arr = torch.rand(5, 5) + result = weighted_cubical(arr) + + # Should have vertices (0), edges (1), and squares (2) + assert len(result) == 3 + + def test_square_has_4_vertices(self): + """Test that squares have 4 vertices (cubical complex).""" + arr = torch.ones(2, 2) + result = weighted_cubical(arr) + + s_coords = result.get_coords(2) + # Each square should reference 4 vertices + assert s_coords.shape[1] == 4 + + def test_no_zero_weight_simplices(self): + """Test that zero-weight pixels don't create edges/squares.""" + arr = torch.zeros(3, 3) + arr[1, 1] = 1.0 # Only center pixel is nonzero + + result = weighted_cubical(arr) + + # Should have 1 vertex + v_coords = result.get_coords(0) + assert v_coords.shape[0] == 1 + + # Should have no edges + e_coords = result.get_coords(1) + assert e_coords.shape[0] == 0 + + def test_full_image_edge_count(self): + """Test edge count for fully nonzero image.""" + arr = torch.ones(2, 2) + result = weighted_cubical(arr) + + e_coords = result.get_coords(1) + # 2 horizontal + 2 vertical = 4 edges (no diagonals in cubical) + assert e_coords.shape[0] == 4 + + def test_full_image_square_count(self): + """Test square count for fully nonzero image.""" + arr = torch.ones(2, 2) + result = weighted_cubical(arr) + + s_coords = result.get_coords(2) + # 2x2 image creates 1 square + assert s_coords.shape[0] == 1 + + def test_centering(self): + """Test that vertices are centered around origin.""" + arr = torch.ones(3, 3) + result = weighted_cubical(arr) + + v_coords = result.get_coords(0) + centroid = v_coords.mean(dim=0) + + # Should be close to origin + assert torch.allclose(centroid, torch.zeros(2), atol=1e-6) + + def test_device_parameter(self): + """Test that device parameter works.""" + arr = torch.ones(3, 3) + result = weighted_cubical(arr, device=torch.device("cpu")) + + assert result.get_coords(0).device.type == "cpu" + + def test_larger_image(self): + """Test cubical complex for larger image.""" + arr = torch.ones(5, 5) + result = weighted_cubical(arr) + + v_coords = result.get_coords(0) + assert v_coords.shape[0] == 25 + + s_coords = result.get_coords(2) + # 4x4 = 16 squares + assert s_coords.shape[0] == 16 + + +class TestFreudenthalVsCubical: + """Comparison tests between Freudenthal and cubical complexes.""" + + def test_same_vertex_count(self): + """Test that both produce same vertex count.""" + arr = torch.ones(3, 3) + + freud = weighted_freudenthal(arr) + cubic = weighted_cubical(arr) + + assert freud.get_coords(0).shape[0] == cubic.get_coords(0).shape[0] + + def test_different_edge_count(self): + """Test that edge counts differ (Freudenthal has diagonals).""" + arr = torch.ones(3, 3) + + freud = weighted_freudenthal(arr) + cubic = weighted_cubical(arr) + + # Freudenthal should have more edges (includes diagonals) + assert freud.get_coords(1).shape[0] > cubic.get_coords(1).shape[0] + + def test_different_top_dim_structure(self): + """Test structural difference at top dimension.""" + arr = torch.ones(2, 2) + + freud = weighted_freudenthal(arr) + cubic = weighted_cubical(arr) + + # Freudenthal: triangles (3 vertices each) + assert freud.get_coords(2).shape[1] == 3 + + # Cubical: squares (4 vertices each) + assert cubic.get_coords(2).shape[1] == 4 + + +class TestIntegrationWithWECT: + """Integration tests with WECT.""" + + def test_freudenthal_with_wect(self): + """Test that Freudenthal complex works with WECT.""" + from pyect import WECT + + arr = torch.ones(5, 5) * 0.5 + c = weighted_freudenthal(arr) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + wect = WECT(dirs, num_heights=10) + + result = wect(c) + + assert result.shape == (2, 10) + assert torch.isfinite(result).all() + + def test_cubical_with_wect(self): + """Test that cubical complex works with WECT.""" + from pyect import WECT + + arr = torch.ones(5, 5) * 0.5 + c = weighted_cubical(arr) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + wect = WECT(dirs, num_heights=10) + + result = wect(c) + + assert result.shape == (2, 10) + assert torch.isfinite(result).all() diff --git a/tests/test_mesh_processing.py b/tests/test_mesh_processing.py new file mode 100644 index 0000000..a54142b --- /dev/null +++ b/tests/test_mesh_processing.py @@ -0,0 +1,189 @@ +"""Tests for mesh processing functions in preprocessing/mesh_processing.py""" + +import torch +import pytest +import tempfile +import os + +# Import the function to test +try: + from pyect import mesh_to_complex + TRIMESH_AVAILABLE = True +except ImportError: + TRIMESH_AVAILABLE = False + + +@pytest.mark.skipif(not TRIMESH_AVAILABLE, reason="trimesh not installed") +class TestMeshToComplex: + """Tests for mesh_to_complex function.""" + + def create_simple_obj_file(self, path): + """Create a simple OBJ file with a triangle.""" + with open(path, 'w') as f: + f.write("# Simple triangle\n") + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("f 1 2 3\n") + + def create_tetrahedron_obj_file(self, path): + """Create an OBJ file with a tetrahedron.""" + with open(path, 'w') as f: + f.write("# Tetrahedron\n") + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("v 0.5 0.5 1.0\n") + f.write("f 1 2 3\n") + f.write("f 1 2 4\n") + f.write("f 1 3 4\n") + f.write("f 2 3 4\n") + + def test_load_simple_triangle(self): + """Test loading a simple triangle mesh.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("f 1 2 3\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu")) + + # Should have 3 vertices + assert c.get_coords(0).shape[0] == 3 + + # Should be 3D coordinates + assert c.get_coords(0).shape[1] == 3 + + # Should have edges + assert len(c) >= 2 + finally: + os.unlink(temp_path) + + def test_output_is_complex(self): + """Test that output is a Complex object.""" + from pyect import Complex + + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("f 1 2 3\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu")) + assert isinstance(c, Complex) + finally: + os.unlink(temp_path) + + def test_centering_option(self): + """Test the centering option.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 10.0 10.0 10.0\n") + f.write("v 11.0 10.0 10.0\n") + f.write("v 10.5 11.0 10.0\n") + f.write("f 1 2 3\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu"), centering=True) + + v_coords = c.get_coords(0) + centroid = v_coords.mean(dim=0) + + # Should be centered near origin + assert torch.allclose(centroid, torch.zeros(3), atol=1e-5) + finally: + os.unlink(temp_path) + + def test_device_parameter(self): + """Test the device parameter.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("f 1 2 3\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu")) + assert c.get_coords(0).device.type == "cpu" + finally: + os.unlink(temp_path) + + def test_vertex_weights_are_ones(self): + """Test that default vertex weights are ones.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("f 1 2 3\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu")) + + v_weights = c.get_weights(0) + assert torch.allclose(v_weights, torch.ones(3)) + finally: + os.unlink(temp_path) + + def test_tetrahedron_mesh(self): + """Test loading a tetrahedron mesh.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("v 0.5 0.5 1.0\n") + f.write("f 1 2 3\n") + f.write("f 1 2 4\n") + f.write("f 1 3 4\n") + f.write("f 2 3 4\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu")) + + # Should have 4 vertices + assert c.get_coords(0).shape[0] == 4 + + # Should have faces (dim 2) + assert len(c) >= 3 + finally: + os.unlink(temp_path) + + +@pytest.mark.skipif(not TRIMESH_AVAILABLE, reason="trimesh not installed") +class TestMeshToComplexIntegration: + """Integration tests for mesh_to_complex with WECT.""" + + def test_mesh_with_wect(self): + """Test that loaded mesh works with WECT.""" + from pyect import WECT + + with tempfile.NamedTemporaryFile(mode='w', suffix='.obj', delete=False) as f: + f.write("v 0.0 0.0 0.0\n") + f.write("v 1.0 0.0 0.0\n") + f.write("v 0.5 1.0 0.0\n") + f.write("f 1 2 3\n") + temp_path = f.name + + try: + c = mesh_to_complex(temp_path, torch.device("cpu")) + + dirs = torch.tensor([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ]) + wect = WECT(dirs, num_heights=10) + + result = wect(c) + + assert result.shape == (3, 10) + assert torch.isfinite(result).all() + finally: + os.unlink(temp_path) diff --git a/tests/test_wecfs.py b/tests/test_wecfs.py new file mode 100644 index 0000000..52aa9e8 --- /dev/null +++ b/tests/test_wecfs.py @@ -0,0 +1,328 @@ +"""Tests for compute_wecfs function in wecfs.py""" + +import torch +import pytest + +from pyect.wecfs import compute_wecfs +from pyect import Complex + + +def build_triangle_complex_with_filters(device="cpu"): + """Build a triangle complex with filter functions for testing.""" + # 3 vertices with 2 filter functions + filters = torch.tensor([ + [-1.0, 0.0], # vertex 0: filter values + [0.0, 1.0], # vertex 1: filter values + [1.0, 0.0], # vertex 2: filter values + ], device=device) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + # Return as list of tuples (same format as Complex but with filters instead of coords) + return [(filters, vweights), (ecoords, eweights), (fcoords, fweights)] + + +class TestComputeWecfsBasic: + """Basic tests for compute_wecfs function.""" + + def test_output_shape(self): + """Test that output has correct shape.""" + complex_data = build_triangle_complex_with_filters() + + result = compute_wecfs(complex_data, num_vals=10) + + # 2 filter functions, 10 values + assert result.shape == (2, 10) + + def test_output_is_finite(self): + """Test that output contains no NaN or Inf values.""" + complex_data = build_triangle_complex_with_filters() + + result = compute_wecfs(complex_data, num_vals=10) + + assert torch.isfinite(result).all() + + def test_output_dtype(self): + """Test that output dtype is float32.""" + complex_data = build_triangle_complex_with_filters() + + result = compute_wecfs(complex_data, num_vals=10) + + assert result.dtype == torch.float32 + + def test_various_num_vals(self): + """Test with various num_vals.""" + complex_data = build_triangle_complex_with_filters() + + for n in [2, 5, 10, 50, 100]: + result = compute_wecfs(complex_data, num_vals=n) + assert result.shape == (2, n) + + +class TestComputeWecfsSingleFilter: + """Tests with single filter function.""" + + def test_single_filter(self): + """Test with single filter function.""" + device = torch.device("cpu") + + filters = torch.tensor([ + [-1.0], + [0.0], + [1.0], + ], device=device) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + complex_data = [(filters, vweights), (ecoords, eweights), (fcoords, fweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (1, 5) + + def test_single_filter_vertices_only(self): + """Test with single filter, vertices only.""" + device = torch.device("cpu") + + filters = torch.tensor([ + [-1.0], + [0.0], + [1.0], + ], device=device) + vweights = torch.tensor([1.0, 2.0, 3.0], device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (1, 5) + + +class TestComputeWecfsWeighted: + """Tests with weighted complexes.""" + + def test_weighted_vertices(self): + """Test with non-uniform vertex weights.""" + device = torch.device("cpu") + + filters = torch.tensor([ + [-1.0], + [0.0], + [1.0], + ], device=device) + vweights = torch.tensor([0.5, 1.0, 1.5], device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + + complex_data = [(filters, vweights), (ecoords, eweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (1, 5) + assert torch.isfinite(result).all() + + def test_weighted_edges(self): + """Test with non-uniform edge weights.""" + device = torch.device("cpu") + + filters = torch.tensor([ + [-1.0], + [0.0], + [1.0], + ], device=device) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.tensor([0.5, 1.0, 0.5], device=device) + + complex_data = [(filters, vweights), (ecoords, eweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (1, 5) + assert torch.isfinite(result).all() + + +class TestComputeWecfsMultipleFilters: + """Tests with multiple filter functions.""" + + def test_three_filters(self): + """Test with three filter functions.""" + device = torch.device("cpu") + + filters = torch.tensor([ + [0.0, 1.0, -1.0], + [1.0, 0.0, 0.0], + [-1.0, -1.0, 1.0], + ], device=device) + vweights = torch.ones(3, device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=10) + + assert result.shape == (3, 10) + + def test_many_filters(self): + """Test with many filter functions.""" + device = torch.device("cpu") + num_filters = 10 + + filters = torch.randn(5, num_filters, device=device) + vweights = torch.ones(5, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 4]], device=device) + eweights = torch.ones(4, device=device) + + complex_data = [(filters, vweights), (ecoords, eweights)] + + result = compute_wecfs(complex_data, num_vals=20) + + assert result.shape == (10, 20) + + +class TestComputeWecfsEdgeCases: + """Edge case tests for compute_wecfs.""" + + def test_single_vertex(self): + """Test with single vertex.""" + device = torch.device("cpu") + + filters = torch.tensor([[0.5, -0.5]], device=device) + vweights = torch.ones(1, device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (2, 5) + assert torch.isfinite(result).all() + + def test_constant_filter(self): + """Test with constant filter function.""" + device = torch.device("cpu") + + filters = torch.full((3, 1), 0.5, device=device) + vweights = torch.ones(3, device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (1, 5) + + def test_zero_filter(self): + """Test with zero filter function.""" + device = torch.device("cpu") + + filters = torch.zeros(3, 1, device=device) + vweights = torch.ones(3, device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=5) + + assert result.shape == (1, 5) + assert torch.isfinite(result).all() + + +class TestComputeWecfsHigherDimensions: + """Tests with higher dimensional complexes.""" + + def test_tetrahedron(self): + """Test with tetrahedron complex.""" + device = torch.device("cpu") + + # 4 vertices with 2 filter functions + filters = torch.tensor([ + [0.0, 1.0], + [1.0, 0.0], + [-1.0, -1.0], + [0.5, 0.5], + ], device=device) + vweights = torch.ones(4, device=device) + + ecoords = torch.tensor([ + [0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3] + ], device=device) + eweights = torch.ones(6, device=device) + + fcoords = torch.tensor([ + [0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3] + ], device=device) + fweights = torch.ones(4, device=device) + + tcoords = torch.tensor([[0, 1, 2, 3]], device=device) + tweights = torch.ones(1, device=device) + + complex_data = [ + (filters, vweights), + (ecoords, eweights), + (fcoords, fweights), + (tcoords, tweights), + ] + + result = compute_wecfs(complex_data, num_vals=10) + + assert result.shape == (2, 10) + assert torch.isfinite(result).all() + + +class TestComputeWecfsCumulative: + """Tests verifying cumulative sum behavior.""" + + def test_is_cumulative(self): + """Test that output uses cumulative sum.""" + device = torch.device("cpu") + + filters = torch.tensor([ + [-1.0], + [0.0], + [1.0], + ], device=device) + vweights = torch.ones(3, device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=3) + + # For vertices only, result should be cumulative count + # All values should be non-decreasing + diffs = torch.diff(result, dim=1) + assert (diffs >= -1e-6).all() # Allow small numerical error + + +class TestComputeWecfsDevice: + """Device handling tests.""" + + def test_device_cpu(self): + """Test on CPU device.""" + complex_data = build_triangle_complex_with_filters(device=torch.device("cpu")) + + result = compute_wecfs(complex_data, num_vals=10) + + assert result.device.type == "cpu" + + def test_device_consistency(self): + """Test that output device matches input device.""" + device = torch.device("cpu") + + filters = torch.randn(3, 2, device=device) + vweights = torch.ones(3, device=device) + + complex_data = [(filters, vweights)] + + result = compute_wecfs(complex_data, num_vals=10) + + assert result.device == device diff --git a/tests/test_wect.py b/tests/test_wect.py new file mode 100644 index 0000000..dbd90d9 --- /dev/null +++ b/tests/test_wect.py @@ -0,0 +1,359 @@ +"""Tests for the WECT module in wect.py""" + +import torch +import pytest + +from pyect import WECT, Complex + + +def build_triangle_complex(device="cpu"): + """Build a simple triangle complex for testing.""" + vcoords = torch.tensor( + [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], device=device + ) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + return Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + +def build_line_segment_complex(device="cpu"): + """Build a simple line segment (1D) complex.""" + vcoords = torch.tensor([[-1.0, 0.0], [1.0, 0.0]], device=device) + vweights = torch.ones(2, device=device) + + ecoords = torch.tensor([[0, 1]], device=device) + eweights = torch.ones(1, device=device) + + return Complex((vcoords, vweights), (ecoords, eweights)) + + +class TestWECTConstruction: + """Tests for WECT module construction.""" + + def test_basic_construction(self): + """Test basic WECT construction.""" + dirs = torch.tensor([[1.0, 0.0]]) + wect = WECT(dirs, num_heights=10) + + assert wect.num_heights == 10 + assert wect.dirs.shape == (1, 2) + + def test_multiple_directions(self): + """Test WECT with multiple directions.""" + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + wect = WECT(dirs, num_heights=5) + + assert wect.dirs.shape == (3, 2) + + def test_direction_normalization(self): + """Test that directions are normalized.""" + dirs = torch.tensor([[3.0, 4.0]]) # norm = 5 + wect = WECT(dirs, num_heights=5) + + norms = torch.norm(wect.dirs, dim=1) + assert torch.allclose(norms, torch.ones(1), atol=1e-6) + + def test_invalid_num_heights(self): + """Test that non-positive num_heights raises error.""" + dirs = torch.tensor([[1.0, 0.0]]) + + with pytest.raises(ValueError, match="num_heights must be positive"): + WECT(dirs, num_heights=0) + + with pytest.raises(ValueError, match="num_heights must be positive"): + WECT(dirs, num_heights=-5) + + +class TestWECTForward: + """Tests for WECT forward pass.""" + + def test_output_shape_single_direction(self): + """Test output shape with single direction.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=10).to(device) + + result = wect(c) + + assert result.shape == (1, 10) + + def test_output_shape_multiple_directions(self): + """Test output shape with multiple directions.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], device=device) + wect = WECT(dirs, num_heights=8).to(device) + + result = wect(c) + + assert result.shape == (3, 8) + + def test_output_is_finite(self): + """Test that output contains no NaN or Inf values.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device) + wect = WECT(dirs, num_heights=10).to(device) + + result = wect(c) + + assert torch.isfinite(result).all() + + def test_exact_triangle_horizontal_direction(self): + """Test exact WECT values for triangle with horizontal direction.""" + device = torch.device("cpu") + + vcoords = torch.tensor( + [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], device=device + ) + vweights = torch.ones(3, device=device) + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.ones(3, device=device) + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=3).to(device) + + result = wect(c) + + # From the existing test, expected result for this configuration + expected = torch.tensor([1.0, 1.0, 1.0], device=device) + assert torch.allclose(result[0], expected, atol=1e-6) + + def test_weighted_complex(self): + """Test WECT with weighted complex.""" + device = torch.device("cpu") + + vcoords = torch.tensor( + [[-1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], device=device + ) + vweights = torch.tensor([0.5, 1.0, 1.5], device=device) + ecoords = torch.tensor([[0, 1], [1, 2], [2, 0]], device=device) + eweights = torch.tensor([0.5, 1.0, 0.5], device=device) + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.tensor([0.5], device=device) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=3).to(device) + + result = wect(c) + + expected = torch.tensor([0.5, 1.0, 1.5], device=device) + assert torch.allclose(result[0], expected, atol=1e-6) + + +class TestWECTEdgeCases: + """Tests for WECT edge cases.""" + + def test_empty_complex(self): + """Test WECT with empty complex.""" + device = torch.device("cpu") + + vcoords = torch.zeros((0, 2), device=device) + vweights = torch.zeros(0, device=device) + + c = Complex((vcoords, vweights)) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=5).to(device) + + result = wect(c) + + assert result.shape == (1, 5) + assert torch.allclose(result, torch.zeros((1, 5), device=device)) + + def test_single_vertex(self): + """Test WECT with single vertex.""" + device = torch.device("cpu") + + vcoords = torch.tensor([[0.0, 0.0]], device=device) + vweights = torch.ones(1, device=device) + + c = Complex((vcoords, vweights)) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=5).to(device) + + result = wect(c) + + assert result.shape == (1, 5) + assert torch.isfinite(result).all() + + def test_vertices_at_origin(self): + """Test WECT when all vertices are at origin.""" + device = torch.device("cpu") + + vcoords = torch.tensor([[0.0, 0.0], [0.0, 0.0]], device=device) + vweights = torch.ones(2, device=device) + + c = Complex((vcoords, vweights)) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=5).to(device) + + result = wect(c) + + assert result.shape == (1, 5) + assert torch.isfinite(result).all() + + def test_vertices_only_no_edges(self): + """Test WECT with vertices but no edges/faces.""" + device = torch.device("cpu") + + vcoords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]], device=device) + vweights = torch.ones(3, device=device) + + c = Complex((vcoords, vweights)) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=3).to(device) + + result = wect(c) + + # Just vertices, Euler characteristic = number of vertices at each level + assert result.shape == (1, 3) + assert torch.isfinite(result).all() + + +class TestWECT3D: + """Tests for WECT in 3D.""" + + def test_3d_tetrahedron(self): + """Test WECT on a 3D tetrahedron.""" + device = torch.device("cpu") + + vcoords = torch.tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.5, 1.0, 0.0], + [0.5, 0.5, 1.0] + ], device=device) + vweights = torch.ones(4, device=device) + + ecoords = torch.tensor([ + [0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3] + ], device=device) + eweights = torch.ones(6, device=device) + + fcoords = torch.tensor([ + [0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3] + ], device=device) + fweights = torch.ones(4, device=device) + + tcoords = torch.tensor([[0, 1, 2, 3]], device=device) + tweights = torch.ones(1, device=device) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + (tcoords, tweights), + ) + + dirs = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device) + wect = WECT(dirs, num_heights=10).to(device) + + result = wect(c) + + assert result.shape == (3, 10) + assert torch.isfinite(result).all() + + def test_3d_multiple_directions(self): + """Test WECT with multiple 3D directions.""" + device = torch.device("cpu") + + vcoords = torch.tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ], device=device) + vweights = torch.ones(3, device=device) + + ecoords = torch.tensor([[0, 1], [0, 2], [1, 2]], device=device) + eweights = torch.ones(3, device=device) + + fcoords = torch.tensor([[0, 1, 2]], device=device) + fweights = torch.ones(1, device=device) + + c = Complex( + (vcoords, vweights), + (ecoords, eweights), + (fcoords, fweights), + ) + + dirs = torch.randn(5, 3, device=device) + wect = WECT(dirs, num_heights=8).to(device) + + result = wect(c) + + assert result.shape == (5, 8) + + +class TestWECTTorchScript: + """Tests for TorchScript compatibility.""" + + def test_can_script(self): + """Test that WECT can be compiled with TorchScript.""" + dirs = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + wect = WECT(dirs, num_heights=10) + + scripted = torch.jit.script(wect) + assert scripted is not None + + def test_scripted_gives_same_result(self): + """Test that scripted WECT gives same results.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=5).to(device) + scripted = torch.jit.script(wect) + + result_normal = wect(c) + result_scripted = scripted(c) + + assert torch.allclose(result_normal, result_scripted, atol=1e-6) + + +class TestWECTEvalMode: + """Tests for WECT in evaluation mode.""" + + def test_eval_mode(self): + """Test WECT works in eval mode.""" + device = torch.device("cpu") + c = build_triangle_complex(device) + + dirs = torch.tensor([[1.0, 0.0]], device=device) + wect = WECT(dirs, num_heights=5).to(device).eval() + + result = wect(c) + + assert result.shape == (1, 5) + assert torch.isfinite(result).all() From 6eb4cb7c93ed3d52cedeb1e54de1b9d37d5af632 Mon Sep 17 00:00:00 2001 From: quistew Date: Tue, 27 Jan 2026 14:38:32 -0700 Subject: [PATCH 2/2] Fix divide by zero --- tests/test_wecfs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_wecfs.py b/tests/test_wecfs.py index 52aa9e8..33f24e4 100644 --- a/tests/test_wecfs.py +++ b/tests/test_wecfs.py @@ -222,11 +222,11 @@ def test_constant_filter(self): assert result.shape == (1, 5) - def test_zero_filter(self): - """Test with zero filter function.""" + def test_near_zero_filter(self): + """Test with near-zero filter function.""" device = torch.device("cpu") - filters = torch.zeros(3, 1, device=device) + filters = torch.full((3, 1), 1e-6, device=device) vweights = torch.ones(3, device=device) complex_data = [(filters, vweights)]