From 7c7c777a1b5ee7cf713119548a832073f6ffc674 Mon Sep 17 00:00:00 2001 From: Eve Callicoat Date: Sat, 12 Apr 2025 23:36:36 +0000 Subject: [PATCH 1/2] Cleanup FlashAttention 2/3 a bit - Revert submodule commits - Remove leftover wraperrs in fa2 - Remove superfluous .git* files in subdir --- flue-flash-attn-v2/cutlass | 2 +- flue-flash-attn-v2/src/lib.rs | 10 ++++------ flue-flash-attn-v3/.gitignore | 7 ------- flue-flash-attn-v3/.gitmodules | 3 --- flue-flash-attn-v3/cutlass | 2 +- 5 files changed, 6 insertions(+), 18 deletions(-) delete mode 100644 flue-flash-attn-v3/.gitignore delete mode 100644 flue-flash-attn-v3/.gitmodules diff --git a/flue-flash-attn-v2/cutlass b/flue-flash-attn-v2/cutlass index 62750a2..afa1772 160000 --- a/flue-flash-attn-v2/cutlass +++ b/flue-flash-attn-v2/cutlass @@ -1 +1 @@ -Subproject commit 62750a2b75c802660e4894434dc55e839f322277 +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/flue-flash-attn-v2/src/lib.rs b/flue-flash-attn-v2/src/lib.rs index f05d2d9..5daf4dd 100644 --- a/flue-flash-attn-v2/src/lib.rs +++ b/flue-flash-attn-v2/src/lib.rs @@ -143,10 +143,9 @@ impl FlashAttn { let seqlen_k_rounded = round_multiple(seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count) }?; let softmax_lse = dev - .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q) - .w()?; + .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; @@ -610,10 +609,9 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count) }?; let softmax_lse = dev - .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) - .w()?; + .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; diff --git a/flue-flash-attn-v3/.gitignore b/flue-flash-attn-v3/.gitignore deleted file mode 100644 index fc378ca..0000000 --- a/flue-flash-attn-v3/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -.idea -target -Cargo.lock -.venv -hkernel/build/* -__pycache__ -*.egg-info \ No newline at end of file diff --git a/flue-flash-attn-v3/.gitmodules b/flue-flash-attn-v3/.gitmodules deleted file mode 100644 index 2b822e9..0000000 --- a/flue-flash-attn-v3/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "cutlass"] - path = cutlass - url = https://github.com/NVIDIA/cutlass.git \ No newline at end of file diff --git a/flue-flash-attn-v3/cutlass b/flue-flash-attn-v3/cutlass index 62750a2..4c42f73 160000 --- a/flue-flash-attn-v3/cutlass +++ b/flue-flash-attn-v3/cutlass @@ -1 +1 @@ -Subproject commit 62750a2b75c802660e4894434dc55e839f322277 +Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d From 9c984ccf74e1bacc4f86f2139a503a4f6b39c84a Mon Sep 17 00:00:00 2001 From: Eve Callicoat Date: Sat, 12 Apr 2025 23:39:48 +0000 Subject: [PATCH 2/2] Rustfmt --- flue-flash-attn-v2/src/lib.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flue-flash-attn-v2/src/lib.rs b/flue-flash-attn-v2/src/lib.rs index 5daf4dd..1accaa1 100644 --- a/flue-flash-attn-v2/src/lib.rs +++ b/flue-flash-attn-v2/src/lib.rs @@ -144,8 +144,7 @@ impl FlashAttn { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::(elem_count) }?; - let softmax_lse = dev - .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; + let softmax_lse = dev.alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; @@ -610,8 +609,7 @@ impl FlashAttnVarLen { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::(elem_count) }?; - let softmax_lse = dev - .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q)?; + let softmax_lse = dev.alloc_zeros::(batch_size * num_heads * self.max_seqlen_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 };