-
-
Notifications
You must be signed in to change notification settings - Fork 88
Add FiLM module (PoC) and unit test #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jacobbieker
merged 5 commits into
openclimatefix:main
from
AmanKushwaha-17:feature/film-conditioning
Dec 15, 2025
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
be577da
Add FiLM generator + applier (MetNet-style one-hot) and a basic unit …
AmanKushwaha-17 da7de59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2422d1c
film: add Google-style docstrings (module) and delete comment msg
AmanKushwaha-17 bafa7de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 558e12f
refactor: move FiLM module under models/layers and remove re-export
AmanKushwaha-17 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class FiLMGenerator(nn.Module): | ||
| """ | ||
| Generates FiLM parameters (gamma and beta) from a lead-time index. | ||
|
|
||
| A one-hot vector for the given lead time is expanded to the batch size | ||
| and passed through a small MLP to produce FiLM modulation parameters. | ||
|
|
||
| Args: | ||
| num_lead_times (int): Number of possible lead-time categories. | ||
| hidden_dim (int): Hidden size for the internal MLP. | ||
| feature_dim (int): Output dimensionality of gamma and beta. | ||
| """ | ||
|
|
||
| def __init__(self, num_lead_times: int, hidden_dim: int, feature_dim: int): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add Google-style docstrings for all the methods |
||
| super().__init__() | ||
| self.num_lead_times = num_lead_times | ||
| self.feature_dim = feature_dim | ||
| self.network = nn.Sequential( | ||
| nn.Linear(num_lead_times, hidden_dim), | ||
| nn.ReLU(), | ||
| nn.Linear(hidden_dim, 2 * feature_dim), | ||
| ) | ||
|
|
||
| def forward(self, batch_size: int, lead_time: int, device=None): | ||
| """ | ||
| Compute FiLM gamma and beta parameters. | ||
|
|
||
| Args: | ||
| batch_size (int): Number of samples to generate parameters for. | ||
| lead_time (int): Lead-time index used to construct the one-hot input. | ||
| device (optional): Device to place tensors on. Defaults to CPU. | ||
|
|
||
| Returns: | ||
| Tuple[torch.Tensor, torch.Tensor]: | ||
| gamma: Tensor of shape (batch_size, feature_dim). | ||
| beta: Tensor of shape (batch_size, feature_dim). | ||
| """ | ||
|
|
||
| one_hot = torch.zeros(batch_size, self.num_lead_times, device=device) | ||
| one_hot[:, lead_time] = 1.0 | ||
| gamma_beta = self.network(one_hot) | ||
| gamma = gamma_beta[:, : self.feature_dim] | ||
| beta = gamma_beta[:, self.feature_dim :] | ||
| return gamma, beta | ||
|
|
||
|
|
||
| class FiLMApplier(nn.Module): | ||
| """ | ||
| Applies FiLM modulation to an input tensor. | ||
|
|
||
| Gamma and beta are broadcast to match the dimensionality of the input, | ||
| and the FiLM operation is applied elementwise. | ||
| """ | ||
|
|
||
| def forward(self, x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Apply FiLM conditioning. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Input tensor of shape (B, C, ...). | ||
| gamma (torch.Tensor): Scaling parameters of shape (B, C). | ||
| beta (torch.Tensor): Bias parameters of shape (B, C). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Output tensor after FiLM modulation, same shape as `x`. | ||
| """ | ||
|
|
||
| while gamma.ndim < x.ndim: | ||
| gamma = gamma.unsqueeze(-1) | ||
| beta = beta.unsqueeze(-1) | ||
| return x * gamma + beta | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| import torch | ||
| from graph_weather.models.layers.film import FiLMGenerator, FiLMApplier | ||
|
|
||
|
|
||
| def test_film_shapes(): | ||
| batch = 4 | ||
| feature_dim = 16 | ||
| num_steps = 10 | ||
| hidden_dim = 8 | ||
| lead_time = 3 | ||
|
|
||
| gen = FiLMGenerator(num_steps, hidden_dim, feature_dim) | ||
| apply = FiLMApplier() | ||
|
|
||
| gamma, beta = gen(batch, lead_time, device="cpu") | ||
|
|
||
| assert gamma.shape == (batch, feature_dim) | ||
| assert beta.shape == (batch, feature_dim) | ||
|
|
||
| x = torch.randn(batch, feature_dim, 8, 8) | ||
| out = apply(x, gamma, beta) | ||
| assert out.shape == x.shape |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.