Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions book/src/week2-04-flash-attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ Before creating the lazy output array, validate all shape and dtype constraints

Then implement `FlashAttention::eval_cpu(...)` with tiled online softmax. Use `Br = 32` and `Bc = 32`, iterate over `(n, i, j)` tiles, map query heads to KV heads with `q_kv_heads_ratio = num_heads / num_kv_heads`, and accumulate in float32. Mask values should be applied in each tile before updating `m_i` and `l_i`.

When `mask == "causal"`, treat it as a block-level optimization opportunity: if a tile is fully invalid, skip that tile entirely; if a tile is fully valid, skip mask read/add for that tile and continue with matmul + online softmax. Also note that `L` and `S` are not always equal in causal attention, so do not hardcode logic that assumes `L == S`.

You can test your implementation by running:

```bash
Expand Down
3 changes: 2 additions & 1 deletion src/extensions_ref/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ NB_MODULE(_ext, m) {
)");

m.def("flash_attention", &tiny_llm_ext_ref::flash_attention, "query"_a, "key"_a, "value"_a, "mask"_a, "scale"_a = 1.0,
"num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"(
"is_causal"_a = false, "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"(
Flash attention layer

Args:
Expand All @@ -41,6 +41,7 @@ NB_MODULE(_ext, m) {
value (array): Value array.
mask (array): Mask array.
scale (float): Scaling factor.
is_causal (bool): Enable causal-mask fast path.

Returns:
array: ``softmax(query @ key.T * scale) @ value``
Expand Down
54 changes: 35 additions & 19 deletions src/extensions_ref/src/flash_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

namespace tiny_llm_ext_ref {
mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask,
const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s) {
const float scale, const bool is_causal, const int num_kv_heads, const int num_heads,
mx::StreamOrDevice s) {
if (q.dtype() != mx::float32 || k.dtype() != mx::float32 || v.dtype() != mx::float32 || mask.dtype() != mx::float32) {
throw std::runtime_error("flash_attention: all input arrays must be float32");
}
Expand Down Expand Up @@ -54,7 +55,8 @@ mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::arra
}

return mx::array(q.shape(), mx::float32,
std::make_shared<FlashAttention>(to_stream(s), scale, num_kv_heads, num_heads), {q, k, v, mask});
std::make_shared<FlashAttention>(to_stream(s), scale, is_causal, num_kv_heads, num_heads),
{q, k, v, mask});
}

void FlashAttention::eval_cpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
Expand Down Expand Up @@ -91,7 +93,7 @@ void FlashAttention::eval_cpu(const std::vector<mx::array> &inputs, std::vector<
encoder.dispatch([out_ptr = out.data<float>(), out_shape = out.shape(), q = mx::array::unsafe_weak_copy(q),
k = mx::array::unsafe_weak_copy(k), v = mx::array::unsafe_weak_copy(v),
mask = mx::array::unsafe_weak_copy(mask), num_heads = num_heads_, num_kv_heads = num_kv_heads_,
scale = scale_]() {
scale = scale_, is_causal = is_causal_]() {
const int64_t N = q.shape()[0];
const int64_t L = q.shape()[1];
const int64_t S = k.shape()[1];
Expand Down Expand Up @@ -126,7 +128,14 @@ void FlashAttention::eval_cpu(const std::vector<mx::array> &inputs, std::vector<
std::vector<float> o_i(Br * E, 0.0);
std::vector<float> l_i(Br, 0.0);
std::vector<float> m_i(Br, -std::numeric_limits<float>::infinity());
const int64_t causal_offset = S - L;
for (int64_t j = 0; j < Tc; j++) {
int64_t row_max = i * Br + br_upper_bound - 1;
int64_t col_min = j * Bc;
// Causal masking: if the entire block of K is masked out by causal mask, we can skip the computation for this block.
if (is_causal && col_min > row_max + causal_offset) {
continue;
}
int bc_upper_bound = std::min(S - j * Bc, Bc);
// Each kernel processes a block of Br x Bc
// Load Kj and Vj
Expand Down Expand Up @@ -154,14 +163,20 @@ void FlashAttention::eval_cpu(const std::vector<mx::array> &inputs, std::vector<
}

// Add mask and scale
int64_t row_min = i * Br;
int64_t col_max = j * Bc + bc_upper_bound - 1;
bool block_all_valid = is_causal && (col_max <= row_min + causal_offset);
for (int64_t a = 0; a < br_upper_bound; a++) {
for (int64_t b = 0; b < bc_upper_bound; b++) {
int m_idx_1 = n;
int m_idx_2 = i * Br + a;
int m_idx_3 = j * Bc + b;
int m_idx_converted = mx::elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask);
s_i[a * Bc + b] *= scale;
s_i[a * Bc + b] += m_ptr[m_idx_converted];
// If the block is all valid, we don't need to add mask because it's all zeros. Otherwise we need to add mask for each element.
if (!block_all_valid) {
int m_idx_1 = n;
int m_idx_2 = i * Br + a;
int m_idx_3 = j * Bc + b;
int m_idx_converted = mx::elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask);
s_i[a * Bc + b] += m_ptr[m_idx_converted];
}
}
}

Expand Down Expand Up @@ -283,15 +298,16 @@ void FlashAttention::eval_gpu(const std::vector<mx::array> &inputs, std::vector<
const int S = k.shape()[1];
const int E = q.shape()[2];

compute_encoder.set_bytes(N, 7);
compute_encoder.set_bytes(L, 8);
compute_encoder.set_bytes(S, 9);
compute_encoder.set_bytes(E, 10);
compute_encoder.set_bytes(static_cast<int>(is_causal_), 7);
compute_encoder.set_bytes(N, 8);
compute_encoder.set_bytes(L, 9);
compute_encoder.set_bytes(S, 10);
compute_encoder.set_bytes(E, 11);

// Make sure the data type matches with the metal kernel: otherwise you'll get flaky issues and stuck :(
compute_encoder.set_bytes(num_kv_heads_, 11);
compute_encoder.set_bytes(num_heads_, 12);
compute_encoder.set_bytes(scale_, 13);
compute_encoder.set_bytes(num_kv_heads_, 12);
compute_encoder.set_bytes(num_heads_, 13);
compute_encoder.set_bytes(scale_, 14);

size_t tgp_size = kernel->maxTotalThreadsPerThreadgroup();
size_t simd_width = kernel->threadExecutionWidth();
Expand Down Expand Up @@ -320,10 +336,10 @@ void FlashAttention::eval_gpu(const std::vector<mx::array> &inputs, std::vector<
const int Tr = (L + Br - 1) / Br;
const int Tc = (S + Bc - 1) / Bc;

compute_encoder.set_bytes(Br, 14);
compute_encoder.set_bytes(Bc, 15);
compute_encoder.set_bytes(Tr, 16);
compute_encoder.set_bytes(Tc, 17);
compute_encoder.set_bytes(Br, 15);
compute_encoder.set_bytes(Bc, 16);
compute_encoder.set_bytes(Tr, 17);
compute_encoder.set_bytes(Tc, 18);

MTL::Size num_threadgroups = MTL::Size(N, Tr, 1);
MTL::Size num_threads_per_group = MTL::Size(Br, simd_width, 1);
Expand Down
45 changes: 29 additions & 16 deletions src/extensions_ref/src/flash_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ using namespace metal;
device float* out [[buffer(4)]],
constant const int* mask_shape [[buffer(5)]],
constant const int64_t* mask_strides [[buffer(6)]],
device const int &N [[buffer(7)]],
device const int &L [[buffer(8)]],
device const int &S [[buffer(9)]],
device const int &E [[buffer(10)]],
device const int &num_kv_heads [[buffer(11)]],
device const int &num_heads [[buffer(12)]],
device const float &scale [[buffer(13)]],
device const int &Br [[buffer(14)]],
device const int &Bc [[buffer(15)]],
[[maybe_unused]] device const int &Tr [[buffer(16)]],
device const int &Tc [[buffer(17)]],
device const int &is_causal [[buffer(7)]],
device const int &N [[buffer(8)]],
device const int &L [[buffer(9)]],
device const int &S [[buffer(10)]],
device const int &E [[buffer(11)]],
device const int &num_kv_heads [[buffer(12)]],
device const int &num_heads [[buffer(13)]],
device const float &scale [[buffer(14)]],
device const int &Br [[buffer(15)]],
device const int &Bc [[buffer(16)]],
[[maybe_unused]] device const int &Tr [[buffer(17)]],
device const int &Tc [[buffer(18)]],
uint2 group_id [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
Expand Down Expand Up @@ -70,6 +71,13 @@ using namespace metal;
}

for (int j = 0; j < Tc; j++) {
int row_max = min((i + 1) * Br - 1, L - 1);
int col_min = j * Bc;
// Causal masking: if the entire block of K is masked out by causal mask, we can skip the computation for this block.
if (is_causal && col_min > row_max + (S - L)) {
continue;
}

bool is_j_in_range = j * Bc + b < S && b < Bc;

device const float *k_ptr = k_ptr_base + j * Bc * E;
Expand All @@ -84,11 +92,16 @@ using namespace metal;
}
s_a_b *= scale;
if (is_i_in_range && is_j_in_range) {
int64_t m_idx_1 = n;
int64_t m_idx_2 = i * Br + a;
int64_t m_idx_3 = j * Bc + b;
int64_t m_idx_converted = elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask_shape, mask_strides, 3);
s_a_b += mask[m_idx_converted];
int row_min = i * Br;
int col_max = min((j + 1) * Bc - 1, S - 1);
bool block_all_valid = is_causal && (col_max <= row_min + (S - L));
if (!block_all_valid) {
int64_t m_idx_1 = n;
int64_t m_idx_2 = i * Br + a;
int64_t m_idx_3 = j * Bc + b;
int64_t m_idx_converted = elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask_shape, mask_strides, 3);
s_a_b += mask[m_idx_converted];
}
} else {
s_a_b = -1e9;
}
Expand Down
9 changes: 6 additions & 3 deletions src/extensions_ref/src/tiny_llm_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ class QuantizedMatmul : public mx::Primitive {
};

mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask,
const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {});
const float scale, const bool is_causal, const int num_kv_heads, const int num_heads,
mx::StreamOrDevice s = {});

mx::array flash_attention_no_mask(const mx::array &q, const mx::array &k, const mx::array &v,
const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {});

class FlashAttention : public mx::Primitive {
public:
explicit FlashAttention(mx::Stream stream, const float scale, const int num_kv_heads, const int num_heads)
: mx::Primitive(stream), scale_(scale), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {};
explicit FlashAttention(mx::Stream stream, const float scale, const bool is_causal, const int num_kv_heads,
const int num_heads)
: mx::Primitive(stream), scale_(scale), is_causal_(is_causal), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {};

void eval_cpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) override;
void eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) override;
Expand All @@ -62,6 +64,7 @@ class FlashAttention : public mx::Primitive {

private:
float scale_;
bool is_causal_;
int num_kv_heads_;
int num_heads_;
};
Expand Down
17 changes: 9 additions & 8 deletions src/tiny_llm_ref/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def flash_attention(
key: mx.array,
value: mx.array,
scale: float | None = None,
mask: mx.array | None = None,
mask: mx.array | str | None = None,
) -> mx.array:
factor = mx.rsqrt(query.shape[-1]) if scale is None else mx.array(scale)
factor = factor.astype(query.dtype)
Expand All @@ -85,21 +85,22 @@ def flash_attention(
query = mx.contiguous(query)
key = mx.contiguous(key)
value = mx.contiguous(value)
is_causal = mask == "causal"
N = query.shape[0]
if mask is None:
mask = mx.reshape(
mx.broadcast_to(mx.zeros((L, S)), (*B, H_q, L, S)), (N, L, S)
).astype(mx.float32)
if is_causal:
mask = mx.broadcast_to(causal_mask(L, S, mx.float32), (*B, H_q, L, S))
elif mask is None:
mask = mx.broadcast_to(mx.zeros((L, S), dtype=mx.float32), (*B, H_q, L, S))
else:
mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)).astype(
mx.float32
)
mask = mx.broadcast_to(mask, (*B, H_q, L, S))
mask = mx.contiguous(mask.reshape(N, L, S)).astype(mx.float32)
result = tiny_llm_ext_ref.flash_attention(
query,
key,
value,
mask,
factor,
is_causal=is_causal,
num_heads=H_q,
num_kv_heads=H,
)
Expand Down
4 changes: 0 additions & 4 deletions src/tiny_llm_ref/qwen2_week2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .attention import (
scaled_dot_product_attention_grouped,
flash_attention,
causal_mask,
)
from .layer_norm import RMSNorm
from .positional_encoding import RoPE
Expand Down Expand Up @@ -81,9 +80,6 @@ def __call__(
projection_k, projection_v, _, mask = cache.update_and_fetch(
projection_k, projection_v, mask_length=L, mask=mask
)
S = projection_k.shape[-2]
if mask == "causal":
mask = causal_mask(L, S, mx.float32)
if self.use_flash_attention:
x = flash_attention(
projection_q.astype(mx.float32),
Expand Down
4 changes: 0 additions & 4 deletions src/tiny_llm_ref/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .attention import (
scaled_dot_product_attention_grouped,
flash_attention,
causal_mask,
)
from .layer_norm import RMSNorm
from .positional_encoding import RoPE
Expand Down Expand Up @@ -85,9 +84,6 @@ def __call__(
projection_k, projection_v, _, mask = cache.update_and_fetch(
projection_k, projection_v, mask_length=L, mask=mask
)
S = projection_k.shape[-2]
if mask == "causal":
mask = causal_mask(L, S, mx.float32)
if self.use_flash_attention:
x = flash_attention(
projection_q.astype(mx.float32),
Expand Down
53 changes: 30 additions & 23 deletions tests_refsol/test_week_2_day_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .utils import *


def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool):
def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, mask_mode: str):
precision = mx.float32
with mx.stream(stream):
q_shape = (BATCH, H_q, L, E)
Expand All @@ -15,7 +15,14 @@ def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool)
query = mx.random.uniform(shape=q_shape, dtype=precision)
key = mx.random.uniform(shape=kv_shape, dtype=precision)
value = mx.random.uniform(shape=kv_shape, dtype=precision)
mask = mx.random.uniform(shape=mask_shape, dtype=precision) if with_mask else None
if mask_mode == "no_mask":
mask = None
elif mask_mode == "mask":
mask = mx.random.uniform(shape=mask_shape, dtype=precision)
elif mask_mode == "causal":
mask = "causal"
else:
raise ValueError(f"Unknown mask_mode: {mask_mode}")

reference_output = mx.fast.scaled_dot_product_attention(
q=query,
Expand All @@ -35,36 +42,36 @@ def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool)
assert_allclose(user_output, reference_output, precision=mx.float16)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_2_flash_attention_cpu_small(with_mask: bool):
attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"])
def test_task_2_flash_attention_cpu_small(mask_mode: str):
attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, mask_mode)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_2_flash_attention_cpu(with_mask: bool):
attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask"])
def test_task_2_flash_attention_cpu(mask_mode: str):
attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, mask_mode)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_2_flash_attention_cpu_large(with_mask: bool):
attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"])
def test_task_2_flash_attention_cpu_large(mask_mode: str):
attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, mask_mode)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_3_flash_attention_gpu_extra_small(with_mask: bool):
attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask"])
def test_task_3_flash_attention_gpu_extra_small(mask_mode: str):
attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, mask_mode)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_3_flash_attention_gpu_small(with_mask: bool):
attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"])
def test_task_3_flash_attention_gpu_small(mask_mode: str):
attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, mask_mode)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_3_flash_attention_gpu(with_mask: bool):
attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask"])
def test_task_3_flash_attention_gpu(mask_mode: str):
attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, mask_mode)


@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"])
def test_task_3_flash_attention_gpu_large(with_mask: bool):
attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, with_mask)
@pytest.mark.parametrize("mask_mode", ["no_mask", "mask", "causal"])
def test_task_3_flash_attention_gpu_large(mask_mode: str):
attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, mask_mode)