Skip to content

Add frequency-domain SVD projection and dense residual block#256

Open
mcoughlin wants to merge 5 commits intoML4GW:mainfrom
mcoughlin:svd-projection
Open

Add frequency-domain SVD projection and dense residual block#256
mcoughlin wants to merge 5 commits intoML4GW:mainfrom
mcoughlin:svd-projection

Conversation

@mcoughlin
Copy link

New ml4gw.nn.svd module with two components:

  • FreqDomainSVDProjection: FFT + linear projection layer initialized from precomputed SVD right singular vectors. Supports shared or per-channel (per-IFO) weights, and freeze/unfreeze for two-phase training. Adapted from DINGO's LinearProjectionRB.

  • DenseResidualBlock: LayerNorm(x + MLP(x)) residual block for processing SVD coefficients. Uses LayerNorm exclusively — BatchNorm causes train/eval output collapse in GW detection where batch composition (signal/noise ratio) differs between train and eval.

Includes 134 parametrized tests covering shapes, gradients, V matrix initialization, freeze/unfreeze, save/load, and train/eval consistency.

New ml4gw.nn.svd module with two components:

- FreqDomainSVDProjection: FFT + linear projection layer initialized
  from precomputed SVD right singular vectors. Supports shared or
  per-channel (per-IFO) weights, and freeze/unfreeze for two-phase
  training. Adapted from DINGO's LinearProjectionRB.

- DenseResidualBlock: LayerNorm(x + MLP(x)) residual block for
  processing SVD coefficients. Uses LayerNorm exclusively — BatchNorm
  causes train/eval output collapse in GW detection where batch
  composition (signal/noise ratio) differs between train and eval.

Includes 134 parametrized tests covering shapes, gradients, V matrix
initialization, freeze/unfreeze, save/load, and train/eval consistency.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@wbenoit26
Copy link
Contributor

@mcoughlin Out of curiosity, how much of this was able to be done by Claude?

@mcoughlin
Copy link
Author

@wbenoit26 A lot of it, but also a lot of discussion ;)

@github-actions
Copy link

github-actions bot commented Feb 16, 2026

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  ml4gw/nn
  __init__.py
  ml4gw/nn/svd
  __init__.py
  dense.py
  projection.py
Project Total  

This report was generated by python-coverage-comment-action

@EthanMarx
Copy link
Collaborator

Very cool. Not surprised claude is decent at this.

I think a feature that is missing is the actual fitting of the SVD? Or is it assumed that the SVD will be fit outside and saved to a file?

Copy link
Contributor

@deepchatterjeeligo deepchatterjeeligo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mcoughlin we discussed about this morning. What came out of the discussion was that in addition to unittests, we should add some content in our documentation pages showing a usecase: something I thought out loud was a figure showing how adding more svd components approaches the waveform, but that is up for more discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the parts from our ResNet implementation or generalize them for the residual blocks? I worry this will become a duplicate residual block implementation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit too awkward, and I don't think it would generalize too much.

Comment on lines +84 to +85
if V_tensor is not None:
proj.weight.data = V_tensor.T.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the .T syntax for transpose will be deprecated soon. I have seen warnings about it. Can we explicitly supply the axes? Also, I'm curious about the .contiguous why that is needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines +120 to +133
x_freq = torch.fft.rfft(x, dim=-1)

# Stack real and imaginary: (batch, channels, 2 * n_freq)
x_ri = torch.cat([x_freq.real, x_freq.imag], dim=-1)

if self.per_channel:
proj_list = []
for ch in range(self.num_channels):
proj_list.append(self.projections[ch](x_ri[:, ch, :]))
x_proj = torch.stack(proj_list, dim=1)
else:
x_proj = self.projection(x_ri)

return x_proj.reshape(batch_size, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a more basic question: I don't think you are computing the SVD here, right? Without the pretrained V matrix, the projections are random linear layers, which is not what we want.

I think it will be great addition to have the SVD creation feature using our own waveforms.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an example for this.

mcoughlin and others added 3 commits February 20, 2026 08:31
Add compute_basis() static method to FreqDomainSVDProjection for
computing SVD basis vectors from waveform banks. Fix bug where
frequency-domain input had imaginary parts discarded by premature
float32 cast. Add comprehensive tests for compute_basis() and
documentation with ml4gw IMRPhenomD waveform generation examples.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The compute_basis() static method uses sklearn's randomized_svd.
Add scikit-learn to project dependencies and tox test deps so CI
can run the TestComputeBasis tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants