Skip to content

Conversation

@Sidharth1743
Copy link
Contributor

Pull Request

Description

Implemented NVIDIA graphcast Hierarchical Gradient Checkpointing for the GraphProcessor and related layers, significantly reducing GPU memory usage during training.

Summary of Changes:
Hierarchical Checkpointing: Implemented a new GraphCast wrapper class that allows fine-grained control over checkpointing (Model, Encoder, Processor, Decoder levels).
Core Optimization: Modified graph_net_block.py to use torch.utils.checkpoint with use_reentrant=False and preserve_rng_state=False. This matches NVIDIA's PhysicsNeMo implementation for maximum memory efficiency
New API: Added a configuration interface (GraphCastConfig) with pre-set strategies (e.g., balanced, full, processor_only)

Motivation: Anyone can see future weather with ow vram gpu

Fixes #189

How Has This Been Tested?

I have added a new test suite and updated the benchmarking script

  1. Unit Tests: Ran pytest tests/models/test_gradient_checkpointing.py which verifies:
    Output Equivalence: Confirmed that model(x) produces the exact same output whether checkpointing is True or False.
    Gradient Flow: Verified that gradients propagate correctly through the checkpointed layers.
    Hierarchical Control: Verified that flags correctly toggle checkpointing for specific sub-modules.

  2. Memory Benchmarks: Ran python scripts/benchmark_memory.py on an NVIDIA RTX 3050 (4GB VRAM).
    Results:
    5.0° Grid (Batch 1): Memory dropped from 3012 MB (Baseline) to 856 MB (Checkpointed). (71.6% Reduction)
    2.5° Grid (Batch 1-4): Previously failed with OOM. Now runs successfully (Peak mem: ~1.2GB - 2.6GB).
    [x] Yes

If your changes affect data processing, have you plotted any changes? i.e. have you done a quick sanity check?

[x] Yes (Verified numerical equivalence of model outputs)

Checklist:

  • [ x] My code follows OCF's coding style guidelines
  • [ x] I have performed a self-review of my own code
  • [ x] I have made corresponding changes to the documentation
  • [ x] I have added tests that prove my fix is effective or that my feature works
  • [ x] I have checked my code and corrected any misspellings

@Sidharth1743
Copy link
Contributor Author

sorry @jacobbieker i don't have the permissions to use the 'Request Review' button on this new PR yet(no gear icon visible)
ready for review!

@jacobbieker jacobbieker self-requested a review January 5, 2026 10:08
Copy link
Member

@jacobbieker jacobbieker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall thanks for adding this. The docstyle needs to be fixed to match the style used for the rest of the repository. Additionally, I would rather not have a continously increasing number of very similar large benchmarking scripts. I would suggest refactoring the current one to also be able to run this benchmark, since both are benchmarking memory savings with different changes. Finally, the tests should be reorganized.

@Sidharth1743 Sidharth1743 force-pushed the feat/gradient-checkpointing branch from 596bc54 to ad304d6 Compare January 7, 2026 17:36
Copy link
Member

@jacobbieker jacobbieker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice, thanks for this. Just one minor moving around then we can merge

Removed unused test for checkpoint flags.
Copy link
Member

@jacobbieker jacobbieker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for this.

@jacobbieker jacobbieker merged commit c65cc70 into openclimatefix:main Jan 7, 2026
1 check was pending
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimization: Gradient checkpointing for deep processor layers

2 participants