From 567f2dbf404f99144455a567f095f09d0e928c17 Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Sun, 22 Feb 2026 18:28:12 -0800 Subject: [PATCH 1/2] impl Signed-off-by: Connor1996 --- src/extensions/CMakeLists.txt | 4 + src/extensions/bindings.cpp | 33 +++ src/extensions/src/flash_attention.cpp | 254 ++++++++++++++++++++++ src/extensions/src/flash_attention.metal | 31 +++ src/extensions/src/quantized_matmul.cpp | 210 ++++++++++++++++++ src/extensions/src/quantized_matmul.metal | 53 +++++ src/extensions/src/tiny_llm_ext.h | 61 ++++++ src/tiny_llm/attention.py | 115 +++++++++- src/tiny_llm/basics.py | 4 +- src/tiny_llm/batch.py | 1 - src/tiny_llm/embedding.py | 10 +- src/tiny_llm/generate.py | 46 +++- src/tiny_llm/kv_cache.py | 14 +- src/tiny_llm/layer_norm.py | 10 +- src/tiny_llm/positional_encoding.py | 43 +++- src/tiny_llm/quantize.py | 25 ++- src/tiny_llm/qwen2_week1.py | 98 ++++++++- src/tiny_llm/qwen2_week2.py | 104 ++++++++- src/tiny_llm/sampler.py | 15 +- 19 files changed, 1087 insertions(+), 44 deletions(-) create mode 100644 src/extensions/src/flash_attention.cpp create mode 100644 src/extensions/src/flash_attention.metal create mode 100644 src/extensions/src/quantized_matmul.cpp create mode 100644 src/extensions/src/quantized_matmul.metal diff --git a/src/extensions/CMakeLists.txt b/src/extensions/CMakeLists.txt index fccadd6..1805b41 100644 --- a/src/extensions/CMakeLists.txt +++ b/src/extensions/CMakeLists.txt @@ -37,6 +37,8 @@ target_sources( PUBLIC ${CMAKE_CURRENT_LIST_DIR}/src/axpby.cpp ${CMAKE_CURRENT_LIST_DIR}/src/utils.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.cpp ) # Add include headers @@ -58,6 +60,8 @@ if(MLX_BUILD_METAL) tiny_llm_ext SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/axpby.metal + ${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.metal + ${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.metal INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} diff --git a/src/extensions/bindings.cpp b/src/extensions/bindings.cpp index 24437b0..a6c24c4 100644 --- a/src/extensions/bindings.cpp +++ b/src/extensions/bindings.cpp @@ -31,4 +31,37 @@ NB_MODULE(_ext, m) { Returns: array: ``alpha * x + beta * y`` )"); + + m.def("quantized_matmul", &tiny_llm_ext::quantized_matmul, "scales"_a, "biases"_a, "group_size"_a, "bits"_a, + "a"_a, "b"_a, "transpose_b"_a = false, "stream"_a = nb::none(), + R"( + Quantized matmul + + Args: + scales (array): Scaling factors. + biases (array): Biases. + group_size (int): Group size. + bits (int): Number of bits. + a (array): Input array (activations). + b (array): Input array (quantized weights). + transpose_b (bool): Whether to transpose ``b``. + + Returns: + array: Result of quantized matmul. + )"); + + m.def("flash_attention", &tiny_llm_ext::flash_attention, "query"_a, "key"_a, "value"_a, "mask"_a, "scale"_a = 1.0, + "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"( + Flash attention layer + + Args: + query (array): Query array. + key (array): Key array. + value (array): Value array. + mask (array): Mask array. + scale (float): Scaling factor. + + Returns: + array: ``softmax(query @ key.T * scale) @ value`` + )"); } diff --git a/src/extensions/src/flash_attention.cpp b/src/extensions/src/flash_attention.cpp new file mode 100644 index 0000000..20de5c2 --- /dev/null +++ b/src/extensions/src/flash_attention.cpp @@ -0,0 +1,254 @@ +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/utils.h" +#include "tiny_llm_ext.h" + +#ifdef _METAL_ +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#endif + +namespace tiny_llm_ext { +mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, + const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s) { + if (q.dtype() != mx::float32 || k.dtype() != mx::float32 || v.dtype() != mx::float32 || mask.dtype() != mx::float32) { + throw std::runtime_error("flash_attention: all input arrays must be float32"); + } + if (q.shape().size() != 3 || k.shape().size() != 3 || v.shape().size() != 3) { + throw std::runtime_error("flash_attention: all input arrays must be 3D"); + } + if (num_heads % num_kv_heads != 0) { + throw std::runtime_error("flash_attention: num_heads must be divisible by num_kv_heads"); + } + if (mask.shape().size() != 3) { + throw std::runtime_error("flash_attention: mask must be 3D"); + } + + // Q: [N, L, E] + // K: [N_KV, S, E] + // V: [N_KV, S, E] + // O: [N, L, E] + // M: [N, L, S] (optional, needs broadcasting) + + if (q.shape()[0] % num_heads != 0) { + throw std::runtime_error("flash_attention: q.shape[0] must be divisible by num_heads"); + } + if (k.shape()[0] % num_kv_heads != 0 || v.shape()[0] % num_kv_heads != 0) { + throw std::runtime_error("flash_attention: k.shape[0] and v.shape[0] must be divisible by num_kv_heads"); + } + if (q.shape()[2] != k.shape()[2] || q.shape()[2] != v.shape()[2]) { + throw std::runtime_error("flash_attention: q.shape[2] must be equal to k.shape[2] and v.shape[2]"); + } + if (q.shape()[0] / num_heads != k.shape()[0] / num_kv_heads) { + throw std::runtime_error("flash_attention: number of heads mismatch"); + } + if (k.shape()[1] != v.shape()[1]) { + throw std::runtime_error("flash_attention: k.shape[1] must be equal to v.shape[1]"); + } + if (mask.shape()[0] != q.shape()[0] || mask.shape()[1] != q.shape()[1] || mask.shape()[2] != k.shape()[1]) { + throw std::runtime_error("flash_attention: mask must be broadcastable to q, k, v"); + } + + return mx::array(q.shape(), mx::float32, + std::make_shared(to_stream(s), scale, num_kv_heads, num_heads), {q, k, v, mask}); +} + +void FlashAttention::eval_cpu(const std::vector &inputs, std::vector &outputs) { + auto &q = inputs[0]; + auto &k = inputs[1]; + auto &v = inputs[2]; + auto &mask = inputs[3]; + auto &out = outputs[0]; + + out.set_data(mx::allocator::malloc(out.nbytes())); + + auto &encoder = mx::cpu::get_command_encoder(stream()); + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + encoder.set_input_array(mask); + encoder.set_output_array(out); + + if (!q.flags().row_contiguous) { + throw std::runtime_error("flash_attention: q must be contiguous"); + } + if (!k.flags().row_contiguous) { + throw std::runtime_error("flash_attention: k must be contiguous"); + } + if (!v.flags().row_contiguous) { + throw std::runtime_error("flash_attention: v must be contiguous"); + } + + encoder.dispatch([ + out_ptr = out.data(), + out_shape = out.shape(), + q = mx::array::unsafe_weak_copy(q), + k = mx::array::unsafe_weak_copy(k), + v = mx::array::unsafe_weak_copy(v), + mask = mx::array::unsafe_weak_copy(mask), + scale = scale_, + num_heads = num_heads_, + num_kv_heads = num_kv_heads_ + ] { + // Q: [N, L, E] + // K: [N_KV, S, E] + // V: [N_KV, S, E] + + const int N = q.shape()[0]; + const int L = q.shape()[1]; + const int S = k.shape()[1]; + const int E = q.shape()[2]; + + const int Br = 32; + const int Bc = 32; + const int Tr = (L + Br - 1) / Br; + const int Tc = (S + Bc - 1) / Bc; + + const int q_kv_heads_ratio = num_heads / num_kv_heads; + const float *q_ptr = q.data(); + const float *k_ptr = k.data(); + const float *v_ptr = v.data(); + const float *mask_ptr = mask.data(); + + for (int n = 0; n < N; n++) { + const float *q_batch = q_ptr + n * L * E; + const float *k_batch = k_ptr + (n / q_kv_heads_ratio) * S * E; + const float *v_batch = v_ptr + (n / q_kv_heads_ratio) * S * E; + const float *mask_batch = mask_ptr + n * L * S; + float *out_batch = out_ptr + n * L * E; + for (int i = 0; i < Tr; i++) { + // Divide L into blocks of size Br + std::vector q_i(Br * E, 0.0); + // Load q_i + // Why load into a separate buffer? We need to reuse q_i for every block of K and V, + // and it's more efficient to load once than to read from global memory repeatedly. + int br_upper_bound = std::min(L - i * Br, Br); + for (int a = 0; a < br_upper_bound; a++) { + for (int b = 0; b < E; b++) { + q_i[a * E + b] = q_batch[(i * Br + a) * E + b]; + } + } + + std::vector m_i(Br, -std::numeric_limits::infinity()); + std::vector p_i(Br * Bc, 0.0); + std::vector l_i(Br, 0.0); + std::vector o_i(Br * E, 0.0); + + for (int j = 0; j < Tc; j++) { + // Divide S into blocks of size Bc + std::vector k_j(Bc * E, 0.0); // should consider tranpose + std::vector v_j(Bc * E, 0.0); + // Load k_j and v_j + int bc_upper_bound = std::min(S - j * Bc, Bc); + for (int a = 0; a < bc_upper_bound; a++) { + for (int b = 0; b < E; b++) { + k_j[a * E + b] = k_batch[(j * Bc + a) * E + b]; + v_j[a * E + b] = v_batch[(j * Bc + a) * E + b]; + } + } + + // Compute matmul for s_i = q_i * k_j^T : [Br, E] x [E, Bc] -> [Br, Bc] + std::vector s_i(Br * Bc, 0.0); + for (int a = 0; a < br_upper_bound; a++) { + for (int b = 0; b < bc_upper_bound; b++) { + for (int c = 0; c < E; c++) { + s_i[a * Bc + b] += (q_i[a * E + c] * k_j[b * E + c]); + } + s_i[a * Bc + b] *= scale; + s_i[a * Bc + b] += mask_batch[(i * Br + a) * S + j * Bc + b]; + } + } + + // Online softmax + // compute m_i = max(m_i, s_i) + std::vector m_i_diff(Br, 0.0); + for (int a = 0; a < br_upper_bound; a++) { + float rowmax = -std::numeric_limits::infinity(); + for (int b = 0; b < bc_upper_bound; b++) { + rowmax = std::max(rowmax, s_i[a * Bc + b]); + } + m_i_diff[a] = m_i[a] - rowmax; + m_i[a] = std::max(m_i[a], rowmax); + } + + // compute p_i = exp(s_i - m_i) + for (int a = 0; a < br_upper_bound; a++) { + for (int b = 0; b < bc_upper_bound; b++) { + p_i[a * Bc + b] = std::exp(s_i[a * Bc + b] - m_i[a]); + } + } + + // compute l_i = exp(m_i_diff) * l_i + sum(p_i) + for (int a = 0; a < br_upper_bound; a++) { + float rowsum = 0.0; + for (int b = 0; b < bc_upper_bound; b++) { + rowsum += p_i[a * Bc + b]; + } + l_i[a] = std::exp(m_i_diff[a]) * l_i[a] + rowsum; + } + + // compute o_i = diag(exp(m_i_diff)) * o_i from prev iteration + p_i * v_j + for (int a = 0; a < br_upper_bound; a++) { + for (int b = 0; b < E; b++) { + o_i[a * E + b] = std::exp(m_i_diff[a]) * o_i[a * E + b]; + // compute p_i * v_j + for (int c = 0; c < bc_upper_bound; c++) { + o_i[a * E + b] += p_i[a * Bc + c] * v_j[c * E + b]; + } + } + } + } + + // compute finial o_i + for (int a = 0; a < br_upper_bound; a++) { + for (int b = 0; b < E; b++) { + o_i[a * E + b] /= l_i[a]; + } + } + + // store o_i to out + for (int a = 0; a < br_upper_bound; a++) { + for (int b = 0; b < E; b++) { + out_batch[(i * Br + a) * E + b] = o_i[a * E + b]; + } + } + } + } + }); +} + +void FlashAttention::eval_gpu(const std::vector &inputs, std::vector &outputs) { + auto &q = inputs[0]; + auto &k = inputs[1]; + auto &v = inputs[2]; + auto &mask = inputs[3]; + auto &scale = inputs[4]; + + auto &out = outputs[0]; + + auto &s = stream(); + auto &d = mx::metal::device(s.device); + out.set_data(mx::allocator::malloc(out.nbytes())); + + auto lib = d.get_library("tiny_llm_ext"); + auto kernel = d.get_kernel("flash_attention", lib); + auto &compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_input_array(mask, 3); + compute_encoder.set_input_array(scale, 4); + compute_encoder.set_output_array(out, 5); + +} + +} // namespace tiny_llm_ext diff --git a/src/extensions/src/flash_attention.metal b/src/extensions/src/flash_attention.metal new file mode 100644 index 0000000..745e567 --- /dev/null +++ b/src/extensions/src/flash_attention.metal @@ -0,0 +1,31 @@ +#include +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +[[kernel]] void flash_attention_f32_e128( + device const float* q [[buffer(0)]], + device const float* k [[buffer(1)]], + device const float* v [[buffer(2)]], + device const float* mask [[buffer(3)]], + device float* out [[buffer(4)]], + constant const int* mask_shape [[buffer(5)]], + constant const int64_t* mask_strides [[buffer(6)]], + device const int &N [[buffer(7)]], + device const int &L [[buffer(8)]], + device const int &S [[buffer(9)]], + device const int &E [[buffer(10)]], + device const int &num_kv_heads [[buffer(11)]], + device const int &num_heads [[buffer(12)]], + device const float &scale [[buffer(13)]], + device const int &Br [[buffer(14)]], + device const int &Bc [[buffer(15)]], + [[maybe_unused]] device const int &Tr [[buffer(16)]], + device const int &Tc [[buffer(17)]], + uint2 group_id [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + + +} diff --git a/src/extensions/src/quantized_matmul.cpp b/src/extensions/src/quantized_matmul.cpp new file mode 100644 index 0000000..57a756f --- /dev/null +++ b/src/extensions/src/quantized_matmul.cpp @@ -0,0 +1,210 @@ +#include +#include +#include + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/utils.h" +#include "tiny_llm_ext.h" + +#ifdef _METAL_ +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#endif + +namespace tiny_llm_ext { + +mx::array quantized_matmul(const mx::array &scales, // Input array scales + const mx::array &biases, // Input array biases + const int group_size, // Group size + const int bits, // Number of bits + const mx::array &a, // Input array a (not quantized) + const mx::array &b, // Input array b (quantized) + const bool transpose_b, // Whether to transpose b + mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16) { + throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16"); + } + if (scales.dtype() != biases.dtype()) { + throw std::runtime_error("quantized_matmul: scales and biases must be the same dtype"); + } + if (b.dtype() != mx::uint32) { + throw std::runtime_error("quantized_matmul: b must be uint32"); + } + if (a.dtype() != scales.dtype()) { + throw std::runtime_error("quantized_matmul: a must be the same dtype as scales"); + } + if (a.shape().size() != 2) { + throw std::runtime_error("quantized_matmul: a must be a 2D array"); + } + if (b.shape().size() != 2) { + throw std::runtime_error("quantized_matmul: b must be a 2D array"); + } + if (bits != 4) { + throw std::runtime_error("quantized_matmul: bits must be 4"); + } + if (group_size != 64) { + throw std::runtime_error("quantized_matmul: group_size must be 64"); + } + auto out_shape = a.shape(); + if (out_shape.size() != 2) { + throw std::runtime_error("quantized_matmul: a must be a 2D array"); + } + out_shape[1] = b.shape()[0]; + if (!transpose_b) { + throw std::runtime_error("quantized_matmul: b must be transposed"); + } + + if (scales.shape() != biases.shape()) { + throw std::runtime_error("quantized_matmul: scales and biases must have the same shape"); + } + if (b.shape()[0] != scales.shape()[0]) { + throw std::runtime_error("quantized_matmul: b must have the same number of rows as scales"); + } + if (b.shape()[1] != scales.shape()[1] * group_size / 8) { + throw std::runtime_error("quantized_matmul: a must have the same number of columns as scales"); + } + + return mx::array( + /* const mx::Shape& shape = */ out_shape, + /* mx::Dtype dtype = */ a.dtype(), + /* std::shared_ptr primitive = */ + std::make_shared(to_stream(s), group_size, bits), + /* const std::vector& inputs = */ {scales, biases, a, b}); +} + + +void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, const mx::array &a, const mx::array &b, + mx::array &out, const int group_size, const int bits, mx::Stream stream) { + out.set_data(mx::allocator::malloc(out.nbytes())); + + // Get the CPU command encoder and register input and output arrays + auto &encoder = mx::cpu::get_command_encoder(stream); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + if (!a.flags().row_contiguous) { + throw std::runtime_error("quantized_matmul: a must be contiguous"); + } + if (!b.flags().row_contiguous) { + throw std::runtime_error("quantized_matmul: b must be contiguous"); + } + + // Launch the CPU kernel + encoder.dispatch([ + out_ptr = out.data(), + out_shape = out.shape(), + out_strides = out.strides(), + scales = mx::array::unsafe_weak_copy(scales), + biases = mx::array::unsafe_weak_copy(biases), + a = mx::array::unsafe_weak_copy(a), + b = mx::array::unsafe_weak_copy(b), + group_size, bits + ](){ + int M = a.shape()[0]; + int N = a.shape()[1]; + int K = b.shape()[0]; + int group_per_row = N / group_size; + int pack_factor = 32 / bits; + uint32_t item_mask = (1 << bits) - 1; + + const float16_t* scales_ptr = scales.data(); + const float16_t* biases_ptr = biases.data(); + const float16_t* a_ptr = a.data(); + const uint32_t* b_ptr = b.data(); + + // Do the element-wise operation for each output + for (int i = 0; i < M; i++) { + for (int k = 0; k < K; k++) { + float sum = 0; + for (int g = 0; g < group_per_row; g++) { + auto scales_loc = mx::elem_to_loc(k * group_per_row + g, scales); + auto bias_loc = mx::elem_to_loc(k * group_per_row + g, biases); + auto a_elem = i * N + g * group_size; + // b stores 8x 4-bit values per uint32; convert element offset to packed-word offset. + auto b_elem = (k * N + g * group_size) / pack_factor; + for (int word = 0; word < group_size/pack_factor; word++) { + uint32_t packed = b_ptr[mx::elem_to_loc(b_elem + word, b)]; + for (int pack_idx = 0; pack_idx < pack_factor; pack_idx++) { + auto shift = (pack_idx * bits); + auto quantized_val = (packed >> shift) & item_mask; + float dequantized_val = static_cast(quantized_val) * scales_ptr[scales_loc] + biases_ptr[bias_loc]; + float a_val = a_ptr[mx::elem_to_loc(a_elem + word * pack_factor + pack_idx, a)]; + sum += a_val * dequantized_val; + } + } + } + + auto out_idx = mx::elem_to_loc(i * K + k, out_shape, out_strides); + out_ptr[out_idx] = static_cast(sum); + } + } + }); +} + +void QuantizedMatmul::eval_cpu(const std::vector &inputs, std::vector &outputs) { + auto &scales = inputs[0]; + auto &biases = inputs[1]; + auto &a = inputs[2]; + auto &b = inputs[3]; + auto &out = outputs[0]; + + quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream()); +} + +void QuantizedMatmul::eval_gpu(const std::vector &inputs, std::vector &outputs) { + auto &scales = inputs[0]; + auto &biases = inputs[1]; + auto &a = inputs[2]; + auto &b = inputs[3]; + auto &out = outputs[0]; + + auto &s = stream(); + auto &d = mx::metal::device(s.device); + out.set_data(mx::allocator::malloc(out.nbytes())); + + auto lib = d.get_library("tiny_llm_ext"); + const char* kname; + if (a.dtype() == mx::float16) { + kname = "quantized_matmul_w4a16_g64_f16"; + } else if (a.dtype() == mx::bfloat16) { + kname = "quantized_matmul_w4a16_g64_bf16"; + } else { + throw std::runtime_error("quantized_matmul: a must be float16 or bfloat16"); + } + auto kernel = d.get_kernel(kname, lib); + + auto &compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(scales, 0); + compute_encoder.set_input_array(biases, 1); + compute_encoder.set_input_array(a, 2); + compute_encoder.set_input_array(b, 3); + compute_encoder.set_output_array(out, 4); + + int M = a.shape()[0]; + int N = a.shape()[1]; + int K = b.shape()[0]; + if (N % group_size_ != 0) { + throw std::runtime_error("quantized_matmul: N must be divisible by group_size"); + } + + compute_encoder.set_bytes(M, 5); + compute_encoder.set_bytes(N, 6); + compute_encoder.set_bytes(K, 7); + + size_t tpg_size = kernel->maxTotalThreadsPerThreadgroup(); + int x_size = kernel->threadExecutionWidth(); + int y_size = tpg_size / x_size; + + MTL::Size num_threadgroups = MTL::Size((M + x_size - 1) / x_size, (K + y_size - 1) / y_size, 1); + MTL::Size num_threads_per_group = MTL::Size(x_size, y_size, 1); + compute_encoder.dispatch_threadgroups(num_threadgroups, num_threads_per_group); +} + +} // namespace tiny_llm_ext diff --git a/src/extensions/src/quantized_matmul.metal b/src/extensions/src/quantized_matmul.metal new file mode 100644 index 0000000..9a16960 --- /dev/null +++ b/src/extensions/src/quantized_matmul.metal @@ -0,0 +1,53 @@ +#include +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void quantized_matmul_w4a16_g64( + device const T* scales [[buffer(0)]], + device const T* biases [[buffer(1)]], + device const T* a [[buffer(2)]], + device const uint32_t* b [[buffer(3)]], + device T* out [[buffer(4)]], + device const int &M [[buffer(5)]], + device const int &N [[buffer(6)]], + device const int &K [[buffer(7)]], + uint3 group_id [[threadgroup_position_in_grid]], + uint3 thread_id [[thread_position_in_threadgroup]], + uint3 threads_per_threadgroup [[threads_per_threadgroup]]) { + + const int group_size = 64; + const int bits = 4; + const int pack_factor = 32 / bits; // number of quantized values per uint + const int group_per_row = N / group_size; + const int item_mask = (1 << bits) - 1; + + // Each thread processes an element in the output matrix + const int i = group_id.x * threads_per_threadgroup.x + thread_id.x; + const int k = group_id.y * threads_per_threadgroup.y + thread_id.y; + + float sum = 0; + if (i < M && k < K) { + // dequantize b = quantized_b * scale + bias + for (int g = 0; g < group_per_row; g++) { + auto scales_loc = k * group_per_row + g; + auto bias_loc = k * group_per_row + g; + auto a_base_loc = i * N + g * group_size; + // b stores 8x 4-bit values per uint32; convert element offset to packed-word offset. + auto b_base_loc = (k * N + g * group_size) / pack_factor; + for (int word = 0; word < group_size/pack_factor; word++) { + uint32_t packed = b[b_base_loc + word]; + for (int pack_idx = 0; pack_idx < pack_factor; pack_idx++) { + auto shift = (pack_idx * bits); + auto quantized_val = (packed >> shift) & item_mask; + float dequantized_val = static_cast(quantized_val) * scales[scales_loc] + biases[bias_loc]; + float a_val = a[a_base_loc + word * pack_factor + pack_idx]; + sum += a_val * dequantized_val; + } + } + } + + out[i * K + k] = static_cast(sum); + } +} + +instantiate_kernel("quantized_matmul_w4a16_g64_f16", quantized_matmul_w4a16_g64, float16_t); \ No newline at end of file diff --git a/src/extensions/src/tiny_llm_ext.h b/src/extensions/src/tiny_llm_ext.h index bbca193..7103cc2 100644 --- a/src/extensions/src/tiny_llm_ext.h +++ b/src/extensions/src/tiny_llm_ext.h @@ -9,4 +9,65 @@ namespace tiny_llm_ext { void load_library(mx::Device d, const char *path); +mx::array quantized_matmul(const mx::array &scales, // Input array scales + const mx::array &biases, // Input array biases + const int group_size, // Group size + const int bits, // Number of bits + const mx::array &a, // Input array a (not quantized) + const mx::array &b, // Input array b (quantized) + const bool transpose_b, // Whether to transpose b + mx::StreamOrDevice s = {} // Stream on which to schedule the operation +); + +mx::array flash_attention(const mx::array &q, // Query [N, L, E], float32 + const mx::array &k, // Key [N_kv, S, E], float32 + const mx::array &v, // Value [N_kv, S, E], float32 + const mx::array &mask, // Mask [N, L, S], float32 + const float scale, // Scale factor + const int num_kv_heads, // Number of KV heads + const int num_heads, // Number of query heads + mx::StreamOrDevice s = {} // Stream on which to schedule the operation +); + +class QuantizedMatmul : public mx::Primitive { +public: + explicit QuantizedMatmul(mx::Stream stream, const int group_size, const int bits) + : mx::Primitive(stream), group_size_(group_size), bits_(bits) {}; + + void eval_cpu(const std::vector &inputs, std::vector &outputs) override; + void eval_gpu(const std::vector &inputs, std::vector &outputs) override; + + std::pair, std::vector> vmap(const std::vector &inputs, + const std::vector &axes) override { + throw std::runtime_error("QuantizedMatmul has no vmap implementation."); + } + + const char *name() const override { return "QuantizedMatmul"; } + +private: + int group_size_; + int bits_; +}; + +class FlashAttention : public mx::Primitive { +public: + explicit FlashAttention(mx::Stream stream, const float scale, const int num_kv_heads, const int num_heads) + : mx::Primitive(stream), scale_(scale), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {}; + + void eval_cpu(const std::vector &inputs, std::vector &outputs) override; + void eval_gpu(const std::vector &inputs, std::vector &outputs) override; + + std::pair, std::vector> vmap(const std::vector &inputs, + const std::vector &axes) override { + throw std::runtime_error("FlashAttention has no vmap implementation."); + } + + const char *name() const override { return "FlashAttention"; } + +private: + float scale_; + int num_kv_heads_; + int num_heads_; +}; + } // namespace tiny_llm_ext diff --git a/src/tiny_llm/attention.py b/src/tiny_llm/attention.py index ccb1179..20bea23 100644 --- a/src/tiny_llm/attention.py +++ b/src/tiny_llm/attention.py @@ -1,5 +1,6 @@ import mlx.core as mx from .basics import softmax, linear +from extensions import tiny_llm_ext def scaled_dot_product_attention_simple( @@ -9,7 +10,13 @@ def scaled_dot_product_attention_simple( scale: float | None = None, mask: mx.array | None = None, ) -> mx.array: - pass + d_k = query.shape[-1] + factor = mx.rsqrt(d_k) if scale is None else scale + scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor + if mask is not None: + scores = scores + mask + scores = softmax(scores, axis=-1) + return mx.matmul(scores, value) # output is (N.., L, D) class SimpleMultiHeadAttention: @@ -22,7 +29,14 @@ def __init__( wv: mx.array, wo: mx.array, ): - pass + self.hidden_size = hidden_size + self.num_heads = num_heads + assert hidden_size % num_heads == 0 + self.head_dim = hidden_size // num_heads + self.wq = wq + self.wk = wk + self.wv = wv + self.wo = wo def __call__( self, @@ -31,11 +45,52 @@ def __call__( value: mx.array, mask: mx.array | None = None, ) -> mx.array: - pass + N, L, _ = query.shape + assert query.shape == key.shape == value.shape + + querys = ( + linear(query, self.wq) + .reshape(N, L, self.num_heads, self.head_dim) + .transpose(0, 2, 1, 3) + ) + keys = ( + linear(key, self.wk) + .reshape(N, L, self.num_heads, self.head_dim) + .transpose(0, 2, 1, 3) + ) + values = ( + linear(value, self.wv) + .reshape(N, L, self.num_heads, self.head_dim) + .transpose(0, 2, 1, 3) + ) + + output = scaled_dot_product_attention_simple( + querys, + keys, + values, + scale=None, + mask=mask, + ) + + output = output.transpose(0, 2, 1, 3).reshape(N, L, self.hidden_size) + return linear(output, self.wo) def causal_mask(L: int, S: int, dtype: mx.Dtype) -> mx.array: - pass + """Generate a causal mask for attention mechanism. + + Args: + L (int): Length of the query sequence. + S (int): Length of the key/value sequence. + dtype (mx.Dtype): Data type of the output mask. + + Returns: + mx.array: A (L, S) shaped array where positions that should not be attended to are set to -inf, + and positions that can be attended to are set to 0. + """ + mask = mx.full((L, S), -mx.inf, dtype=dtype) + mask = mx.triu(mask, k=(S-L+1)) + return mask def scaled_dot_product_attention_grouped( @@ -45,7 +100,29 @@ def scaled_dot_product_attention_grouped( scale: float | None = None, mask: mx.array | str | None = None, ) -> mx.array: - pass + expected_shape = query.shape + H_q, L, D = query.shape[-3:] + H, S, _ = key.shape[-3:] + assert H_q % H == 0 + group = H_q // H + + query = query.reshape(-1, H, group, L, D) + key = key.reshape(-1, H, 1, S, D) + value = value.reshape(-1, H, 1, S, D) + if mask == "causal": + mask = causal_mask(L, S, query.dtype) + elif mask is not None: + mask = mask.reshape(-1, H, group, L, S) + else: + mask = None + + return scaled_dot_product_attention_simple( + query, + key, + value, + scale=scale, + mask=mask, + ).reshape(expected_shape) # output is (N.., L, D) def flash_attention( @@ -55,4 +132,30 @@ def flash_attention( scale: float | None = None, mask: mx.array | None = None, ) -> mx.array: - pass + B, H_q, L, E = query.shape + _, H, S, _ = key.shape + assert H_q % H == 0 + + query = query.reshape(-1, L, E) + key = key.reshape(-1, S, E) + value = value.reshape(-1, S, E) + + query = mx.contiguous(query) + key = mx.contiguous(key) + value = mx.contiguous(value) + + if mask is None: + mask = mx.zeros((L, S)) + mask = mx.broadcast_to(mask, (B, H_q, L, S)).reshape(-1, L, S).astype(mx.float32) + mask = mx.contiguous(mask) + + result = tiny_llm_ext.flash_attention( + query, + key, + value, + mask, + scale, + num_heads=H_q, + num_kv_heads=H, + ) + return mx.contiguous(result.reshape(B, H_q, L, E)) diff --git a/src/tiny_llm/basics.py b/src/tiny_llm/basics.py index 082223f..ea36ed3 100644 --- a/src/tiny_llm/basics.py +++ b/src/tiny_llm/basics.py @@ -12,8 +12,8 @@ def linear( w: mx.array, bias: mx.array | None = None, ) -> mx.array: - pass + return mx.matmul(x, w.T) + (bias if bias is not None else 0) def silu(x: mx.array) -> mx.array: - pass + return x / (1 + mx.exp(-x)) diff --git a/src/tiny_llm/batch.py b/src/tiny_llm/batch.py index 329971c..56e2ef5 100644 --- a/src/tiny_llm/batch.py +++ b/src/tiny_llm/batch.py @@ -96,7 +96,6 @@ def _print_progress( else: print(f" Prefill: idle, {queue_size} requests in queue", flush=True) - def batch_generate( model: any, tokenizer: TokenizerWrapper, diff --git a/src/tiny_llm/embedding.py b/src/tiny_llm/embedding.py index c66ccf6..d629324 100644 --- a/src/tiny_llm/embedding.py +++ b/src/tiny_llm/embedding.py @@ -1,12 +1,14 @@ import mlx.core as mx - +from .basics import linear class Embedding: def __init__(self, vocab_size: int, embedding_dim: int, weight: mx.array): - pass + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.weight = weight def __call__(self, x: mx.array) -> mx.array: - pass + return self.weight[x, :] def as_linear(self, x: mx.array) -> mx.array: - pass + return linear(x, self.weight) diff --git a/src/tiny_llm/generate.py b/src/tiny_llm/generate.py index d2bc3ed..33e98ea 100644 --- a/src/tiny_llm/generate.py +++ b/src/tiny_llm/generate.py @@ -3,7 +3,7 @@ from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 from typing import Callable - +from .kv_cache import TinyKvFullCache def simple_generate( model: Qwen2ModelWeek1, @@ -12,14 +12,54 @@ def simple_generate( sampler: Callable[[mx.array], mx.array] | None, ) -> str: def _step(model, y): - pass + output_logits = model(y[None, :]) + logits = output_logits[:, -1, :] + # avoid numerical instability + if sampler is None: + return mx.argmax(logits, axis=-1) + else: + logprobs = logits - mx.logsumexp(logits, keepdims=True) + return sampler(logprobs) + + + context = mx.array(tokenizer.encode(prompt)) + token = _step(model, context) + detokenizer = tokenizer.detokenizer + while token.item() != tokenizer.eos_token_id: + detokenizer.add_token(token.item()) + print(detokenizer.last_segment, end="", flush=True) + + context = mx.concat([context, token]) + token = _step(model, context) + mx.eval(token) + return "" + def simple_generate_with_kv_cache( model: Qwen2ModelWeek2, tokenizer: TokenizerWrapper, prompt: str ) -> str: def _step(model, y, offset, kv_cache): - pass + output_logits = model(y[None, :], offset, kv_cache) + logits = output_logits[:, -1, :] + # avoid numerical instability + return mx.argmax(logits, axis=-1) + + offset = 0 + kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + + context = mx.array(tokenizer.encode(prompt)) + token = _step(model, context, offset, kv_cache) + offset += context.size + detokenizer = tokenizer.detokenizer + while token.item() != tokenizer.eos_token_id: + detokenizer.add_token(token.item()) + print(detokenizer.last_segment, end="", flush=True) + + token = _step(model, token, offset, kv_cache) + offset += 1 + mx.eval(token) + return "" def speculative_generate( diff --git a/src/tiny_llm/kv_cache.py b/src/tiny_llm/kv_cache.py index 0b98a5a..8fb8812 100644 --- a/src/tiny_llm/kv_cache.py +++ b/src/tiny_llm/kv_cache.py @@ -98,4 +98,16 @@ def update_and_fetch( mask_length: int | None = None, mask: mx.array | str | None = None, ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: - pass + + if self.key_values is None: + self.key_values = (key, value) + self.offset: int = key.shape[2] + return key, value, self.offset, None + else: + B, H, S, D = key.shape + self.key_values = ( + mx.concat([self.key_values[0], key], axis=2), + mx.concat([self.key_values[1], value], axis=2), + ) + self.offset += S + return self.key_values[0], self.key_values[1], self.offset, None diff --git a/src/tiny_llm/layer_norm.py b/src/tiny_llm/layer_norm.py index 9b7af24..284878c 100644 --- a/src/tiny_llm/layer_norm.py +++ b/src/tiny_llm/layer_norm.py @@ -3,7 +3,13 @@ class RMSNorm: def __init__(self, dim: int, weight: mx.array, eps: float = 1e-5): - pass + self.dim = dim + self.weight = weight + self.eps = eps def __call__(self, x: mx.array) -> mx.array: - pass + orig_dtype = x.dtype + # use high precision for mean and square + x = x.astype(mx.float32) + mean = mx.mean(mx.square(x), axis=-1, keepdims=True) + return (x * self.weight / mx.sqrt(mean + self.eps)).astype(orig_dtype) diff --git a/src/tiny_llm/positional_encoding.py b/src/tiny_llm/positional_encoding.py index 8cb206f..f11c157 100644 --- a/src/tiny_llm/positional_encoding.py +++ b/src/tiny_llm/positional_encoding.py @@ -9,9 +9,48 @@ def __init__( base: int = 10000, traditional: bool = False, ): - pass + assert dims % 2 == 0, "dims must be even" + self.dims = dims + freqs = mx.power(base, mx.arange(0, dims // 2) * -2 / dims) + self.theta = mx.outer(mx.arange(seq_len), freqs) + self.cos = mx.cos(self.theta) + self.sin = mx.sin(self.theta) + + self.traditional = traditional def __call__( self, x: mx.array, offset: list[slice] | slice | None = None ) -> mx.array: - pass + N, S, H, D = x.shape + half_dims = self.dims // 2 + if self.traditional: + even = x[..., 0::2] + odd = x[..., 1::2] + else: + even = x[..., :half_dims] + odd = x[..., half_dims:] + + if offset is not None: + if isinstance(offset, slice): + assert offset.stop - offset.start == S, f"offset must be of length {S}" + elif isinstance(offset, list): + assert len(offset) == N, ( + f"offsets must have the same length as batch size {N}" + ) + for o in offset: + assert o.stop - o.start == S, f"offset must be of length {S}" + offset = mx.array([list(range(i.start, i.stop)) for i in offset]) + else: + raise ValueError("offset must be a slice or a list of slices") + + cos = self.cos[:S, :] if offset is None else self.cos[offset, :] + sin = self.sin[:S, :] if offset is None else self.sin[offset, :] + + cos = cos.reshape(-1, S, 1, half_dims) + sin = sin.reshape(-1, S, 1, half_dims) + r = [even * cos - odd * sin, odd * cos + even * sin] + if self.traditional: + r = mx.stack(r, axis=-1) + else: + r = mx.concat(r, axis=-1) + return r.reshape(N, S, H, D) # zip even and odd diff --git a/src/tiny_llm/quantize.py b/src/tiny_llm/quantize.py index 09ef193..1a8cfd5 100644 --- a/src/tiny_llm/quantize.py +++ b/src/tiny_llm/quantize.py @@ -1,6 +1,8 @@ import mlx.core as mx from typing import Any +from extensions import tiny_llm_ext + def dequantize_linear(mx_layer: Any) -> mx.array: w = mx.dequantize( @@ -48,7 +50,19 @@ def quantized_matmul( b: mx.array, transpose_b: bool = False, ) -> mx.array: - pass + *N, D = a.shape + a = a.reshape(-1, D) + a = mx.contiguous(a) + b = mx.contiguous(b) + return tiny_llm_ext.quantized_matmul( + scales, + biases, + group_size, + bits, + a, + b, + transpose_b, + ).reshape(*N, -1) def quantized_linear( @@ -56,4 +70,11 @@ def quantized_linear( w: QuantizedWeights, bias: mx.array | None = None, ) -> mx.array: - pass + v = quantized_matmul( + w.scales, w.biases, w.group_size, w.bits, x, w.weight, transpose_b=True + ) + + if bias is not None: + return v + bias + else: + return v \ No newline at end of file diff --git a/src/tiny_llm/qwen2_week1.py b/src/tiny_llm/qwen2_week1.py index e196fb1..cc4b055 100644 --- a/src/tiny_llm/qwen2_week1.py +++ b/src/tiny_llm/qwen2_week1.py @@ -7,7 +7,6 @@ from .embedding import Embedding from .quantize import dequantize_linear - class Qwen2MultiHeadAttention: def __init__( self, @@ -24,15 +23,43 @@ def __init__( max_seq_len: int = 32768, theta: int = 1000000, ): - pass + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + assert hidden_size % num_heads == 0 + self.head_dim = hidden_size // num_heads + self.wq = wq + self.wk = wk + self.wv = wv + self.wo = wo + self.bq = bq + self.bk = bk + self.bv = bv + self.rope = RoPE(self.head_dim, max_seq_len, theta, traditional=False) def __call__( self, x: mx.array, mask: mx.array | str | None = None, ) -> mx.array: - pass + B, L, _ = x.shape + query = linear(x, self.wq, bias=self.bq).reshape(B, L, self.num_heads, self.head_dim) + key = linear(x, self.wk, bias=self.bk).reshape(B, L, self.num_kv_heads, self.head_dim) + value = linear(x, self.wv, bias=self.bv).reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + query = self.rope(query, offset=slice(0, L)).transpose(0, 2, 1, 3) + key = self.rope(key, offset=slice(0, L)).transpose(0, 2, 1, 3) + # TODO: why use float32 to compute + output = scaled_dot_product_attention_grouped( + query.astype(mx.float32), + key.astype(mx.float32), + value.astype(mx.float32), + scale=None, + mask=mask + ).astype(x.dtype).transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size) + output = linear(output, self.wo) + return output class Qwen2MLP: def __init__( @@ -43,10 +70,17 @@ def __init__( w_up: mx.array, w_down: mx.array, ): - pass + self.dim = dim + self.hidden_dim = hidden_dim + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down def __call__(self, x: mx.array) -> mx.array: - pass + gate = silu(linear(x, self.w_gate)) + value = linear(x, self.w_up) + + return linear(gate * value, self.w_down) class Qwen2TransformerBlock: @@ -72,22 +106,66 @@ def __init__( max_seq_len: int = 32768, theta: int = 1000000, ): - pass - + self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, w_post_attention_layernorm, rms_norm_eps) + self.attention = Qwen2MultiHeadAttention(hidden_size, num_attention_heads, num_kv_heads, wq, wk, wv, wo, bq, bk, bv, max_seq_len, theta) + self.mlp = Qwen2MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + def __call__( self, x: mx.array, mask: mx.array | str | None = None, ) -> mx.array: - pass + r = self.attention(self.input_layernorm(x), mask) + h = r + x + r = self.mlp(self.post_attention_layernorm(h)) + return r + h class Qwen2ModelWeek1: def __init__(self, mlx_model: Any): - pass + model = mlx_model.model + args = mlx_model.args + + self.embedding = Embedding(args.vocab_size, args.hidden_size, weight=dequantize_linear(model.embed_tokens)) + self.layers = [] + for layer in model.layers: + wq = dequantize_linear(layer.self_attn.q_proj) + wk = dequantize_linear(layer.self_attn.k_proj) + wv = dequantize_linear(layer.self_attn.v_proj) + wo = dequantize_linear(layer.self_attn.o_proj) + w_gate = dequantize_linear(layer.mlp.gate_proj) + w_up = dequantize_linear(layer.mlp.up_proj) + w_down = dequantize_linear(layer.mlp.down_proj) + + self.layers.append(Qwen2TransformerBlock( + args.num_attention_heads, args.num_key_value_heads, + args.hidden_size, args.intermediate_size, args.rms_norm_eps, + wq, wk, wv, wo, + layer.self_attn.q_proj.bias, layer.self_attn.k_proj.bias, layer.self_attn.v_proj.bias, + w_gate, w_up, w_down, + layer.input_layernorm.weight, layer.post_attention_layernorm.weight, + args.max_position_embeddings, args.rope_theta + )) + + if not args.tie_word_embeddings: + self.w_lm_head = dequantize_linear(mlx_model.lm_head) + else: + self.w_lm_head = None + + self.rms_norm = RMSNorm(args.hidden_size, model.norm.weight) def __call__( self, inputs: mx.array, ) -> mx.array: - pass + x = self.embedding(inputs) + + for layer in self.layers: + x = layer(x, mask="causal") + x = self.rms_norm(x) + + if self.w_lm_head is None: + return self.embedding.as_linear(x) + else: + return linear(x, self.w_lm_head) diff --git a/src/tiny_llm/qwen2_week2.py b/src/tiny_llm/qwen2_week2.py index 1604cba..4957eb0 100644 --- a/src/tiny_llm/qwen2_week2.py +++ b/src/tiny_llm/qwen2_week2.py @@ -1,11 +1,12 @@ import mlx.core as mx + from .basics import linear, silu from .attention import scaled_dot_product_attention_grouped from .layer_norm import RMSNorm from .positional_encoding import RoPE from typing import Any from .embedding import Embedding -from .quantize import dequantize_linear, QuantizedWeights +from .quantize import dequantize_linear, quantized_linear, QuantizedWeights from .kv_cache import TinyKvCache @@ -26,7 +27,19 @@ def __init__( theta: int = 1000000, use_flash_attention: bool = False, ): - pass + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + assert hidden_size % num_heads == 0 + self.head_dim = hidden_size // num_heads + self.wq = wq + self.wk = wk + self.wv = wv + self.wo = wo + self.bq = bq + self.bk = bk + self.bv = bv + self.rope = RoPE(self.head_dim, max_seq_len, theta, traditional=False) def __call__( self, @@ -35,8 +48,27 @@ def __call__( cache: TinyKvCache, mask: mx.array | str | None = None, ) -> mx.array: - pass + B, L, _ = x.shape + query = quantized_linear(x, self.wq, bias=self.bq).reshape(B, L, self.num_heads, self.head_dim) + key = quantized_linear(x, self.wk, bias=self.bk).reshape(B, L, self.num_kv_heads, self.head_dim) + value = quantized_linear(x, self.wv, bias=self.bv).reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + offset = list(map(lambda x: slice(x, x+L), offsets)) + query = self.rope(query, offset).transpose(0, 2, 1, 3) + key = self.rope(key, offset).transpose(0, 2, 1, 3) + key, value, _, _ = cache.update_and_fetch(key, value) + + # TODO: why use float32 to compute + output = scaled_dot_product_attention_grouped( + query.astype(mx.float32), + key.astype(mx.float32), + value.astype(mx.float32), + scale=None, + mask=mask + ).astype(x.dtype).transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size) + output = quantized_linear(output, self.wo) + return output class Qwen2MLP: def __init__( @@ -47,10 +79,17 @@ def __init__( w_up: QuantizedWeights, w_down: QuantizedWeights, ): - pass + self.dim = dim + self.hidden_dim = hidden_dim + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down def __call__(self, x: mx.array) -> mx.array: - pass + gate = silu(quantized_linear(x, self.w_gate)) + value = quantized_linear(x, self.w_up) + + return quantized_linear(gate * value, self.w_down) class Qwen2TransformerBlock: @@ -77,8 +116,11 @@ def __init__( theta: int = 1000000, use_flash_attention: bool = False, ): - pass - + self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, w_post_attention_layernorm, rms_norm_eps) + self.attention = Qwen2MultiHeadAttention(hidden_size, num_attention_heads, num_kv_heads, wq, wk, wv, wo, bq, bk, bv, max_seq_len, theta) + self.mlp = Qwen2MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + def __call__( self, x: mx.array, @@ -86,7 +128,11 @@ def __call__( cache: TinyKvCache, mask: mx.array | str | None = None, ) -> mx.array: - pass + offsets = [offset] * x.shape[0] + r = self.attention(self.input_layernorm(x), offsets, cache, mask) + h = r + x + r = self.mlp(self.post_attention_layernorm(h)) + return r + h class Qwen2ModelWeek2: @@ -96,7 +142,37 @@ def __init__( enable_flash_attn: bool = False, ): self.num_hidden_layers = mlx_model.args.num_hidden_layers - pass + model = mlx_model.model + args = mlx_model.args + self.precision = mx.float16 + + self.embedding = Embedding(args.vocab_size, args.hidden_size, weight=dequantize_linear(model.embed_tokens)) + self.layers = [] + for layer in model.layers: + wq = QuantizedWeights.from_mlx_layer(layer.self_attn.q_proj) + wk = QuantizedWeights.from_mlx_layer(layer.self_attn.k_proj) + wv = QuantizedWeights.from_mlx_layer(layer.self_attn.v_proj) + wo = QuantizedWeights.from_mlx_layer(layer.self_attn.o_proj) + w_gate = QuantizedWeights.from_mlx_layer(layer.mlp.gate_proj) + w_up = QuantizedWeights.from_mlx_layer(layer.mlp.up_proj) + w_down = QuantizedWeights.from_mlx_layer(layer.mlp.down_proj) + + self.layers.append(Qwen2TransformerBlock( + args.num_attention_heads, args.num_key_value_heads, + args.hidden_size, args.intermediate_size, args.rms_norm_eps, + wq, wk, wv, wo, + layer.self_attn.q_proj.bias.astype(self.precision), layer.self_attn.k_proj.bias.astype(self.precision), layer.self_attn.v_proj.bias.astype(self.precision), + w_gate, w_up, w_down, + layer.input_layernorm.weight.astype(self.precision), layer.post_attention_layernorm.weight.astype(self.precision), + args.max_position_embeddings, args.rope_theta + )) + + if not args.tie_word_embeddings: + self.w_lm_head = QuantizedWeights.from_mlx_layer(mlx_model.lm_head) + else: + self.w_lm_head = None + + self.rms_norm = RMSNorm(args.hidden_size, model.norm.weight.astype(self.precision)) def __call__( self, @@ -104,4 +180,12 @@ def __call__( offset: int, cache: list[TinyKvCache], ) -> mx.array: - pass + x = self.embedding(inputs) + for layer, layer_cache in zip(self.layers, cache): + x = layer(x, offset, layer_cache, mask="causal") + x = self.rms_norm(x) + + if self.w_lm_head is None: + return self.embedding.as_linear(x) + else: + return quantized_linear(x, self.w_lm_head) diff --git a/src/tiny_llm/sampler.py b/src/tiny_llm/sampler.py index 5d2ed9d..6d315bf 100644 --- a/src/tiny_llm/sampler.py +++ b/src/tiny_llm/sampler.py @@ -4,8 +4,21 @@ def make_sampler(temp: float, top_p: float, top_k: int | None): def sample(logprobs: mx.array): + if top_k is not None and top_k > 0: + mask_idx = mx.argpartition(-logprobs, top_k - 1)[..., top_k:] + logprobs = mx.put_along_axis(logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1) + if top_p is not None and top_p > 0: + sorted_idx = mx.argsort(-logprobs, axis=-1) + sorted_logprobs = logprobs[..., sorted_idx] + cumsum = mx.cumsum(mx.exp(sorted_logprobs), axis=-1) + mask = cumsum < top_p + mask[..., 0] = True # ensure at least one token is kept + logprobs[..., sorted_idx] = mx.where(mask, sorted_logprobs, -mx.inf) + if temp == 0: return mx.argmax(logprobs, axis=-1) - pass + else: + return mx.random.categorical(logprobs / temp, axis=-1) return sample + From 759c5c39cd0a1e1ab3152df89eb2ef8265b31413 Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Tue, 3 Mar 2026 00:20:57 -0800 Subject: [PATCH 2/2] fix(ref): remove flash attention debug sentinel write Signed-off-by: Connor1996 --- src/extensions_ref/src/flash_attention.metal | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index 98e5821..0d67dc5 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -61,14 +61,6 @@ using namespace metal; } } - if (simd_lid == 0) { - for (int c = 0; c < E; c++) { - if (is_i_in_range && n < N) { - out[n * L * E + (i * Br + a) * E + c] = -233.0; - } - } - } - for (int j = 0; j < Tc; j++) { bool is_j_in_range = j * Bc + b < S && b < Bc;