diff --git a/book/src/week2-04-flash-attention.md b/book/src/week2-04-flash-attention.md index 3578e82..4d97f78 100644 --- a/book/src/week2-04-flash-attention.md +++ b/book/src/week2-04-flash-attention.md @@ -100,6 +100,8 @@ Before creating the lazy output array, validate all shape and dtype constraints Then implement `FlashAttention::eval_cpu(...)` with tiled online softmax. Use `Br = 32` and `Bc = 32`, iterate over `(n, i, j)` tiles, map query heads to KV heads with `q_kv_heads_ratio = num_heads / num_kv_heads`, and accumulate in float32. Mask values should be applied in each tile before updating `m_i` and `l_i`. +When `mask == "causal"`, treat it as a block-level optimization opportunity: if a tile is fully invalid, skip that tile entirely; if a tile is fully valid, skip mask read/add for that tile and continue with matmul + online softmax. Also note that `L` and `S` are not always equal in causal attention, so do not hardcode logic that assumes `L == S`. + You can test your implementation by running: ```bash diff --git a/src/extensions_ref/bindings.cpp b/src/extensions_ref/bindings.cpp index afb4aec..ed8dd39 100644 --- a/src/extensions_ref/bindings.cpp +++ b/src/extensions_ref/bindings.cpp @@ -32,7 +32,7 @@ NB_MODULE(_ext, m) { )"); m.def("flash_attention", &tiny_llm_ext_ref::flash_attention, "query"_a, "key"_a, "value"_a, "mask"_a, "scale"_a = 1.0, - "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"( + "is_causal"_a = false, "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"( Flash attention layer Args: @@ -41,6 +41,7 @@ NB_MODULE(_ext, m) { value (array): Value array. mask (array): Mask array. scale (float): Scaling factor. + is_causal (bool): Enable causal-mask fast path. Returns: array: ``softmax(query @ key.T * scale) @ value`` diff --git a/src/extensions_ref/src/flash_attention.cpp b/src/extensions_ref/src/flash_attention.cpp index 006ae51..1d1e375 100644 --- a/src/extensions_ref/src/flash_attention.cpp +++ b/src/extensions_ref/src/flash_attention.cpp @@ -14,7 +14,8 @@ namespace tiny_llm_ext_ref { mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, - const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s) { + const float scale, const bool is_causal, const int num_kv_heads, const int num_heads, + mx::StreamOrDevice s) { if (q.dtype() != mx::float32 || k.dtype() != mx::float32 || v.dtype() != mx::float32 || mask.dtype() != mx::float32) { throw std::runtime_error("flash_attention: all input arrays must be float32"); } @@ -54,7 +55,8 @@ mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::arra } return mx::array(q.shape(), mx::float32, - std::make_shared(to_stream(s), scale, num_kv_heads, num_heads), {q, k, v, mask}); + std::make_shared(to_stream(s), scale, is_causal, num_kv_heads, num_heads), + {q, k, v, mask}); } void FlashAttention::eval_cpu(const std::vector &inputs, std::vector &outputs) { @@ -91,7 +93,7 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< encoder.dispatch([out_ptr = out.data(), out_shape = out.shape(), q = mx::array::unsafe_weak_copy(q), k = mx::array::unsafe_weak_copy(k), v = mx::array::unsafe_weak_copy(v), mask = mx::array::unsafe_weak_copy(mask), num_heads = num_heads_, num_kv_heads = num_kv_heads_, - scale = scale_]() { + scale = scale_, is_causal = is_causal_]() { const int64_t N = q.shape()[0]; const int64_t L = q.shape()[1]; const int64_t S = k.shape()[1]; @@ -126,7 +128,14 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< std::vector o_i(Br * E, 0.0); std::vector l_i(Br, 0.0); std::vector m_i(Br, -std::numeric_limits::infinity()); + const int64_t causal_offset = S - L; for (int64_t j = 0; j < Tc; j++) { + int64_t row_max = i * Br + br_upper_bound - 1; + int64_t col_min = j * Bc; + // Causal masking: if the entire block of K is masked out by causal mask, we can skip the computation for this block. + if (is_causal && col_min > row_max + causal_offset) { + continue; + } int bc_upper_bound = std::min(S - j * Bc, Bc); // Each kernel processes a block of Br x Bc // Load Kj and Vj @@ -154,14 +163,20 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< } // Add mask and scale + int64_t row_min = i * Br; + int64_t col_max = j * Bc + bc_upper_bound - 1; + bool block_all_valid = is_causal && (col_max <= row_min + causal_offset); for (int64_t a = 0; a < br_upper_bound; a++) { for (int64_t b = 0; b < bc_upper_bound; b++) { - int m_idx_1 = n; - int m_idx_2 = i * Br + a; - int m_idx_3 = j * Bc + b; - int m_idx_converted = mx::elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask); s_i[a * Bc + b] *= scale; - s_i[a * Bc + b] += m_ptr[m_idx_converted]; + // If the block is all valid, we don't need to add mask because it's all zeros. Otherwise we need to add mask for each element. + if (!block_all_valid) { + int m_idx_1 = n; + int m_idx_2 = i * Br + a; + int m_idx_3 = j * Bc + b; + int m_idx_converted = mx::elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask); + s_i[a * Bc + b] += m_ptr[m_idx_converted]; + } } } @@ -283,15 +298,16 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< const int S = k.shape()[1]; const int E = q.shape()[2]; - compute_encoder.set_bytes(N, 7); - compute_encoder.set_bytes(L, 8); - compute_encoder.set_bytes(S, 9); - compute_encoder.set_bytes(E, 10); + compute_encoder.set_bytes(static_cast(is_causal_), 7); + compute_encoder.set_bytes(N, 8); + compute_encoder.set_bytes(L, 9); + compute_encoder.set_bytes(S, 10); + compute_encoder.set_bytes(E, 11); // Make sure the data type matches with the metal kernel: otherwise you'll get flaky issues and stuck :( - compute_encoder.set_bytes(num_kv_heads_, 11); - compute_encoder.set_bytes(num_heads_, 12); - compute_encoder.set_bytes(scale_, 13); + compute_encoder.set_bytes(num_kv_heads_, 12); + compute_encoder.set_bytes(num_heads_, 13); + compute_encoder.set_bytes(scale_, 14); size_t tgp_size = kernel->maxTotalThreadsPerThreadgroup(); size_t simd_width = kernel->threadExecutionWidth(); @@ -320,10 +336,10 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< const int Tr = (L + Br - 1) / Br; const int Tc = (S + Bc - 1) / Bc; - compute_encoder.set_bytes(Br, 14); - compute_encoder.set_bytes(Bc, 15); - compute_encoder.set_bytes(Tr, 16); - compute_encoder.set_bytes(Tc, 17); + compute_encoder.set_bytes(Br, 15); + compute_encoder.set_bytes(Bc, 16); + compute_encoder.set_bytes(Tr, 17); + compute_encoder.set_bytes(Tc, 18); MTL::Size num_threadgroups = MTL::Size(N, Tr, 1); MTL::Size num_threads_per_group = MTL::Size(Br, simd_width, 1); diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index 98e5821..b78af0e 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -11,17 +11,18 @@ using namespace metal; device float* out [[buffer(4)]], constant const int* mask_shape [[buffer(5)]], constant const int64_t* mask_strides [[buffer(6)]], - device const int &N [[buffer(7)]], - device const int &L [[buffer(8)]], - device const int &S [[buffer(9)]], - device const int &E [[buffer(10)]], - device const int &num_kv_heads [[buffer(11)]], - device const int &num_heads [[buffer(12)]], - device const float &scale [[buffer(13)]], - device const int &Br [[buffer(14)]], - device const int &Bc [[buffer(15)]], - [[maybe_unused]] device const int &Tr [[buffer(16)]], - device const int &Tc [[buffer(17)]], + device const int &is_causal [[buffer(7)]], + device const int &N [[buffer(8)]], + device const int &L [[buffer(9)]], + device const int &S [[buffer(10)]], + device const int &E [[buffer(11)]], + device const int &num_kv_heads [[buffer(12)]], + device const int &num_heads [[buffer(13)]], + device const float &scale [[buffer(14)]], + device const int &Br [[buffer(15)]], + device const int &Bc [[buffer(16)]], + [[maybe_unused]] device const int &Tr [[buffer(17)]], + device const int &Tc [[buffer(18)]], uint2 group_id [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -70,6 +71,13 @@ using namespace metal; } for (int j = 0; j < Tc; j++) { + int row_max = min((i + 1) * Br - 1, L - 1); + int col_min = j * Bc; + // Causal masking: if the entire block of K is masked out by causal mask, we can skip the computation for this block. + if (is_causal && col_min > row_max + (S - L)) { + continue; + } + bool is_j_in_range = j * Bc + b < S && b < Bc; device const float *k_ptr = k_ptr_base + j * Bc * E; @@ -84,11 +92,16 @@ using namespace metal; } s_a_b *= scale; if (is_i_in_range && is_j_in_range) { - int64_t m_idx_1 = n; - int64_t m_idx_2 = i * Br + a; - int64_t m_idx_3 = j * Bc + b; - int64_t m_idx_converted = elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask_shape, mask_strides, 3); - s_a_b += mask[m_idx_converted]; + int row_min = i * Br; + int col_max = min((j + 1) * Bc - 1, S - 1); + bool block_all_valid = is_causal && (col_max <= row_min + (S - L)); + if (!block_all_valid) { + int64_t m_idx_1 = n; + int64_t m_idx_2 = i * Br + a; + int64_t m_idx_3 = j * Bc + b; + int64_t m_idx_converted = elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask_shape, mask_strides, 3); + s_a_b += mask[m_idx_converted]; + } } else { s_a_b = -1e9; } diff --git a/src/extensions_ref/src/tiny_llm_ext.h b/src/extensions_ref/src/tiny_llm_ext.h index 599ba9f..04dc9db 100644 --- a/src/extensions_ref/src/tiny_llm_ext.h +++ b/src/extensions_ref/src/tiny_llm_ext.h @@ -40,15 +40,17 @@ class QuantizedMatmul : public mx::Primitive { }; mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, - const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {}); + const float scale, const bool is_causal, const int num_kv_heads, const int num_heads, + mx::StreamOrDevice s = {}); mx::array flash_attention_no_mask(const mx::array &q, const mx::array &k, const mx::array &v, const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {}); class FlashAttention : public mx::Primitive { public: - explicit FlashAttention(mx::Stream stream, const float scale, const int num_kv_heads, const int num_heads) - : mx::Primitive(stream), scale_(scale), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {}; + explicit FlashAttention(mx::Stream stream, const float scale, const bool is_causal, const int num_kv_heads, + const int num_heads) + : mx::Primitive(stream), scale_(scale), is_causal_(is_causal), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {}; void eval_cpu(const std::vector &inputs, std::vector &outputs) override; void eval_gpu(const std::vector &inputs, std::vector &outputs) override; @@ -62,6 +64,7 @@ class FlashAttention : public mx::Primitive { private: float scale_; + bool is_causal_; int num_kv_heads_; int num_heads_; }; diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index a53a68d..1f81da9 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -71,7 +71,7 @@ def flash_attention( key: mx.array, value: mx.array, scale: float | None = None, - mask: mx.array | None = None, + mask: mx.array | str | None = None, ) -> mx.array: factor = mx.rsqrt(query.shape[-1]) if scale is None else mx.array(scale) factor = factor.astype(query.dtype) @@ -85,21 +85,22 @@ def flash_attention( query = mx.contiguous(query) key = mx.contiguous(key) value = mx.contiguous(value) + is_causal = mask == "causal" N = query.shape[0] - if mask is None: - mask = mx.reshape( - mx.broadcast_to(mx.zeros((L, S)), (*B, H_q, L, S)), (N, L, S) - ).astype(mx.float32) + if is_causal: + mask = mx.broadcast_to(causal_mask(L, S, mx.float32), (*B, H_q, L, S)) + elif mask is None: + mask = mx.broadcast_to(mx.zeros((L, S), dtype=mx.float32), (*B, H_q, L, S)) else: - mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)).astype( - mx.float32 - ) + mask = mx.broadcast_to(mask, (*B, H_q, L, S)) + mask = mx.contiguous(mask.reshape(N, L, S)).astype(mx.float32) result = tiny_llm_ext_ref.flash_attention( query, key, value, mask, factor, + is_causal=is_causal, num_heads=H_q, num_kv_heads=H, ) diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 7db2e27..0c38ecd 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -3,7 +3,6 @@ from .attention import ( scaled_dot_product_attention_grouped, flash_attention, - causal_mask, ) from .layer_norm import RMSNorm from .positional_encoding import RoPE @@ -81,9 +80,6 @@ def __call__( projection_k, projection_v, _, mask = cache.update_and_fetch( projection_k, projection_v, mask_length=L, mask=mask ) - S = projection_k.shape[-2] - if mask == "causal": - mask = causal_mask(L, S, mx.float32) if self.use_flash_attention: x = flash_attention( projection_q.astype(mx.float32), diff --git a/src/tiny_llm_ref/qwen3.py b/src/tiny_llm_ref/qwen3.py index 4d28ea8..d104619 100644 --- a/src/tiny_llm_ref/qwen3.py +++ b/src/tiny_llm_ref/qwen3.py @@ -3,7 +3,6 @@ from .attention import ( scaled_dot_product_attention_grouped, flash_attention, - causal_mask, ) from .layer_norm import RMSNorm from .positional_encoding import RoPE @@ -85,9 +84,6 @@ def __call__( projection_k, projection_v, _, mask = cache.update_and_fetch( projection_k, projection_v, mask_length=L, mask=mask ) - S = projection_k.shape[-2] - if mask == "causal": - mask = causal_mask(L, S, mx.float32) if self.use_flash_attention: x = flash_attention( projection_q.astype(mx.float32), diff --git a/tests_refsol/test_week_2_day_4.py b/tests_refsol/test_week_2_day_4.py index 3e8b3d4..6cbcef6 100644 --- a/tests_refsol/test_week_2_day_4.py +++ b/tests_refsol/test_week_2_day_4.py @@ -4,7 +4,7 @@ from .utils import * -def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool): +def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, mask_mode: str): precision = mx.float32 with mx.stream(stream): q_shape = (BATCH, H_q, L, E) @@ -15,7 +15,14 @@ def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool) query = mx.random.uniform(shape=q_shape, dtype=precision) key = mx.random.uniform(shape=kv_shape, dtype=precision) value = mx.random.uniform(shape=kv_shape, dtype=precision) - mask = mx.random.uniform(shape=mask_shape, dtype=precision) if with_mask else None + if mask_mode == "no_mask": + mask = None + elif mask_mode == "mask": + mask = mx.random.uniform(shape=mask_shape, dtype=precision) + elif mask_mode == "causal": + mask = "causal" + else: + raise ValueError(f"Unknown mask_mode: {mask_mode}") reference_output = mx.fast.scaled_dot_product_attention( q=query, @@ -35,36 +42,36 @@ def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool) assert_allclose(user_output, reference_output, precision=mx.float16) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_2_flash_attention_cpu_small(with_mask: bool): - attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"]) +def test_task_2_flash_attention_cpu_small(mask_mode: str): + attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, mask_mode) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_2_flash_attention_cpu(with_mask: bool): - attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask"]) +def test_task_2_flash_attention_cpu(mask_mode: str): + attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, mask_mode) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_2_flash_attention_cpu_large(with_mask: bool): - attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"]) +def test_task_2_flash_attention_cpu_large(mask_mode: str): + attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, mask_mode) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_3_flash_attention_gpu_extra_small(with_mask: bool): - attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask"]) +def test_task_3_flash_attention_gpu_extra_small(mask_mode: str): + attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, mask_mode) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_3_flash_attention_gpu_small(with_mask: bool): - attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"]) +def test_task_3_flash_attention_gpu_small(mask_mode: str): + attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, mask_mode) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_3_flash_attention_gpu(with_mask: bool): - attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask"]) +def test_task_3_flash_attention_gpu(mask_mode: str): + attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, mask_mode) -@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) -def test_task_3_flash_attention_gpu_large(with_mask: bool): - attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, with_mask) +@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"]) +def test_task_3_flash_attention_gpu_large(mask_mode: str): + attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, mask_mode)