Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
WrapperImageModel,
WrapperMetaModel,
)
from .graphcast import GraphCast, GraphCastConfig
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
5 changes: 5 additions & 0 deletions graph_weather/models/graphcast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""GraphCast model with gradient checkpointing."""

from .model import GraphCast, GraphCastConfig

__all__ = ["GraphCast", "GraphCastConfig"]
345 changes: 345 additions & 0 deletions graph_weather/models/graphcast/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
"""GraphCast model with hierarchical gradient checkpointing.

This module provides a complete GraphCast-style weather forecasting model
with NVIDIA-style hierarchical gradient checkpointing for memory-efficient training.

Based on:
- NVIDIA PhysicsNeMo GraphCast implementation
"""

from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint

from graph_weather.models.layers.decoder import Decoder
from graph_weather.models.layers.encoder import Encoder
from graph_weather.models.layers.processor import Processor


class GraphCast(torch.nn.Module):
"""GraphCast model with hierarchical gradient checkpointing.

This model combines Encoder, Processor, and Decoder with NVIDIA-style
hierarchical checkpointing controls for flexible memory-compute tradeoffs.

Hierarchical checkpointing methods:
- set_checkpoint_model(flag): Checkpoint entire forward pass
- set_checkpoint_encoder(flag): Checkpoint encoder section
- set_checkpoint_processor(segments): Checkpoint processor with configurable segments
- set_checkpoint_decoder(flag): Checkpoint decoder section
"""

def __init__(
self,
lat_lons: list,
resolution: int = 2,
input_dim: int = 78,
output_dim: int = 78,
hidden_dim: int = 256,
num_processor_blocks: int = 9,
hidden_layers: int = 2,
mlp_norm_type: str = "LayerNorm",
use_checkpointing: bool = False,
efficient_batching: bool = False,
):
"""
Initialize GraphCast model with hierarchical checkpointing support.

Args:
lat_lons: List of (lat, lon) tuples defining the grid points
resolution: H3 resolution level
input_dim: Input feature dimension
output_dim: Output feature dimension
hidden_dim: Hidden dimension for all layers
num_processor_blocks: Number of message passing blocks in processor
hidden_layers: Number of hidden layers in MLPs
mlp_norm_type: Normalization type for MLPs
use_checkpointing: Enable fine-grained checkpointing in all layers
efficient_batching: Use efficient batching (avoid graph replication)
"""
super().__init__()

self.lat_lons = lat_lons
self.input_dim = input_dim
self.output_dim = output_dim
self.efficient_batching = efficient_batching

# Initialize components
self.encoder = Encoder(
lat_lons=lat_lons,
resolution=resolution,
input_dim=input_dim,
output_dim=hidden_dim,
output_edge_dim=hidden_dim,
hidden_dim_processor_node=hidden_dim,
hidden_dim_processor_edge=hidden_dim,
hidden_layers_processor_node=hidden_layers,
hidden_layers_processor_edge=hidden_layers,
mlp_norm_type=mlp_norm_type,
use_checkpointing=use_checkpointing,
efficient_batching=efficient_batching,
)

self.processor = Processor(
input_dim=hidden_dim,
edge_dim=hidden_dim,
num_blocks=num_processor_blocks,
hidden_dim_processor_node=hidden_dim,
hidden_dim_processor_edge=hidden_dim,
hidden_layers_processor_node=hidden_layers,
hidden_layers_processor_edge=hidden_layers,
mlp_norm_type=mlp_norm_type,
use_checkpointing=use_checkpointing,
)

self.decoder = Decoder(
lat_lons=lat_lons,
resolution=resolution,
input_dim=hidden_dim,
output_dim=output_dim,
hidden_dim_processor_node=hidden_dim,
hidden_dim_processor_edge=hidden_dim,
hidden_layers_processor_node=hidden_layers,
hidden_layers_processor_edge=hidden_layers,
mlp_norm_type=mlp_norm_type,
hidden_dim_decoder=hidden_dim,
hidden_layers_decoder=hidden_layers,
use_checkpointing=use_checkpointing,
efficient_batching=efficient_batching,
)

# Hierarchical checkpointing flags (default: use fine-grained checkpointing)
self._checkpoint_model = False
self._checkpoint_encoder = False
self._checkpoint_processor_segments = 0 # 0 = use layer's internal checkpointing
self._checkpoint_decoder = False

def set_checkpoint_model(self, checkpoint_flag: bool):
"""
Checkpoint entire model as a single segment.

When enabled, creates one checkpoint for the entire forward pass.
This provides maximum memory savings but highest recomputation cost.
Disables all other hierarchical checkpointing when enabled.

Args:
checkpoint_flag: If True, checkpoint entire model. If False, use hierarchical checkpointing.
"""
self._checkpoint_model = checkpoint_flag
if checkpoint_flag:
# Disable all fine-grained checkpointing
self._checkpoint_encoder = False
self._checkpoint_processor_segments = 0
self._checkpoint_decoder = False

def set_checkpoint_encoder(self, checkpoint_flag: bool):
"""
Checkpoint encoder section.

Checkpoints the encoder forward pass as a single segment.
Only effective when set_checkpoint_model(False).

Args:
checkpoint_flag: If True, checkpoint encoder section.
"""
self._checkpoint_encoder = checkpoint_flag

def set_checkpoint_processor(self, checkpoint_segments: int):
"""
Checkpoint processor with configurable segments.

Controls how the processor is checkpointed:
- 0: Use processor's internal per-block checkpointing
- -1: Checkpoint entire processor as one segment
- N > 0: Checkpoint every N blocks (not yet implemented)

Only effective when set_checkpoint_model(False).

Args:
checkpoint_segments: Checkpointing strategy (0, -1, or positive integer).
"""
self._checkpoint_processor_segments = checkpoint_segments

def set_checkpoint_decoder(self, checkpoint_flag: bool):
"""
Checkpoint decoder section.

Checkpoints the decoder forward pass as a single segment.
Only effective when set_checkpoint_model(False).

Args:
checkpoint_flag: If True, checkpoint decoder section.
"""
self._checkpoint_decoder = checkpoint_flag

def _encoder_forward(self, features: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Encoder forward pass (for checkpointing).
"""
return self.encoder(features)

def _processor_forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Tensor,
batch_size: Optional[int] = None,
) -> Tensor:
"""
Processor forward pass (for checkpointing).
"""
return self.processor(
x,
edge_index,
edge_attr,
batch_size=batch_size,
efficient_batching=self.efficient_batching,
)

def _decoder_forward(
self,
processed_features: Tensor,
original_features: Tensor,
batch_size: int,
) -> Tensor:
"""
Decoder forward pass (for checkpointing).
"""
return self.decoder(processed_features, original_features, batch_size)

def _custom_forward(self, features: Tensor) -> Tensor:
"""
Forward pass with hierarchical checkpointing.
"""
batch_size = features.shape[0]

# Encoder
if self._checkpoint_encoder:
latent_features, edge_index, edge_attr = checkpoint(
self._encoder_forward,
features,
use_reentrant=False,
preserve_rng_state=False,
)
else:
latent_features, edge_index, edge_attr = self.encoder(features)

# Processor
if self._checkpoint_processor_segments == -1:
# Checkpoint entire processor as one block
processed_features = checkpoint(
self._processor_forward,
latent_features,
edge_index,
edge_attr,
batch_size if self.efficient_batching else None,
use_reentrant=False,
preserve_rng_state=False,
)
else:
# Use processor's internal checkpointing (controlled by use_checkpointing)
processed_features = self.processor(
latent_features,
edge_index,
edge_attr,
batch_size=batch_size,
efficient_batching=self.efficient_batching,
)

# Decoder
if self._checkpoint_decoder:
output = checkpoint(
self._decoder_forward,
processed_features,
features,
batch_size,
use_reentrant=False,
preserve_rng_state=False,
)
else:
output = self.decoder(processed_features, features, batch_size)

return output

def forward(self, features: Tensor) -> Tensor:
"""Forward pass through GraphCast model.

Args:
features: Input features of shape [batch_size, num_points, input_dim]

Returns:
Output predictions of shape [batch_size, num_points, output_dim]
"""
if self._checkpoint_model:
# Checkpoint entire model as one segment
return checkpoint(
self._custom_forward,
features,
use_reentrant=False,
preserve_rng_state=False,
)
else:
# Use hierarchical checkpointing
return self._custom_forward(features)


class GraphCastConfig:
"""Configuration helper for GraphCast checkpointing strategies.

Provides pre-defined checkpointing strategies for different use cases.
"""

@staticmethod
def no_checkpointing(model: GraphCast):
"""
Disable all checkpointing (maximum speed, maximum memory).
"""
model.set_checkpoint_model(False)
model.set_checkpoint_encoder(False)
model.set_checkpoint_processor(0)
model.set_checkpoint_decoder(False)

@staticmethod
def full_checkpointing(model: GraphCast):
"""
Checkpoint entire model (maximum memory savings, slowest).
"""
model.set_checkpoint_model(True)

@staticmethod
def balanced_checkpointing(model: GraphCast):
"""
Balanced strategy (good memory savings, moderate speed).
"""
model.set_checkpoint_model(False)
model.set_checkpoint_encoder(True)
model.set_checkpoint_processor(-1)
model.set_checkpoint_decoder(True)

@staticmethod
def processor_only_checkpointing(model: GraphCast):
"""
Checkpoint only processor (targets main memory bottleneck).
"""
model.set_checkpoint_model(False)
model.set_checkpoint_encoder(False)
model.set_checkpoint_processor(-1)
model.set_checkpoint_decoder(False)

@staticmethod
def fine_grained_checkpointing(model: GraphCast):
"""
Fine-grained per-layer checkpointing (best memory savings).

This checkpoints each individual MLP and processor block separately.
Provides the best memory savings with moderate recomputation cost.
Note: Model must be created with use_checkpointing=True.
"""
# Fine-grained is enabled via use_checkpointing=True in __init__
# This just disables hierarchical checkpointing
model.set_checkpoint_model(False)
model.set_checkpoint_encoder(False)
model.set_checkpoint_processor(0)
model.set_checkpoint_decoder(False)
1 change: 1 addition & 0 deletions graph_weather/models/layers/assimilator_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
hidden_layers_node=hidden_layers_processor_node,
hidden_layers_edge=hidden_layers_processor_edge,
norm_type=mlp_norm_type,
use_checkpointing=self.use_checkpointing,
)
self.node_decoder = MLP(
input_dim,
Expand Down
1 change: 1 addition & 0 deletions graph_weather/models/layers/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
hidden_layers_processor_node,
hidden_layers_processor_edge,
mlp_norm_type,
use_checkpointing=self.use_checkpointing,
)

def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
Loading