-
-
Notifications
You must be signed in to change notification settings - Fork 90
feat: NVIDIA-style Hierarchical Gradient Checkpointing #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: NVIDIA-style Hierarchical Gradient Checkpointing #193
Conversation
|
sorry @jacobbieker i don't have the permissions to use the 'Request Review' button on this new PR yet(no gear icon visible) |
jacobbieker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall thanks for adding this. The docstyle needs to be fixed to match the style used for the rest of the repository. Additionally, I would rather not have a continously increasing number of very similar large benchmarking scripts. I would suggest refactoring the current one to also be able to run this benchmark, since both are benchmarking memory savings with different changes. Finally, the tests should be reorganized.
596bc54 to
ad304d6
Compare
for more information, see https://pre-commit.ci
jacobbieker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice, thanks for this. Just one minor moving around then we can merge
Removed unused test for checkpoint flags.
jacobbieker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for this.
Pull Request
Description
Implemented NVIDIA graphcast Hierarchical Gradient Checkpointing for the GraphProcessor and related layers, significantly reducing GPU memory usage during training.
Summary of Changes:
Hierarchical Checkpointing: Implemented a new GraphCast wrapper class that allows fine-grained control over checkpointing (Model, Encoder, Processor, Decoder levels).
Core Optimization: Modified graph_net_block.py to use torch.utils.checkpoint with use_reentrant=False and preserve_rng_state=False. This matches NVIDIA's PhysicsNeMo implementation for maximum memory efficiency
New API: Added a configuration interface (GraphCastConfig) with pre-set strategies (e.g., balanced, full, processor_only)
Motivation: Anyone can see future weather with ow vram gpu
Fixes #189
How Has This Been Tested?
I have added a new test suite and updated the benchmarking script
Unit Tests: Ran pytest tests/models/test_gradient_checkpointing.py which verifies:
Output Equivalence: Confirmed that model(x) produces the exact same output whether checkpointing is True or False.
Gradient Flow: Verified that gradients propagate correctly through the checkpointed layers.
Hierarchical Control: Verified that flags correctly toggle checkpointing for specific sub-modules.
Memory Benchmarks: Ran python scripts/benchmark_memory.py on an NVIDIA RTX 3050 (4GB VRAM).
Results:
5.0° Grid (Batch 1): Memory dropped from 3012 MB (Baseline) to 856 MB (Checkpointed). (71.6% Reduction)
2.5° Grid (Batch 1-4): Previously failed with OOM. Now runs successfully (Peak mem: ~1.2GB - 2.6GB).
[x] Yes
If your changes affect data processing, have you plotted any changes? i.e. have you done a quick sanity check?
[x] Yes (Verified numerical equivalence of model outputs)
Checklist: