Skip to content

KiritoHugh/attention_kernel_triton

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

attention_kernel_triton

Test on NVIDIA RTX 5000 Ada Generation

  • prefill_flash_attention.py

Output:

>> Q: torch.Size([8, 64, 1024, 64]), K: torch.Size([8, 16, 1024, 64]), V: torch.Size([8, 16, 1024, 64]), causal: True, GQA_group_size: 4
Benchmarking reference implementation...
Reference implementation: 35.219 ms
Benchmarking Triton implementation...
Triton implementation: 0.780 ms
Speedup: 45.180x
  • prefill_flash_attention_var_len_padding.py

Output:

>> Lengths: tensor([1024,  468,  631,  258,  353,  599,  732,   94], device='cuda:0')
>> Q: torch.Size([8, 64, 1024, 64]), K: torch.Size([8, 16, 1024, 64]), V: torch.Size([8, 16, 1024, 64]), causal: True, GQA_group_size: 4
Benchmarking reference implementation...
Reference implementation: 75.014 ms
Benchmarking Triton implementation...
Triton implementation: 0.596 ms
Speedup: 125.807x
  • decode_flash_attention_redundant.py

Output:

>> q: torch.Size([8, 64, 1, 64]), K: torch.Size([8, 16, 1024, 64]), V: torch.Size([8, 16, 1024, 64]), GQA_group_size: 4
 shape of ref: torch.Size([8, 64, 1, 64])
 shape of tri: torch.Size([8, 64, 1, 64])
Benchmarking reference implementation...
Reference implementation: 0.692 ms
Benchmarking Triton implementation...
Triton implementation: 0.102 ms
Speedup: 6.812x
  • decode_flash_attention_redundant_var_len_paged.py

Output:

shape of ref_O: torch.Size([3, 4, 64])
shape of triton_O: torch.Size([3, 4, 64])
Number of NaNs in triton_O: 0
Ratio of NaNs in triton_O: 0.0
Max absolute values - ref: 1.0302734375  tri: 1.0302734375
Max absolute difference: 0.0009765625
Benchmarking reference implementation...
Reference implementation: 0.763 ms
Benchmarking Triton implementation...
Triton implementation: 0.017 ms
Speedup: 44.067x
  • sparsetoken_decode_flash_attention_redundant_var_len_paged.py

Output:

real kept ratio: 0.01997607291090003
shape of ref_O: torch.Size([3, 32, 256])
shape of tri_O: torch.Size([3, 32, 256])
Number of NaNs in triton_O: 0
Ratio of NaNs in triton_O: 0.0
Max absolute values - ref: 0.183837890625  tri: 0.183837890625
Max absolute difference: 0.0001220703125
Benchmarking reference implementation...
Reference implementation: 7751.282 ms
Benchmarking Triton implementation...
Triton implementation: 0.846 ms
Speedup: 9162.089x
  • sparsetoken_decode_flash_attention_redundant.py

Output:

>> q: torch.Size([1, 32, 1, 128]), K: torch.Size([1, 8, 32000, 128]), V: torch.Size([1, 8, 32000, 128]), GQA_group_size: 4
real kept ratio: 0.019923828125
shape of ref_O: torch.Size([1, 32, 1, 128])
shape of tri_O: torch.Size([1, 32, 1, 128])
Number of NaNs in triton_O: 0
Ratio of NaNs in triton_O: 0.0
shape of ref_O_by_mask: torch.Size([1, 32, 1, 128])
Max absolute values - ref: 0.07025146484375  tri: 0.07025146484375
Max absolute difference: 3.0517578125e-05
Benchmarking reference implementation...
Reference implementation: 4.111 ms
Benchmarking naive_by_mask implementation...
Reference by mask implementation: 2.412 ms
Benchmarking Triton implementation...
Triton implementation: 0.057 ms
Speedup over reference: 71.995x
Speedup over reference by mask: 42.236x

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages