Skip to content

Bugs in Triton operator? #9

@XintianHan

Description

@XintianHan

Hi. Thanks for the nice triton implementation. Maybe I found a bug in the triton operator. It seems that the operator does not support head dim=192, but it supports dim=128 and 256.

For the example below

from lightning_attention import lightning_attention
import torch
# b h n d
b = 1
h = 16
n = 64
d = 192
q = torch.randn(b, h, n, d).to("cuda")
k = torch.randn(b, h, n, d).to("cuda")
v = torch.randn(b, h, n, d).to("cuda")
slope_rate = torch.ones(h).to("cuda")
output = lightning_attention(
    q, k, v, True, slope_rate.squeeze(-1).squeeze(-1)
)
print("test succeed!")

It gives me the error

  File "<string>", line 41, in _fwd_kernel
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 1621, in compile
    next_module = compile(module)
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 1550, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 963, in ast_to_ttir
    return optimize_triton_ir(mod)
  File "/home/tiger/.local/lib/python3.9/site-packages/triton/compiler.py", line 957, in optimize_triton_ir
    pm.run(mod)
RuntimeError: PassManager::run failed

On

line 370, in forward
    _fwd_kernel[grid](

Any advice here?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions