-
Notifications
You must be signed in to change notification settings - Fork 89
Open
Labels
? - Needs TriageNeed team to review and classifyNeed team to review and classifybugSomething isn't workingSomething isn't working
Description
Version
0.11.0
On which installation method(s) does this occur?
uv
Describe the issue
Context
Hi,
The SFNO model should be able to run on a 16GiB GPU. In the past, using ai-models from the ECMWF, I've managed to make it run on a NVIDIA P100 16GiB.
Unfortunately, with the current implementation, I get an OutOfMemory during the forward pass using an NVIDIA T4 16GiB.
A clone operation seems responsible for a slight memory increase at inference, hence reaching the memory limit.
Removing this cloning operation and allowing PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True allows the memory usage to decrease slightly, and make the run possible (see logs below)
Solution
- allow
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Trueas an env variable - modify
x = x.clone().squeeze(2)forx = x.squeeze(2)in here: https://github.com/NVIDIA/earth2studio/blob/6e75f557363447599b1b513ae9c9e751387f5ea4/earth2studio/models/px/sfno.py#L338C8-L338C33 - (optional) to further reduce memory usage, the squeeze and unsqueeze operations in the _forward function seems replaceable by indexeing the tensor in the time loop directly. See below a suggested code snippet.
@torch.inference_mode()
def _forward(
self,
x: torch.Tensor,
coords: CoordSystem,
) -> tuple[torch.Tensor, CoordSystem]:
output_coords = self.output_coords(coords)
for j, _ in enumerate(coords["batch"]):
for i, t in enumerate(coords["time"]):
# https://github.com/NVIDIA/modulus-makani/blob/933b17d5a1ebfdb0e16e2ebbd7ee78cfccfda9e1/makani/third_party/climt/zenith_angle.py#L197
# Requires time zone data
t = [
datetime.fromisoformat(dt.isoformat() + "+00:00")
for dt in timearray_to_datetime(t + coords["lead_time"])
]
x[j, i : i + 1, 0] = self.model(x[j, i : i + 1, 0], t, normalized_data=False)
return x, output_coords
Logs memory profiling
- right after the OOM error - with expandable_segments:True
Total accounted for by active tensors: 13850.92 MiB (right after OOM error)
TOP 10 LARGEST TENSORS:
shape size_mb dtype requires_grad
0 (1, 768, 721, 1440) 3041.718750 torch.float32 False
1 (1, 384, 721, 1440) 1520.859375 torch.float32 False
2 (1, 384, 721, 1440) 1520.859375 torch.float32 False
3 (1, 384, 721, 1440) 1520.859375 torch.float32 False
4 (1, 77, 721, 1440) 304.963989 torch.float32 False
5 (1, 74, 721, 1440) 293.082275 torch.float32 False
6 (1, 1, 73, 721, 1440) 289.121704 torch.float32 False
7 (1, 1, 1, 73, 721, 1440) 289.121704 torch.float32 False
8 (1, 1, 73, 721, 1440) 289.121704 torch.float32 False
9 (1, 1, 73, 721, 1440) 289.121704 torch.float32 False
Metadata
Metadata
Assignees
Labels
? - Needs TriageNeed team to review and classifyNeed team to review and classifybugSomething isn't workingSomething isn't working