Skip to content

🚀 [FEA]: Make SFNO run on a 16GiB GPU #608

@forcadellvincent

Description

@forcadellvincent

Version

0.11.0

On which installation method(s) does this occur?

uv

Describe the issue

Context

Hi,

The SFNO model should be able to run on a 16GiB GPU. In the past, using ai-models from the ECMWF, I've managed to make it run on a NVIDIA P100 16GiB.

Unfortunately, with the current implementation, I get an OutOfMemory during the forward pass using an NVIDIA T4 16GiB.

A clone operation seems responsible for a slight memory increase at inference, hence reaching the memory limit.

Removing this cloning operation and allowing PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True allows the memory usage to decrease slightly, and make the run possible (see logs below)

Solution

  1. allow PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True as an env variable
  2. modify x = x.clone().squeeze(2) for x = x.squeeze(2) in here: https://github.com/NVIDIA/earth2studio/blob/6e75f557363447599b1b513ae9c9e751387f5ea4/earth2studio/models/px/sfno.py#L338C8-L338C33
  3. (optional) to further reduce memory usage, the squeeze and unsqueeze operations in the _forward function seems replaceable by indexeing the tensor in the time loop directly. See below a suggested code snippet.
    @torch.inference_mode()
    def _forward(
        self,
        x: torch.Tensor,
        coords: CoordSystem,
    ) -> tuple[torch.Tensor, CoordSystem]:
        output_coords = self.output_coords(coords)
        for j, _ in enumerate(coords["batch"]):
            for i, t in enumerate(coords["time"]):
                # https://github.com/NVIDIA/modulus-makani/blob/933b17d5a1ebfdb0e16e2ebbd7ee78cfccfda9e1/makani/third_party/climt/zenith_angle.py#L197
                # Requires time zone data
                t = [
                    datetime.fromisoformat(dt.isoformat() + "+00:00")
                    for dt in timearray_to_datetime(t + coords["lead_time"])
                ]
                x[j, i : i + 1, 0] = self.model(x[j, i : i + 1, 0], t, normalized_data=False)
        return x, output_coords

Logs memory profiling

  • right after the OOM error - with expandable_segments:True
Total accounted for by active tensors: 13850.92 MiB (right after OOM error)
TOP 10 LARGEST TENSORS:
           shape                              size_mb            dtype  requires_grad
0         (1, 768, 721, 1440)  3041.718750    torch.float32          False
1         (1, 384, 721, 1440)  1520.859375    torch.float32          False
2         (1, 384, 721, 1440)  1520.859375    torch.float32          False
3         (1, 384, 721, 1440)  1520.859375    torch.float32          False
4           (1, 77, 721, 1440)   304.963989    torch.float32          False
5           (1, 74, 721, 1440)   293.082275    torch.float32          False
6       (1, 1, 73, 721, 1440)   289.121704    torch.float32          False
7   (1, 1, 1, 73, 721, 1440)   289.121704    torch.float32          False
8       (1, 1, 73, 721, 1440)   289.121704    torch.float32          False
9       (1, 1, 73, 721, 1440)   289.121704    torch.float32          False

Metadata

Metadata

Assignees

No one assigned

    Labels

    ? - Needs TriageNeed team to review and classifybugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions