-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Is there an existing issue for this?
- I have searched the existing issues.
Describe the bug:
When run the following codes in an environment with:
- torch 2.7
- triton 3.5
An error occurs.
The message says that the program is acquiring shared memory resource much beyond the hardware limitation.
This envirnonment is highly suggtested to be reproudced using nvcr-2504(https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-04.html)
To reproduce:
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor, block_size
import torch
import math
from icecream import ic
BLOCK_SIZE_M = block_size()
BLOCK_SIZE_N = block_size()
BLOCK_SIZE_K = block_size()
def arrangement(
input,
other,
output,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
if block_size_n is None:
block_size_n = BLOCK_SIZE_N
if block_size_k is None:
block_size_k = BLOCK_SIZE_K
output_arranged = output.tile((block_size_m, block_size_n))
input_arranged = input.tile((block_size_m, block_size_k))
input_arranged = input_arranged.tile((1, -1))
input_arranged = input_arranged.expand((-1, output_arranged.shape[1]))
input_arranged.dtype = input_arranged.dtype.squeeze(0)
other = other.permute((1, 0))
other_arranged = other.tile((block_size_k, block_size_n))
other_arranged = other_arranged.tile((-1, 1))
other_arranged = other_arranged.expand((output_arranged.shape[0], -1))
other_arranged.dtype = other_arranged.dtype.squeeze(1)
return input_arranged, other_arranged, output_arranged
def arrangement_with_bias(
input,
other,
output,
bias,
block_size_m=None,
block_size_n=None,
block_size_k=None,
):
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
if block_size_n is None:
block_size_n = BLOCK_SIZE_N
if block_size_k is None:
block_size_k = BLOCK_SIZE_K
input_arranged, other_arranged, output_arranged = arrangement(
input,
other,
output,
block_size_m,
block_size_n,
block_size_k,
)
# bias (1, N)
bias_arranged = bias.tile((1, block_size_n)) # (1, N/BN) x (1, BN)
bias_arranged = bias_arranged.expand((output_arranged.shape[0], -1)) # (M/BM, N/BN) x (1, BN)
return input_arranged, other_arranged, output_arranged, bias_arranged
def application(input, other, output):
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
for k in range(input.shape[0]):
accumulator += ntl.dot(input[k], other[k])
output = accumulator
def application_with_bias(input, other, output, bias):
application(input, other, output)
output = output + bias
kernel = ninetoothed.make(arrangement, application, (Tensor(2) for _ in range(3)), max_num_configs=2)
kernel_with_bias = ninetoothed.make(arrangement_with_bias, application_with_bias, (Tensor(2) for _ in range(4)), max_num_configs=2)
def linear(input, other, bias=None):
# A: [M, K], B: [N, K]
# C = A @ B.T -> C: [M, N]
assert input.shape[1] == other.shape[1], "Inner dimension K must match for NT GEMM"
output_shape = (input.shape[0], other.shape[0])
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
if bias is not None:
assert bias.shape == (other.shape[0],), "Bias shape must match output dimension N"
kernel_with_bias(input, other, output, bias.unsqueeze(0))
else:
kernel(input, other, output)
return output
if __name__ == "__main__":
DEV = "cuda"
M, N, K = 10, 1024, 128
x = torch.randn((M, K), device=DEV)
# column major
w = torch.randn((N, K), device=DEV)
ref = torch.nn.functional.linear(x, w)
out = linear(x, w)
print((ref - out).abs().max().item())
Expected behavior:
File "/usr/local/lib/python3.12/dist-packages/ninetoothed/jit.py", line 128, in __call__
return self._launch(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.ninetoothed/0d50f784bb02b2257c47e55e747d790c6b13998ce24eb2d6e308c303e88c7c71.py", line 31, in launch_application
application_with_auto_tuning[lambda meta: (((ninetoothed_tensor_0.size(0) - (meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_0'] - 1) - 1 + meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_0'] - 1) // meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_0'] + 1 - 1 + 1 - 1 + 1) * ((ninetoothed_tensor_2.size(1) - (meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_1'] - 1) - 1 + meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_1'] - 1) // meta['ninetoothed_meta_prefix_ninetoothed_BLOCK_SIZE_1'] + 1),)](ninetoothed_tensor_0, ninetoothed_tensor_0.size(0), ninetoothed_tensor_0.size(1), ninetoothed_tensor_0.stride(0), ninetoothed_tensor_0.stride(1), ninetoothed_tensor_1, ninetoothed_tensor_1.size(0), ninetoothed_tensor_1.size(1), ninetoothed_tensor_1.stride(0), ninetoothed_tensor_1.stride(1), ninetoothed_tensor_2, ninetoothed_tensor_2.size(0), ninetoothed_tensor_2.size(1), ninetoothed_tensor_2.stride(0), ninetoothed_tensor_2.stride(1))
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 419, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 250, in run
ret = self.fn.run(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 756, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 490, in launch_metadata
self._init_handles()
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 464, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 456, in raise_
raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 1835008, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
Environment details:
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working