Add frequency-domain SVD projection and dense residual block#256
Add frequency-domain SVD projection and dense residual block#256mcoughlin wants to merge 5 commits intoML4GW:mainfrom
Conversation
New ml4gw.nn.svd module with two components: - FreqDomainSVDProjection: FFT + linear projection layer initialized from precomputed SVD right singular vectors. Supports shared or per-channel (per-IFO) weights, and freeze/unfreeze for two-phase training. Adapted from DINGO's LinearProjectionRB. - DenseResidualBlock: LayerNorm(x + MLP(x)) residual block for processing SVD coefficients. Uses LayerNorm exclusively — BatchNorm causes train/eval output collapse in GW detection where batch composition (signal/noise ratio) differs between train and eval. Includes 134 parametrized tests covering shapes, gradients, V matrix initialization, freeze/unfreeze, save/load, and train/eval consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@mcoughlin Out of curiosity, how much of this was able to be done by Claude? |
|
@wbenoit26 A lot of it, but also a lot of discussion ;) |
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
||||||||||||||||||||||||||||||||||||||||||||||||
|
Very cool. Not surprised claude is decent at this. I think a feature that is missing is the actual fitting of the SVD? Or is it assumed that the SVD will be fit outside and saved to a file? |
deepchatterjeeligo
left a comment
There was a problem hiding this comment.
Hi @mcoughlin we discussed about this morning. What came out of the discussion was that in addition to unittests, we should add some content in our documentation pages showing a usecase: something I thought out loud was a figure showing how adding more svd components approaches the waveform, but that is up for more discussion.
There was a problem hiding this comment.
Can we reuse the parts from our ResNet implementation or generalize them for the residual blocks? I worry this will become a duplicate residual block implementation.
There was a problem hiding this comment.
I think it's a bit too awkward, and I don't think it would generalize too much.
ml4gw/nn/svd/projection.py
Outdated
| if V_tensor is not None: | ||
| proj.weight.data = V_tensor.T.contiguous() |
There was a problem hiding this comment.
I believe the .T syntax for transpose will be deprecated soon. I have seen warnings about it. Can we explicitly supply the axes? Also, I'm curious about the .contiguous why that is needed.
| x_freq = torch.fft.rfft(x, dim=-1) | ||
|
|
||
| # Stack real and imaginary: (batch, channels, 2 * n_freq) | ||
| x_ri = torch.cat([x_freq.real, x_freq.imag], dim=-1) | ||
|
|
||
| if self.per_channel: | ||
| proj_list = [] | ||
| for ch in range(self.num_channels): | ||
| proj_list.append(self.projections[ch](x_ri[:, ch, :])) | ||
| x_proj = torch.stack(proj_list, dim=1) | ||
| else: | ||
| x_proj = self.projection(x_ri) | ||
|
|
||
| return x_proj.reshape(batch_size, -1) |
There was a problem hiding this comment.
This is a more basic question: I don't think you are computing the SVD here, right? Without the pretrained V matrix, the projections are random linear layers, which is not what we want.
I think it will be great addition to have the SVD creation feature using our own waveforms.
Add compute_basis() static method to FreqDomainSVDProjection for computing SVD basis vectors from waveform banks. Fix bug where frequency-domain input had imaginary parts discarded by premature float32 cast. Add comprehensive tests for compute_basis() and documentation with ml4gw IMRPhenomD waveform generation examples. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The compute_basis() static method uses sklearn's randomized_svd. Add scikit-learn to project dependencies and tox test deps so CI can run the TestComputeBasis tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
New ml4gw.nn.svd module with two components:
FreqDomainSVDProjection: FFT + linear projection layer initialized from precomputed SVD right singular vectors. Supports shared or per-channel (per-IFO) weights, and freeze/unfreeze for two-phase training. Adapted from DINGO's LinearProjectionRB.
DenseResidualBlock: LayerNorm(x + MLP(x)) residual block for processing SVD coefficients. Uses LayerNorm exclusively — BatchNorm causes train/eval output collapse in GW detection where batch composition (signal/noise ratio) differs between train and eval.
Includes 134 parametrized tests covering shapes, gradients, V matrix initialization, freeze/unfreeze, save/load, and train/eval consistency.