diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 7ff6db14c..70f65d4cc 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -1,6 +1,8 @@ from .base import Architecture from .supervised import ( SupervisedArchitecture, + SupervisedDecimatedResNet, + SupervisedDecimatedSVDNet, SupervisedFrequencyDomainResNet, SupervisedMultiModalResNet, SupervisedSpectrogramDomainResNet, diff --git a/libs/architectures/architectures/supervised.py b/libs/architectures/architectures/supervised.py index 1ade07fff..d52f74f9e 100644 --- a/libs/architectures/architectures/supervised.py +++ b/libs/architectures/architectures/supervised.py @@ -5,6 +5,7 @@ from jaxtyping import Float from ml4gw.nn.resnet.resnet_1d import NormLayer, ResNet1D from ml4gw.nn.resnet.resnet_2d import ResNet2D +from ml4gw.nn.svd import DenseResidualBlock, FreqDomainSVDProjection from torch import Tensor import torch @@ -228,6 +229,338 @@ def forward(self, X, X_fft): return self.classifier(concat) +class SupervisedDecimatedResNet(SupervisedArchitecture): + """ + Multi-branch ResNet1D for decimated time-domain inputs. + + Each decimation segment gets its own ResNet1D branch that + produces an embedding. Embeddings are concatenated and passed + through a final linear classifier. This allows each branch to + specialize in its frequency band. + + Args: + num_ifos: + Number of interferometer channels. + num_branches: + Number of decimation segments (branches). + branch_layers: + ResNet layer configuration for each branch. Can be a + single list (shared across branches) or a list of lists + (per-branch). + branch_classes: + Embedding dimension for each branch. Can be a single int + (shared) or a list of ints (per-branch). + kernel_size: + Convolution kernel size for ResNet blocks. + norm_layer: + Normalization layer getter. + """ + + def __init__( + self, + num_ifos: int, + num_branches: int, + branch_layers: list, + branch_classes: list[int] | int, + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ): + super().__init__() + self.num_branches = num_branches + + # Normalize branch_classes to a list + if isinstance(branch_classes, int): + branch_classes = [branch_classes] * num_branches + + # Normalize branch_layers: if it's a flat list of ints, + # use same layers for all branches + if branch_layers and isinstance(branch_layers[0], int): + branch_layers = [branch_layers] * num_branches + + self.branches = torch.nn.ModuleList() + for i in range(num_branches): + self.branches.append( + ResNet1D( + in_channels=num_ifos, + layers=branch_layers[i], + classes=branch_classes[i], + kernel_size=kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + ) + + total_classes = sum(branch_classes) + self.classifier = torch.nn.Linear(total_classes, 1) + + def forward(self, *segments): + embeddings = [] + for i, seg in enumerate(segments): + embeddings.append(self.branches[i](seg)) + concat = torch.cat(embeddings, dim=-1) + return self.classifier(concat).squeeze(-1) + + +class SupervisedDecimatedSVDNet(SupervisedArchitecture): + """ + Multi-branch frequency-domain SVD network for BNS detection. + + Each decimation branch: + 1. FFTs its time-domain segment to frequency domain + 2. Projects onto a reduced SVD basis (filtering noise orthogonal + to the signal manifold) + 3. Processes SVD coefficients through a dense residual network + + Branch embeddings are concatenated and classified. + + This follows the approach of frequency-domain SVD projection, + adapted for multi-rate decimated detection. + + Args: + num_ifos: Number of interferometer channels. + num_branches: Number of decimation segments. + n_svd: SVD components per branch. + branch_hidden_dim: Hidden dimension for each branch's dense + network (legacy, single width). Ignored if + branch_hidden_dims is provided. + branch_hidden_dims: Tapering hidden dimensions for each + branch's dense network. Can be a single list like + [512, 256, 128] (shared across branches) or a list of + lists (per-branch). If an int, wraps in [int] for + backward compat. Use "shallow" for a minimal + BN -> Linear -> ReLU -> Linear architecture. + branch_embed_dim: Output embedding dimension per branch. + num_dense_blocks: Number of dense residual blocks per branch + (legacy, used when branch_hidden_dims is None). + num_blocks_per_stage: Residual blocks at each width stage + (used with branch_hidden_dims). + norm_type: Normalization type for dense blocks, "layer" or + "batch". + per_ifo_svd: If True, use per-IFO SVD projection weights. + svd_basis_path: Path to HDF5 with precomputed V matrices. + freeze_svd: Whether to freeze SVD layers initially. + dropout: Dropout rate in dense blocks. + normalize_svd: If True, apply LayerNorm to SVD output + before the dense network. Recommended for stable + training when SVD coefficients have large scale. + """ + + def __init__( + self, + num_ifos: int, + num_branches: int, + n_svd: list[int] | int = 100, + branch_hidden_dim: list[int] | int = 128, + branch_hidden_dims: Optional[list | int | str] = None, + branch_embed_dim: list[int] | int = 32, + num_dense_blocks: int = 3, + num_blocks_per_stage: int = 2, + norm_type: str = "layer", + per_ifo_svd: bool = False, + svd_basis_path: Optional[str] = None, + freeze_svd: bool = True, + dropout: float = 0.1, + normalize_svd: bool = False, + **kwargs, + ): + super().__init__() + self.num_branches = num_branches + + if isinstance(n_svd, int): + n_svd = [n_svd] * num_branches + if isinstance(branch_embed_dim, int): + branch_embed_dim = [branch_embed_dim] * num_branches + + # Determine dense network architecture + use_shallow = branch_hidden_dims == "shallow" + use_tapering = ( + branch_hidden_dims is not None and not use_shallow + ) + if use_tapering: + # Normalize branch_hidden_dims + if isinstance(branch_hidden_dims, int): + branch_hidden_dims = [branch_hidden_dims] + # If it's a flat list of ints, share across branches + if branch_hidden_dims and isinstance( + branch_hidden_dims[0], int + ): + branch_hidden_dims = [ + branch_hidden_dims + ] * num_branches + elif not use_shallow: + # Legacy: single hidden dim with num_dense_blocks + if isinstance(branch_hidden_dim, int): + branch_hidden_dim = [branch_hidden_dim] * num_branches + + # Load precomputed SVD bases + V_matrices, n_freqs = self._load_svd_bases( + svd_basis_path, num_branches + ) + + self.svd_layers = torch.nn.ModuleList() + self.svd_norms = torch.nn.ModuleList() + self.branches = torch.nn.ModuleList() + + for i in range(num_branches): + V = V_matrices[i] if V_matrices else None + n_freq = n_freqs[i] if n_freqs else n_svd[i] + + # SVD projection: time -> freq -> n_svd coefficients per IFO + svd_layer = FreqDomainSVDProjection( + num_ifos, n_freq, n_svd[i], V, + per_channel=per_ifo_svd, + ) + if freeze_svd: + svd_layer.freeze() + self.svd_layers.append(svd_layer) + + # Optional normalization on SVD output + # Use LayerNorm (not BatchNorm) to avoid train/eval + # discrepancy where BatchNorm causes output collapse + svd_out_dim = n_svd[i] * num_ifos + if normalize_svd: + self.svd_norms.append( + torch.nn.LayerNorm(svd_out_dim) + ) + else: + self.svd_norms.append(torch.nn.Identity()) + + # Dense network: SVD coefficients -> embedding + e_dim = branch_embed_dim[i] + + if use_shallow: + layers = self._build_shallow_network( + svd_out_dim, e_dim, dropout, + ) + elif use_tapering: + dims = branch_hidden_dims[i] + layers = self._build_tapering_network( + svd_out_dim, dims, e_dim, + num_blocks_per_stage, dropout, + ) + else: + h_dim = branch_hidden_dim[i] + layers = self._build_flat_network( + svd_out_dim, h_dim, e_dim, + num_dense_blocks, dropout, + ) + + self.branches.append(torch.nn.Sequential(*layers)) + + total_embed = sum(branch_embed_dim) + self.classifier = torch.nn.Linear(total_embed, 1) + + @staticmethod + def _build_shallow_network(in_dim, out_dim, dropout): + """Build minimal dense network: Linear -> ReLU -> Linear. + + Designed for weak-signal detection where a simpler model + generalizes better than a deep one. + """ + hidden = min(in_dim, 128) + return [ + torch.nn.Linear(in_dim, hidden), + torch.nn.ReLU(), + torch.nn.Dropout(dropout), + torch.nn.Linear(hidden, out_dim), + ] + + @staticmethod + def _build_flat_network( + in_dim, hidden_dim, out_dim, num_blocks, dropout, + ): + """Build legacy flat dense network (single hidden width).""" + layers = [ + torch.nn.Linear(in_dim, hidden_dim), + torch.nn.GELU(), + ] + for _ in range(num_blocks): + layers.append( + DenseResidualBlock(hidden_dim, dropout) + ) + layers.append(torch.nn.Linear(hidden_dim, out_dim)) + return layers + + @staticmethod + def _build_tapering_network( + in_dim, hidden_dims, out_dim, + blocks_per_stage, dropout, + ): + """Build tapering dense network with dimension transitions. + + Creates a network that tapers through decreasing hidden + dimensions (e.g. [512, 256, 128]), with residual blocks at + each stage and linear resize layers between stages. + """ + layers = [ + torch.nn.Linear(in_dim, hidden_dims[0]), + torch.nn.GELU(), + ] + + for stage_idx, dim in enumerate(hidden_dims): + # Residual blocks at this width + for _ in range(blocks_per_stage): + layers.append( + DenseResidualBlock(dim, dropout) + ) + + # Resize to next stage (if not the last) + if stage_idx < len(hidden_dims) - 1: + next_dim = hidden_dims[stage_idx + 1] + layers.append(torch.nn.Linear(dim, next_dim)) + layers.append(torch.nn.GELU()) + + # Final projection to embedding dim + layers.append(torch.nn.Linear(hidden_dims[-1], out_dim)) + return layers + + @staticmethod + def _load_svd_bases(path, num_branches): + """Load precomputed V matrices and n_freq from HDF5.""" + if path is None: + return None, None + import numpy as np + import h5py + V_matrices = [] + n_freqs = [] + with h5py.File(path, "r") as f: + for i in range(num_branches): + key = f"branch_{i}" + if key in f: + V_matrices.append(np.array(f[key]["V"])) + n_freqs.append(int(f[key].attrs["n_freq"])) + else: + V_matrices.append(None) + n_freqs.append(0) + return V_matrices, n_freqs + + def set_svd_frozen(self, frozen: bool): + """Freeze or unfreeze all SVD projection layers.""" + for svd_layer in self.svd_layers: + if frozen: + svd_layer.freeze() + else: + svd_layer.unfreeze() + + def forward(self, *segments): + embeddings = [] + for i, seg in enumerate(segments): + # FFT + SVD projection -> normalize -> dense embedding + svd_coeffs = self.svd_layers[i](seg) + svd_coeffs = self.svd_norms[i](svd_coeffs) + embeddings.append(self.branches[i](svd_coeffs)) + concat = torch.cat(embeddings, dim=-1) + return self.classifier(concat).squeeze(-1) + + class SupervisedTimeSpectrogramResNet(SupervisedArchitecture): """ Spectrogram and Time Domain ResNet that processes a combination of diff --git a/libs/architectures/tests/test_supervised.py b/libs/architectures/tests/test_supervised.py new file mode 100644 index 000000000..e5fca41bc --- /dev/null +++ b/libs/architectures/tests/test_supervised.py @@ -0,0 +1,233 @@ +import tempfile + +import h5py +import numpy as np +import torch +from architectures.supervised import ( + SupervisedDecimatedResNet, + SupervisedDecimatedSVDNet, +) + + +class TestSupervisedDecimatedResNet: + def test_forward_shape(self): + """4 branches with shared config produce scalar output.""" + batch = 8 + num_ifos = 2 + num_branches = 4 + arch = SupervisedDecimatedResNet( + num_ifos=num_ifos, + num_branches=num_branches, + branch_layers=[2, 2], + branch_classes=16, + ) + segments = [torch.randn(batch, num_ifos, 128) for _ in range(4)] + out = arch(*segments) + assert out.shape == (batch,) + + def test_per_branch_classes(self): + """Per-branch embedding dims via list of ints.""" + batch = 4 + num_ifos = 2 + arch = SupervisedDecimatedResNet( + num_ifos=num_ifos, + num_branches=3, + branch_layers=[2, 2], + branch_classes=[8, 16, 32], + ) + segments = [torch.randn(batch, num_ifos, 64) for _ in range(3)] + out = arch(*segments) + assert out.shape == (batch,) + + def test_per_branch_layers(self): + """Per-branch layer configs via list of lists.""" + batch = 4 + num_ifos = 2 + arch = SupervisedDecimatedResNet( + num_ifos=num_ifos, + num_branches=2, + branch_layers=[[2, 2], [3, 3]], + branch_classes=16, + ) + segments = [torch.randn(batch, num_ifos, 128) for _ in range(2)] + out = arch(*segments) + assert out.shape == (batch,) + + +class TestSupervisedDecimatedSVDNet: + """Tests for frequency-domain SVD network. + + Without an SVD basis file, n_freq falls back to n_svd[i], + so input segments must have n_samples = (n_svd - 1) * 2. + """ + + def _n_samples(self, n_svd): + """Compute required n_samples for a given n_svd.""" + return (n_svd - 1) * 2 + + def test_shallow_forward(self): + """Shallow network (no SVD basis file) produces scalar output.""" + batch = 8 + num_ifos = 2 + n_svd = 16 + n_samples = self._n_samples(n_svd) + arch = SupervisedDecimatedSVDNet( + num_ifos=num_ifos, + num_branches=2, + n_svd=n_svd, + branch_hidden_dims="shallow", + branch_embed_dim=8, + freeze_svd=False, + ) + segments = [ + torch.randn(batch, num_ifos, n_samples) for _ in range(2) + ] + out = arch(*segments) + assert out.shape == (batch,) + + def test_flat_forward(self): + """Flat (legacy) dense network produces scalar output.""" + batch = 4 + num_ifos = 2 + n_svd = 16 + n_samples = self._n_samples(n_svd) + arch = SupervisedDecimatedSVDNet( + num_ifos=num_ifos, + num_branches=2, + n_svd=n_svd, + branch_hidden_dim=32, + num_dense_blocks=2, + branch_embed_dim=8, + freeze_svd=False, + ) + segments = [ + torch.randn(batch, num_ifos, n_samples) for _ in range(2) + ] + out = arch(*segments) + assert out.shape == (batch,) + + def test_tapering_forward(self): + """Tapering dense network produces scalar output.""" + batch = 4 + num_ifos = 2 + n_svd = 16 + n_samples = self._n_samples(n_svd) + arch = SupervisedDecimatedSVDNet( + num_ifos=num_ifos, + num_branches=2, + n_svd=n_svd, + branch_hidden_dims=[64, 32], + branch_embed_dim=8, + num_blocks_per_stage=1, + freeze_svd=False, + ) + segments = [ + torch.randn(batch, num_ifos, n_samples) for _ in range(2) + ] + out = arch(*segments) + assert out.shape == (batch,) + + def test_freeze_unfreeze(self): + """set_svd_frozen toggles requires_grad on SVD params.""" + arch = SupervisedDecimatedSVDNet( + num_ifos=2, + num_branches=2, + n_svd=16, + branch_hidden_dims="shallow", + branch_embed_dim=8, + freeze_svd=True, + ) + # Initially frozen + for svd_layer in arch.svd_layers: + for p in svd_layer.parameters(): + assert not p.requires_grad + + # Unfreeze + arch.set_svd_frozen(False) + for svd_layer in arch.svd_layers: + for p in svd_layer.parameters(): + assert p.requires_grad + + # Re-freeze + arch.set_svd_frozen(True) + for svd_layer in arch.svd_layers: + for p in svd_layer.parameters(): + assert not p.requires_grad + + def test_normalize_svd(self): + """normalize_svd=True still produces correct output shape.""" + batch = 4 + num_ifos = 2 + n_svd = 16 + n_samples = self._n_samples(n_svd) + arch = SupervisedDecimatedSVDNet( + num_ifos=num_ifos, + num_branches=2, + n_svd=n_svd, + branch_hidden_dims="shallow", + branch_embed_dim=8, + normalize_svd=True, + freeze_svd=False, + ) + segments = [ + torch.randn(batch, num_ifos, n_samples) for _ in range(2) + ] + out = arch(*segments) + assert out.shape == (batch,) + # Check that LayerNorm was used (not Identity) + for norm in arch.svd_norms: + assert isinstance(norm, torch.nn.LayerNorm) + + def test_per_ifo_svd(self): + """per_ifo_svd=True forward pass produces correct shape.""" + batch = 4 + num_ifos = 2 + n_svd = 16 + n_samples = self._n_samples(n_svd) + arch = SupervisedDecimatedSVDNet( + num_ifos=num_ifos, + num_branches=2, + n_svd=n_svd, + branch_hidden_dims="shallow", + branch_embed_dim=8, + per_ifo_svd=True, + freeze_svd=False, + ) + segments = [ + torch.randn(batch, num_ifos, n_samples) for _ in range(2) + ] + out = arch(*segments) + assert out.shape == (batch,) + + def test_load_from_hdf5(self): + """Loading V matrices from HDF5 produces correct shape.""" + batch = 4 + num_ifos = 2 + n_svd = 8 + n_freq = 33 # e.g. from a 64-sample segment + n_samples = (n_freq - 1) * 2 # 64 + + with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp: + # Create HDF5 with V matrices + with h5py.File(tmp.name, "w") as f: + for i in range(2): + grp = f.create_group(f"branch_{i}") + grp.create_dataset( + "V", data=np.random.randn(2 * n_freq, n_svd) + ) + grp.attrs["n_freq"] = n_freq + + arch = SupervisedDecimatedSVDNet( + num_ifos=num_ifos, + num_branches=2, + n_svd=n_svd, + branch_hidden_dims="shallow", + branch_embed_dim=8, + svd_basis_path=tmp.name, + freeze_svd=False, + ) + segments = [ + torch.randn(batch, num_ifos, n_samples) for _ in range(2) + ] + out = arch(*segments) + assert out.shape == (batch,) diff --git a/projects/train/tests/test_callbacks.py b/projects/train/tests/test_callbacks.py new file mode 100644 index 000000000..116470e24 --- /dev/null +++ b/projects/train/tests/test_callbacks.py @@ -0,0 +1,99 @@ +from unittest.mock import MagicMock + +import torch +from architectures.supervised import SupervisedDecimatedSVDNet +from train.callbacks import SVDUnfreezeCallback + + +def _make_model(freeze=True): + """Create a small SVD network for testing.""" + return SupervisedDecimatedSVDNet( + num_ifos=2, + num_branches=2, + n_svd=16, + branch_hidden_dims="shallow", + branch_embed_dim=8, + freeze_svd=freeze, + ) + + +def _make_trainer_and_module(model, epoch): + """Create mock trainer and pl_module with a real optimizer.""" + # Only optimize non-frozen params (mimics real training) + trainable = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.Adam(trainable, lr=1e-3) + + trainer = MagicMock() + trainer.current_epoch = epoch + trainer.optimizers = [optimizer] + + pl_module = MagicMock() + pl_module.model = model + + return trainer, pl_module + + +class TestSVDUnfreezeCallback: + def test_before_unfreeze_epoch(self): + """SVD params stay frozen before unfreeze_epoch.""" + model = _make_model(freeze=True) + callback = SVDUnfreezeCallback(unfreeze_epoch=10) + trainer, pl_module = _make_trainer_and_module(model, epoch=5) + + callback.on_train_epoch_start(trainer, pl_module) + + for svd_layer in model.svd_layers: + for p in svd_layer.parameters(): + assert not p.requires_grad + + def test_at_unfreeze_epoch(self): + """At unfreeze_epoch, SVD params are unfrozen and added.""" + model = _make_model(freeze=True) + callback = SVDUnfreezeCallback(unfreeze_epoch=10, svd_lr_factor=0.01) + trainer, pl_module = _make_trainer_and_module(model, epoch=10) + + num_groups_before = len(trainer.optimizers[0].param_groups) + callback.on_train_epoch_start(trainer, pl_module) + + # SVD params should now require grad + for svd_layer in model.svd_layers: + for p in svd_layer.parameters(): + assert p.requires_grad + + # New param group should have been added + opt = trainer.optimizers[0] + assert len(opt.param_groups) == num_groups_before + 1 + + # Check LR of new group + svd_group = opt.param_groups[-1] + base_lr = opt.param_groups[0]["lr"] + assert abs(svd_group["lr"] - base_lr * 0.01) < 1e-10 + + def test_idempotent(self): + """Calling again after unfreeze has no effect.""" + model = _make_model(freeze=True) + callback = SVDUnfreezeCallback(unfreeze_epoch=10) + trainer, pl_module = _make_trainer_and_module(model, epoch=10) + + callback.on_train_epoch_start(trainer, pl_module) + num_groups_after_first = len(trainer.optimizers[0].param_groups) + + # Call again at epoch 11 + trainer.current_epoch = 11 + callback.on_train_epoch_start(trainer, pl_module) + assert len(trainer.optimizers[0].param_groups) == num_groups_after_first + + def test_model_without_set_svd_frozen(self, capsys): + """Model without set_svd_frozen prints warning, no crash.""" + model = MagicMock(spec=[]) # No attributes at all + callback = SVDUnfreezeCallback(unfreeze_epoch=0) + + trainer = MagicMock() + trainer.current_epoch = 0 + + pl_module = MagicMock() + pl_module.model = model + + callback.on_train_epoch_start(trainer, pl_module) + captured = capsys.readouterr() + assert "Warning" in captured.out diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index b052558cc..fac9c53d7 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -9,13 +9,14 @@ from botocore.exceptions import ClientError, ConnectTimeoutError from lightning import pytorch as pl from lightning.pytorch.callbacks import Callback +from lightning.pytorch.cli import SaveConfigCallback from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.utilities import grad_norm BOTO_RETRY_EXCEPTIONS = (ClientError, ConnectTimeoutError) -class WandbSaveConfig(pl.cli.SaveConfigCallback): +class WandbSaveConfig(SaveConfigCallback): """ Override of `lightning.pytorch.cli.SaveConfigCallback` for use with WandB to ensure all the hyperparameters are logged to the WandB dashboard. @@ -170,3 +171,65 @@ def on_before_optimizer_step(self, trainer, pl_module, optimizer): norms = grad_norm(pl_module, norm_type=self.norm_type) total_norm = norms[f"grad_{float(self.norm_type)}_norm_total"] self.log(f"grad_norm_{self.norm_type}", total_norm) + + +class SVDUnfreezeCallback(Callback): + """Two-phase training callback for SVD networks. + + During Phase 1, SVD projection layers are frozen and only the + dense network trains. At `unfreeze_epoch`, SVD layers are + unfrozen and added to the optimizer with a reduced learning rate. + + Args: + unfreeze_epoch: Epoch at which to unfreeze SVD layers. + svd_lr_factor: Factor to multiply the base LR by for SVD + parameters. E.g. 0.01 means SVD params train at 1% of + the main learning rate. + """ + + def __init__( + self, unfreeze_epoch: int = 300, svd_lr_factor: float = 0.01 + ): + super().__init__() + self.unfreeze_epoch = unfreeze_epoch + self.svd_lr_factor = svd_lr_factor + self._unfrozen = False + + def on_train_epoch_start(self, trainer, pl_module): + if self._unfrozen: + return + if trainer.current_epoch < self.unfreeze_epoch: + return + + self._unfrozen = True + + # Unfreeze SVD layers + if hasattr(pl_module.model, "set_svd_frozen"): + pl_module.model.set_svd_frozen(False) + else: + print( + "Warning: model has no set_svd_frozen method, " + "SVDUnfreezeCallback has no effect" + ) + return + + # Add SVD parameters to optimizer with reduced LR + optimizer = trainer.optimizers[0] + base_lr = optimizer.param_groups[0]["lr"] + svd_lr = base_lr * self.svd_lr_factor + + svd_params = [] + for svd_layer in pl_module.model.svd_layers: + svd_params.extend( + p for p in svd_layer.parameters() if p.requires_grad + ) + + if svd_params: + optimizer.add_param_group( + {"params": svd_params, "lr": svd_lr} + ) + print( + f"Epoch {trainer.current_epoch}: Unfroze SVD layers, " + f"added {len(svd_params)} params at lr={svd_lr:.2e} " + f"(base lr={base_lr:.2e})" + ) diff --git a/projects/train/train/model/__init__.py b/projects/train/train/model/__init__.py index 92e9cdd0e..0edba2c7d 100644 --- a/projects/train/train/model/__init__.py +++ b/projects/train/train/model/__init__.py @@ -1,8 +1,10 @@ from .autoencoder import AutoencoderAframe from .base import AframeBase from .supervised import ( + SimpleDecimatedAframe, SupervisedAframe, SupervisedAframeS4, + SupervisedDecimatedAframe, SupervisedMultiModalAframe, SupervisedTimeSpectrogramAframe, ) diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 77c68463d..a074c6abf 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -184,6 +184,74 @@ def validation_step(self, batch, _) -> None: ) +class SupervisedDecimatedAframe(SupervisedAframe): + """Model class for multi-branch decimated architectures.""" + + def __init__(self, arch: SupervisedArchitecture, *args, **kwargs) -> None: + super().__init__(arch, *args, **kwargs) + + def forward(self, *segments): + return self.model(*segments) + + def score(self, *segments): + return self(*segments) + + def train_step(self, batch: tuple[tuple, Tensor]) -> Tensor: + segments, y = batch + y_hat = self(*segments) + return torch.nn.functional.binary_cross_entropy_with_logits( + y_hat, y.squeeze(-1) + ) + + def validation_step(self, batch, _) -> None: + shift = batch[0] + bg_segments = batch[1] + fg_segments = batch[2] + + y_bg = self.score(*bg_segments) + + # For each segment, reshape views and compute predictions + num_views, batch_size = fg_segments[0].shape[:2] + reshaped = [] + for seg in fg_segments: + shape = seg.shape[2:] + reshaped.append(seg.view(num_views * batch_size, *shape)) + + y_fg = self.score(*reshaped) + y_fg = y_fg.view(num_views, batch_size) + y_fg = y_fg.mean(0) + + self.metric.update(shift, y_bg, y_fg) + self.log( + "valid_auroc", + self.metric, + on_step=True, + on_epoch=True, + sync_dist=True, + ) + + +class SimpleDecimatedAframe(SupervisedDecimatedAframe): + """Decimated model with simple BCE+accuracy validation. + + Uses the same train_step as SupervisedDecimatedAframe but + replaces timeslide-based validation with straightforward + BCE loss and accuracy. Suitable for pre-whitened data where + background and signal segments are already labeled. + """ + + def validation_step(self, batch, _) -> None: + segments, y = batch + y_hat = self.model(*segments) + loss = torch.nn.functional.binary_cross_entropy_with_logits( + y_hat, y.squeeze(-1) + ) + preds = (y_hat > 0).float() + acc = (preds == y.squeeze(-1)).float().mean() + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + self.log("val_acc", acc, prog_bar=True, sync_dist=True) + + class SupervisedAframeS4(SupervisedAframe): def __init__(self, arch: SupervisedArchitecture, *args, **kwargs) -> None: super().__init__(arch, *args, **kwargs)