-
-
Notifications
You must be signed in to change notification settings - Fork 90
feat: NVIDIA-style Hierarchical Gradient Checkpointing #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jacobbieker
merged 5 commits into
openclimatefix:main
from
Sidharth1743:feat/gradient-checkpointing
Jan 7, 2026
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9c4c80c
Added gradient checkpointing
Sidharth1743 ad304d6
docstrings matched , pytest modifications , benchamrk merged
Sidharth1743 40f8ca3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9d554f2
move graphcast to models/graphcast/model.py
Sidharth1743 ac5b607
Remove redundant checkpoint flags test
jacobbieker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """GraphCast model with gradient checkpointing.""" | ||
|
|
||
| from .model import GraphCast, GraphCastConfig | ||
|
|
||
| __all__ = ["GraphCast", "GraphCastConfig"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,345 @@ | ||
| """GraphCast model with hierarchical gradient checkpointing. | ||
|
|
||
| This module provides a complete GraphCast-style weather forecasting model | ||
| with NVIDIA-style hierarchical gradient checkpointing for memory-efficient training. | ||
|
|
||
| Based on: | ||
| - NVIDIA PhysicsNeMo GraphCast implementation | ||
| """ | ||
|
|
||
| from typing import Optional, Tuple | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
| from torch.utils.checkpoint import checkpoint | ||
|
|
||
| from graph_weather.models.layers.decoder import Decoder | ||
| from graph_weather.models.layers.encoder import Encoder | ||
| from graph_weather.models.layers.processor import Processor | ||
|
|
||
|
|
||
| class GraphCast(torch.nn.Module): | ||
| """GraphCast model with hierarchical gradient checkpointing. | ||
|
|
||
| This model combines Encoder, Processor, and Decoder with NVIDIA-style | ||
| hierarchical checkpointing controls for flexible memory-compute tradeoffs. | ||
|
|
||
| Hierarchical checkpointing methods: | ||
| - set_checkpoint_model(flag): Checkpoint entire forward pass | ||
| - set_checkpoint_encoder(flag): Checkpoint encoder section | ||
| - set_checkpoint_processor(segments): Checkpoint processor with configurable segments | ||
| - set_checkpoint_decoder(flag): Checkpoint decoder section | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| lat_lons: list, | ||
| resolution: int = 2, | ||
| input_dim: int = 78, | ||
| output_dim: int = 78, | ||
| hidden_dim: int = 256, | ||
| num_processor_blocks: int = 9, | ||
| hidden_layers: int = 2, | ||
| mlp_norm_type: str = "LayerNorm", | ||
| use_checkpointing: bool = False, | ||
| efficient_batching: bool = False, | ||
| ): | ||
| """ | ||
| Initialize GraphCast model with hierarchical checkpointing support. | ||
|
|
||
| Args: | ||
| lat_lons: List of (lat, lon) tuples defining the grid points | ||
| resolution: H3 resolution level | ||
| input_dim: Input feature dimension | ||
| output_dim: Output feature dimension | ||
| hidden_dim: Hidden dimension for all layers | ||
| num_processor_blocks: Number of message passing blocks in processor | ||
| hidden_layers: Number of hidden layers in MLPs | ||
| mlp_norm_type: Normalization type for MLPs | ||
| use_checkpointing: Enable fine-grained checkpointing in all layers | ||
| efficient_batching: Use efficient batching (avoid graph replication) | ||
| """ | ||
| super().__init__() | ||
|
|
||
| self.lat_lons = lat_lons | ||
| self.input_dim = input_dim | ||
| self.output_dim = output_dim | ||
| self.efficient_batching = efficient_batching | ||
|
|
||
| # Initialize components | ||
| self.encoder = Encoder( | ||
| lat_lons=lat_lons, | ||
| resolution=resolution, | ||
| input_dim=input_dim, | ||
| output_dim=hidden_dim, | ||
| output_edge_dim=hidden_dim, | ||
| hidden_dim_processor_node=hidden_dim, | ||
| hidden_dim_processor_edge=hidden_dim, | ||
| hidden_layers_processor_node=hidden_layers, | ||
| hidden_layers_processor_edge=hidden_layers, | ||
| mlp_norm_type=mlp_norm_type, | ||
| use_checkpointing=use_checkpointing, | ||
| efficient_batching=efficient_batching, | ||
| ) | ||
|
|
||
| self.processor = Processor( | ||
| input_dim=hidden_dim, | ||
| edge_dim=hidden_dim, | ||
| num_blocks=num_processor_blocks, | ||
| hidden_dim_processor_node=hidden_dim, | ||
| hidden_dim_processor_edge=hidden_dim, | ||
| hidden_layers_processor_node=hidden_layers, | ||
| hidden_layers_processor_edge=hidden_layers, | ||
| mlp_norm_type=mlp_norm_type, | ||
| use_checkpointing=use_checkpointing, | ||
| ) | ||
|
|
||
| self.decoder = Decoder( | ||
| lat_lons=lat_lons, | ||
| resolution=resolution, | ||
| input_dim=hidden_dim, | ||
| output_dim=output_dim, | ||
| hidden_dim_processor_node=hidden_dim, | ||
| hidden_dim_processor_edge=hidden_dim, | ||
| hidden_layers_processor_node=hidden_layers, | ||
| hidden_layers_processor_edge=hidden_layers, | ||
| mlp_norm_type=mlp_norm_type, | ||
| hidden_dim_decoder=hidden_dim, | ||
| hidden_layers_decoder=hidden_layers, | ||
| use_checkpointing=use_checkpointing, | ||
| efficient_batching=efficient_batching, | ||
| ) | ||
|
|
||
| # Hierarchical checkpointing flags (default: use fine-grained checkpointing) | ||
| self._checkpoint_model = False | ||
| self._checkpoint_encoder = False | ||
| self._checkpoint_processor_segments = 0 # 0 = use layer's internal checkpointing | ||
| self._checkpoint_decoder = False | ||
|
|
||
| def set_checkpoint_model(self, checkpoint_flag: bool): | ||
| """ | ||
| Checkpoint entire model as a single segment. | ||
|
|
||
| When enabled, creates one checkpoint for the entire forward pass. | ||
| This provides maximum memory savings but highest recomputation cost. | ||
| Disables all other hierarchical checkpointing when enabled. | ||
|
|
||
| Args: | ||
| checkpoint_flag: If True, checkpoint entire model. If False, use hierarchical checkpointing. | ||
| """ | ||
| self._checkpoint_model = checkpoint_flag | ||
| if checkpoint_flag: | ||
| # Disable all fine-grained checkpointing | ||
| self._checkpoint_encoder = False | ||
| self._checkpoint_processor_segments = 0 | ||
| self._checkpoint_decoder = False | ||
|
|
||
| def set_checkpoint_encoder(self, checkpoint_flag: bool): | ||
| """ | ||
| Checkpoint encoder section. | ||
|
|
||
| Checkpoints the encoder forward pass as a single segment. | ||
| Only effective when set_checkpoint_model(False). | ||
|
|
||
| Args: | ||
| checkpoint_flag: If True, checkpoint encoder section. | ||
| """ | ||
| self._checkpoint_encoder = checkpoint_flag | ||
|
|
||
| def set_checkpoint_processor(self, checkpoint_segments: int): | ||
| """ | ||
| Checkpoint processor with configurable segments. | ||
|
|
||
| Controls how the processor is checkpointed: | ||
| - 0: Use processor's internal per-block checkpointing | ||
| - -1: Checkpoint entire processor as one segment | ||
| - N > 0: Checkpoint every N blocks (not yet implemented) | ||
|
|
||
| Only effective when set_checkpoint_model(False). | ||
|
|
||
| Args: | ||
| checkpoint_segments: Checkpointing strategy (0, -1, or positive integer). | ||
| """ | ||
| self._checkpoint_processor_segments = checkpoint_segments | ||
|
|
||
| def set_checkpoint_decoder(self, checkpoint_flag: bool): | ||
| """ | ||
| Checkpoint decoder section. | ||
|
|
||
| Checkpoints the decoder forward pass as a single segment. | ||
| Only effective when set_checkpoint_model(False). | ||
|
|
||
| Args: | ||
| checkpoint_flag: If True, checkpoint decoder section. | ||
| """ | ||
| self._checkpoint_decoder = checkpoint_flag | ||
|
|
||
| def _encoder_forward(self, features: Tensor) -> Tuple[Tensor, Tensor, Tensor]: | ||
| """ | ||
| Encoder forward pass (for checkpointing). | ||
| """ | ||
| return self.encoder(features) | ||
|
|
||
| def _processor_forward( | ||
| self, | ||
| x: Tensor, | ||
| edge_index: Tensor, | ||
| edge_attr: Tensor, | ||
| batch_size: Optional[int] = None, | ||
| ) -> Tensor: | ||
| """ | ||
| Processor forward pass (for checkpointing). | ||
| """ | ||
| return self.processor( | ||
| x, | ||
| edge_index, | ||
| edge_attr, | ||
| batch_size=batch_size, | ||
| efficient_batching=self.efficient_batching, | ||
| ) | ||
|
|
||
| def _decoder_forward( | ||
| self, | ||
| processed_features: Tensor, | ||
| original_features: Tensor, | ||
| batch_size: int, | ||
| ) -> Tensor: | ||
| """ | ||
| Decoder forward pass (for checkpointing). | ||
| """ | ||
| return self.decoder(processed_features, original_features, batch_size) | ||
|
|
||
| def _custom_forward(self, features: Tensor) -> Tensor: | ||
| """ | ||
| Forward pass with hierarchical checkpointing. | ||
| """ | ||
| batch_size = features.shape[0] | ||
|
|
||
| # Encoder | ||
| if self._checkpoint_encoder: | ||
| latent_features, edge_index, edge_attr = checkpoint( | ||
| self._encoder_forward, | ||
| features, | ||
| use_reentrant=False, | ||
| preserve_rng_state=False, | ||
| ) | ||
| else: | ||
| latent_features, edge_index, edge_attr = self.encoder(features) | ||
|
|
||
| # Processor | ||
| if self._checkpoint_processor_segments == -1: | ||
| # Checkpoint entire processor as one block | ||
| processed_features = checkpoint( | ||
| self._processor_forward, | ||
| latent_features, | ||
| edge_index, | ||
| edge_attr, | ||
| batch_size if self.efficient_batching else None, | ||
| use_reentrant=False, | ||
| preserve_rng_state=False, | ||
| ) | ||
| else: | ||
| # Use processor's internal checkpointing (controlled by use_checkpointing) | ||
| processed_features = self.processor( | ||
| latent_features, | ||
| edge_index, | ||
| edge_attr, | ||
| batch_size=batch_size, | ||
| efficient_batching=self.efficient_batching, | ||
| ) | ||
|
|
||
| # Decoder | ||
| if self._checkpoint_decoder: | ||
| output = checkpoint( | ||
| self._decoder_forward, | ||
| processed_features, | ||
| features, | ||
| batch_size, | ||
| use_reentrant=False, | ||
| preserve_rng_state=False, | ||
| ) | ||
| else: | ||
| output = self.decoder(processed_features, features, batch_size) | ||
|
|
||
| return output | ||
|
|
||
| def forward(self, features: Tensor) -> Tensor: | ||
| """Forward pass through GraphCast model. | ||
|
|
||
| Args: | ||
| features: Input features of shape [batch_size, num_points, input_dim] | ||
|
|
||
| Returns: | ||
| Output predictions of shape [batch_size, num_points, output_dim] | ||
| """ | ||
| if self._checkpoint_model: | ||
| # Checkpoint entire model as one segment | ||
| return checkpoint( | ||
| self._custom_forward, | ||
| features, | ||
| use_reentrant=False, | ||
| preserve_rng_state=False, | ||
| ) | ||
| else: | ||
| # Use hierarchical checkpointing | ||
| return self._custom_forward(features) | ||
|
|
||
|
|
||
| class GraphCastConfig: | ||
| """Configuration helper for GraphCast checkpointing strategies. | ||
|
|
||
| Provides pre-defined checkpointing strategies for different use cases. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def no_checkpointing(model: GraphCast): | ||
| """ | ||
| Disable all checkpointing (maximum speed, maximum memory). | ||
| """ | ||
| model.set_checkpoint_model(False) | ||
| model.set_checkpoint_encoder(False) | ||
| model.set_checkpoint_processor(0) | ||
| model.set_checkpoint_decoder(False) | ||
|
|
||
| @staticmethod | ||
| def full_checkpointing(model: GraphCast): | ||
| """ | ||
| Checkpoint entire model (maximum memory savings, slowest). | ||
| """ | ||
| model.set_checkpoint_model(True) | ||
|
|
||
| @staticmethod | ||
| def balanced_checkpointing(model: GraphCast): | ||
| """ | ||
| Balanced strategy (good memory savings, moderate speed). | ||
| """ | ||
| model.set_checkpoint_model(False) | ||
| model.set_checkpoint_encoder(True) | ||
| model.set_checkpoint_processor(-1) | ||
| model.set_checkpoint_decoder(True) | ||
|
|
||
| @staticmethod | ||
| def processor_only_checkpointing(model: GraphCast): | ||
| """ | ||
| Checkpoint only processor (targets main memory bottleneck). | ||
| """ | ||
| model.set_checkpoint_model(False) | ||
| model.set_checkpoint_encoder(False) | ||
| model.set_checkpoint_processor(-1) | ||
| model.set_checkpoint_decoder(False) | ||
|
|
||
| @staticmethod | ||
| def fine_grained_checkpointing(model: GraphCast): | ||
| """ | ||
| Fine-grained per-layer checkpointing (best memory savings). | ||
|
|
||
| This checkpoints each individual MLP and processor block separately. | ||
| Provides the best memory savings with moderate recomputation cost. | ||
| Note: Model must be created with use_checkpointing=True. | ||
| """ | ||
| # Fine-grained is enabled via use_checkpointing=True in __init__ | ||
| # This just disables hierarchical checkpointing | ||
| model.set_checkpoint_model(False) | ||
| model.set_checkpoint_encoder(False) | ||
| model.set_checkpoint_processor(0) | ||
| model.set_checkpoint_decoder(False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.