Skip to content

Shared Memory Out of Resources #101

@JoeZhang-0x000

Description

@JoeZhang-0x000

Is there an existing issue for this?

  • I have searched the existing issues.

Describe the bug:

When run the following codes in an environment with:

  • torch 2.7
  • triton 3.5
    An error occurs.
    The message says that the program is acquiring shared memory resource much beyond the hardware limitation.

This envirnonment is highly suggtested to be reproudced using nvcr-2504(https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-04.html)

To reproduce:

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor, block_size
import torch
import math
from icecream import ic



BLOCK_SIZE_M = block_size()
BLOCK_SIZE_N = block_size()
BLOCK_SIZE_K = block_size()

def arrangement(
    input,
    other,
    output,
    block_size_m=None,
    block_size_n=None,
    block_size_k=None,
):
    if block_size_m is None:
        block_size_m = BLOCK_SIZE_M

    if block_size_n is None:
        block_size_n = BLOCK_SIZE_N

    if block_size_k is None:
        block_size_k = BLOCK_SIZE_K

    output_arranged = output.tile((block_size_m, block_size_n))

    input_arranged = input.tile((block_size_m, block_size_k))
    input_arranged = input_arranged.tile((1, -1))
    input_arranged = input_arranged.expand((-1, output_arranged.shape[1]))
    input_arranged.dtype = input_arranged.dtype.squeeze(0)

    other = other.permute((1, 0))
    other_arranged = other.tile((block_size_k, block_size_n))
    other_arranged = other_arranged.tile((-1, 1))
    other_arranged = other_arranged.expand((output_arranged.shape[0], -1))
    other_arranged.dtype = other_arranged.dtype.squeeze(1)

    return input_arranged, other_arranged, output_arranged

def arrangement_with_bias(
    input,
    other,
    output,
    bias,
    block_size_m=None,
    block_size_n=None,
    block_size_k=None,
):
    if block_size_m is None:
        block_size_m = BLOCK_SIZE_M

    if block_size_n is None:
        block_size_n = BLOCK_SIZE_N

    if block_size_k is None:
        block_size_k = BLOCK_SIZE_K

    input_arranged, other_arranged, output_arranged = arrangement(
        input,
        other,
        output,
        block_size_m,
        block_size_n,
        block_size_k,
    )

    # bias (1, N)
    bias_arranged = bias.tile((1, block_size_n)) # (1, N/BN) x (1, BN)
    bias_arranged = bias_arranged.expand((output_arranged.shape[0], -1)) # (M/BM, N/BN) x (1, BN)
    return input_arranged, other_arranged, output_arranged, bias_arranged

def application(input, other, output):
    accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
    for k in range(input.shape[0]):
        accumulator += ntl.dot(input[k], other[k])
    output = accumulator

def application_with_bias(input, other, output, bias):
    application(input, other, output)
    output = output + bias


kernel = ninetoothed.make(arrangement, application, (Tensor(2) for _ in range(3)), max_num_configs=2)
kernel_with_bias = ninetoothed.make(arrangement_with_bias, application_with_bias, (Tensor(2) for _ in range(4)), max_num_configs=2)

def linear(input, other, bias=None):
    # A: [M, K], B: [N, K]
    # C = A @ B.T -> C: [M, N]
    assert input.shape[1] == other.shape[1], "Inner dimension K must match for NT GEMM"
    output_shape = (input.shape[0], other.shape[0])
    output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
    if bias is not None:
        assert bias.shape == (other.shape[0],), "Bias shape must match output dimension N"
        kernel_with_bias(input, other, output, bias.unsqueeze(0))
    else:
        kernel(input, other, output)

    return output


if __name__ == "__main__":
    DEV = "cuda"
    M, N, K = 10, 1024, 128
    x = torch.randn((M, K), device=DEV)
    # column major
    w = torch.randn((N, K), device=DEV)

    ref = torch.nn.functional.linear(x, w)
    out = linear(x, w)

    print((ref - out).abs().max().item())

Expected behavior:

  File "/usr/local/lib/python3.12/dist-packages/ninetoothed/jit.py", line 128, in __call__
    return self._launch(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.ninetoothed/0d50f784bb02b2257c47e55e747d790c6b13998ce24eb2d6e308c303e88c7c71.py", line 31, in launch_application
    application_with_auto_tuning[lambda meta: (((ninetoothed_tensor_0.size(0) - (meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_0'] - 1) - 1 + meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_0'] - 1) // meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_0'] + 1 - 1 + 1 - 1 + 1) * ((ninetoothed_tensor_2.size(1) - (meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_1'] - 1) - 1 + meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_1'] - 1) // meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_1'] + 1),)](ninetoothed_tensor_0, ninetoothed_tensor_0.size(0), ninetoothed_tensor_0.size(1), ninetoothed_tensor_0.stride(0), ninetoothed_tensor_0.stride(1), ninetoothed_tensor_1, ninetoothed_tensor_1.size(0), ninetoothed_tensor_1.size(1), ninetoothed_tensor_1.stride(0), ninetoothed_tensor_1.stride(1), ninetoothed_tensor_2, ninetoothed_tensor_2.size(0), ninetoothed_tensor_2.size(1), ninetoothed_tensor_2.stride(0), ninetoothed_tensor_2.stride(1))
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 419, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 250, in run
    ret = self.fn.run(
          ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 756, in run
    launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 490, in launch_metadata
    self._init_handles()
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 464, in _init_handles
    raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 456, in raise_
    raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 1835008, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.

Environment details:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions