Skip to content

MYousuf3/flash-attention-kernels

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CS 8803 GPU Project 4: Implementing a self-attention kernel in CUDA

Introduction

Large Language Models (LLMs) are being widely used across numerous applications. The key model architecture powering these LLMs is the Transformer, which relies on a fundamental component known as the self-attention operation.

In this project, we will delve into the mechanics of how the attention operation works. The primary goal is to implement a highly optimized CUDA kernel for this operation, specifically targeting the FlashAttention algorithm.

Self-attention

Attention is a powerful operation first introduced in the "Attention Is All You Need" paper link. The basic operation defines a relationship between three matrices: Query (Q), Key (K), and Value (V). The output is calculated using the following formula:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V ,$$

where $d$ is the dimension of the key and query vectors. This scaling factor is used to prevent the dot products from growing too large, which could otherwise saturate the softmax function and destabilize training.

Self-attention is a specific application of this operation where the Q, K, and V matrices are all generated from the same input sequence, $X$.

In the context of language models, this input sequence $X$ typically represents a sentence or a batch of sentences. Each word in the sentence is first split into tokens. An embedding layer then maps each token into a high-dimensional vector in an embedding space. This process results in the input tensor $X$ having a dimension of [batch_size, seq_len, hidden_dim].

From this single input $X$, the Query (Q), Key (K), and Value (V) matrices are generated through linear projections. This is done by multiplying $X$ with three distinct, learnable weight matrices ($W_Q$, $W_K$, and $W_V$):

  • Query: $Q = X \cdot W_Q$
  • Key: $K = X \cdot W_K$
  • Value: $V = X \cdot W_V$

Each of these weight matrices ($W_Q$, $W_K$, $W_V$) typically has a dimension of [hidden_dim, hidden_dim]. This means that Q, K, and V will also have the same shape as $X$ ([batch_size, seq_len, hidden_dim]), effectively projecting the input embeddings into three different "representation subspaces" required for the attention operation.

For a better visual and conceptual understanding of self-attention, refer to these links and Youtube videos:

Causal attention

In many Transformer models, such as those used for text generation, we must ensure that tokens can only attend to previous tokens in the sequence. This is achieved by applying a causal mask (also known as a look-ahead mask) to the score matrix ($QK^T$).

This mask is typically a lower triangular matrix that is multiplied with the scores before the softmax operation. By masking out all elements corresponding to future tokens (i.e., setting them to negative infinity), we ensure that a token at a given position cannot "see" or attend to any tokens that come after it.

Multi-headed attention (MHA)

To allow the model to learn richer representations, we want each token to attend to other tokens in multiple contexts or "representation subspaces." This is the goal of Multi-headed Attention (MHA), which runs multiple self-attention operations in parallel.

Instead of performing $h$ separate projections for $h$ heads, it's far more efficient to:

  1. Perform one large linear projection for each of Q, K, and V (using $W_Q, W_K, W_V$ as described previously).
  2. "Split" the resulting $Q$, $K$, and $V$ matrices into $h$ heads.

This "split" is typically a tensor reshape, where the hidden_dim (or $d_{\text{model}}$) is divided into $h$ smaller chunks, one for each head. For example, $Q$ (shape [batch_size, seq_len, hidden_dim]) is reshaped to [batch_size, seq_len, num_heads, head_dim], where num_heads is the number of heads and head_dim is the dimension of each head (head_dim = hidden_dim / num_heads).

We can then define $Q^i$, $K^i$, and $V^i$ as the $i$-th head (slice) from these reshaped tensors. The attention operation is then computed for each head independently:

$$ head_i = \text{Attention}(Q^i, K^i, V^i) $$

The outputs of all heads ($head_1, ..., head_h$) are then concatenated back together (reversing the reshape) and passed through a final linear projection layer ($W^O$) to produce the final output:

$$ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1, ..., head_h)W^O $$

It is important to note that while this architecture is "Multi-head," all $h$ heads can be computed efficiently in parallel within a single, large batched matrix multiplication (BMM), which is key to its performance on GPUs. Here is the polished "Setup" section:

Setup

Follow these instructions to set up the project environment, especially if you are working on the ICE cluster.

0. Allocate GPU Resources

First, allocate an interactive session.

From the command line (via SSH): A typical allocation request for one H100 GPU, 8 CPUs, and 128GB of memory for 4 hours looks like this:

salloc --gres=gpu:h100 --cpus-per-task=8 --mem=128G --time=4:00:00

From the web: Alternatively, you can use the Open OnDemand website to request GPU allocation.

1. Cloning this repo

Important: Use Scratch Directory on ICE Cluster

When cloning this repository on the ICE cluster, make sure to work inside your scratch directory. Your home directory has only 30 GB of storage, while scratch provides up to 300 GB.

You can check your scratch path by running:

pace-quota

on any ICE node.

To simplify future access, consider creating a symlink from your home directory to your scratch directory:

ln -s /path/to/your/scratch ~/scratch

This ensures you always work within the larger storage space and avoid exceeding your home directory limit. Once this setup is done, please clone this repo and start working on this project.

cd <path to your scratch directory>
git clone <URL_TO_THIS_REPO>

Important: The default home folder (~/) is NFS-based and has a very limited quota. Working from this directory will be very slow and may cause your programs to fail. Always use ~/scratch/ for your projects.

2. Load uv

To get started, you first need to load the uv module, which is a fast Python package installer and virtual environment manager.

module load uv

3. Create a python virtual environment with uv

Next, create a new virtual environment and install the required packages.

# Create a virtual environment using Python 3.12
uv venv --python 3.12

# Activate the new environment
source .venv/bin/activate

# Install the required libraries 
uv pip install --no-cache-dir torch numpy transformers

4. At later PACE logins

For any future sessions after you log out and log back into PACE, you only need to re-activate the virtual environment to get started:

source .venv/bin/activate

Tasks

Task 1 (2 pts)

Your first task is to implement a naive, non-optimized self-attention operation in PyTorch. This will help you become familiar with the multi-headed self-attention mechanism.

You will implement your algorithm in the forward() method of attention.py. You can then use the provided test.py script to compare your results with PyTorch's built-in torch.nn.MultiheadAttention to verify correctness.

Here are the steps to follow:

  1. Split Heads: Use the torch.Tensor.view() method to reshape the input q, k, and v tensors. You will change their dimensions from [batch_size, seq_len, hidden_dim] to [batch_size, seq_len, num_heads, head_dim] to split the hidden_dim into multiple attention heads.

  2. Transpose: Use the .transpose() method to swap the sequence length and head dimensions. This changes the tensor shape to [batch_size, num_heads, seq_len, head_dim], which is the standard layout for batched attention computation.

  3. Compute Scores: Compute the scaled dot-product attention scores. This is done by performing a matrix multiplication of $Q$ and $K^T$, and then scaling by the square root of the head dimension ($d_k$). You can achieve this using the @ operator for matrix multiplication and math.sqrt() for scaling. The full operation is $QK^T / \sqrt{d}$.

  4. Apply Causal Mask (if enabled): If the causal=True flag is passed to the function, you must apply a causal (look-ahead) mask.

  • First, create the mask using torch.triu to get an upper-triangular matrix.
  • Then, apply this mask to the score matrix using the .masked_fill(mask, value) method. You should fill the masked positions with float('-inf') to ensure they become zero after the softmax.
  1. Apply Softmax: Use the torch.nn.functional.softmax() (imported as F) to apply the softmax function to each row (i.e., along the last dimension) of the scaled and masked score matrix.

  2. Compute Output: Compute the final output matrix $O$ by multiplying the softmax-normalized attention weights with the V (Value) tensor. We have already provided the code for transposing and concatenating the heads back into the final output shape.

Once you have completed your implementation, please read through the test code in test.py to understand how it works. You can then run the test case from your terminal with the following command:

python -m task1.test

Task 2 (3 pts)

In this task, you will implement a naive self-attention operation using a combination of PyTorch and your own custom CUDA kernels.

While we will continue to use the PyTorch framework for high-level evaluation, the core computations will be replaced by your CUDA code. These kernels will be bound to Python, allowing them to be called directly from the attention.py script instead of using standard PyTorch tensor operations. Please examine the forward() method to understand how this Python-CUDA binding is structured.

Your specific task is to complete the attention_kernel.cu file by implementing the following components.

Steps to Follow:

  1. Softmax Kernel This kernel will compute the softmax function, which normalizes a vector of values into a probability distribution. To ensure numerical stability, your implementation must first find the maximum value of each input row, subtract it from all elements in that row, and then compute the exponent and sum. This two-pass approach prevents overflow or underflow issues with large input values.

    You must use shared memory (__shared__) to minimize costly global memory accesses. Implement techniques such as parallel reduction within a thread block to efficiently find the row maximum and, subsequently, the sum of the exponents.

    Finally, you must implement a batched softmax kernel. The design should map one CUDA block to compute the softmax for one input row vector. As you can observe from the test_batched_softmax() function, the gridDim.x (the first dimension of the grid) will correspond to the batch dimension.

  2. GEMM Kernels The full attention operation requires two distinct matrix multiplication (GEMM) steps: $S = QK^T$ and $O = SV$ (where $S$ is the softmax result). Instead of creating one complex kernel that handles transposition on the fly, you will implement two separate, optimized GEMM kernels:

    • GEMM_NT: Performs A @ B.T (Normal-Transposed). This kernel will be used for the $QK^T$ computation.
    • GEMM_NN: Performs A @ B (Normal-Normal). This kernel will be used for the $SV$ computation.

    Note on Matrix Layout: Our N (Normal) and T (Transposed) notation differs from the cuBLAS convention. In PyTorch and our CUDA implementation, matrices are row-major by default. Therefore, N signifies a row-major matrix, and T signifies a column-major (transposed row-major) matrix.

    For the implementation, you must use the standard tiled-GEMM approach (as you learned in Project 1). This involves using shared memory to stage tiles of the input matrices, which significantly reduces global memory bandwidth and improves performance. Pay close attention to the test code to determine how batching is handled in the kernel launch parameters.

  3. Attention Computation Finally, you will assemble the full operation inside the custom_attention() host function in attention_kernel.cu. You must call your newly implemented kernels in the correct sequence.

    We have provided the scale_and_causal_mask kernel for you. You must call this kernel after your $QK^T$ (GEMM_NT) computation to perform both the scaling (division by $\sqrt{d}$) and the application of the causal mask (if enabled).

    It is your responsibility to choose the correct block size, grid size, and tile size (TILE_SIZE) for all kernel launches. If you are unsure about the launch configuration, study the testing functions in test.py for guidance.

To test your code, you must first compile the CUDA kernels. The provided script handles the PyTorch C++ extension building process.

# Compile the CUDA kernels
python -m task2.compile 

After a successful compilation, you can test your individual kernels to isolate any issues:

# Test only the batched softmax kernel
python -m task2.test --softmax

# Test only the batched GEMM kernels
python -m task2.test --gemm

Once your individual kernels are passing, you can test the full end-to-end attention kernel.

# Test the complete custom attention implementation
python -m task2.test --attention

Note: This CUDA implementation in Task 2 still involves multiple separate kernel calls (GEMM, scale/mask, softmax, GEMM). Because each kernel launch has overhead, we are not expecting to see significant performance benefits over PyTorch just yet.


Task 3 (3 pts)

Now, you will implement the FlashAttention 2 algorithm in PyTorch. This task is similar in spirit to Task 1 (as it's a PyTorch-only implementation), but the logic is significantly more complex. Your goal here is to understand and implement the attention operation in a tiled manner, which is the key to its efficiency.

Your implementation must follow Algorithm 1 from the paper "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" by Tri Dao.

I strongly suggest you first read the original FlashAttention-1 paper. This will help you understand why this tiled approach, which avoids materializing the massive intermediate $N \times N$ matrices in global memory, is so much faster than the naive PyTorch implementation from Task 1 that requires multiple separate kernel calls.

You will be implementing the forward pass only, so you can safely ignore the $L$ (log-sum-exp) statistic in Algorithm 1, which is primarily saved for the backward pass.

If you trace through Algorithm 1 and carefully track the intermediate values and their tensor dimensions, you should be able to implement the logic in Python without too much difficulty.

Important Typo in Paper!

There is a well-known typo in Algorithm 1 of the original paper. Please refer to this GitHub issue for the correction.

Tips for PyTorch Implementation: The notation $\mathrm{diag}(v)A$ on lines 10 and 12 may seem confusing to implement. This represents a diagonal matrix (formed from vector $v$) multiplied by a dense matrix $A$. However, if you think about it closely, this is simply an element-wise multiplication per row. This is very easy to implement in PyTorch using broadcasting.

  • For $\mathrm{diag}(v)A$, you can use: A * v.unsqueeze(-1)
  • Similarly, for $\mathrm{diag}(v)^{-1}A$, you can use: A / v.unsqueeze(-1)

To test the functionality of your implementation, run the following command:

python -m task3.test

Here is the polished "Task 4" section, formatted as requested:


Task 4 (4 pts)

This is the final and most important task. You will now implement the full FlashAttention algorithm in PyTorch + CUDA.

Your goal is to translate the tiled logic you developed in Task 3 (the PyTorch FlashAttention-2 algorithm) into a single, high-performance CUDA kernel. This fused kernel will perform the entire attention computation (GEMMs, masking, softmax, and output GEMM) in one pass, using tiling to avoid materializing the large $N \times N$ matrices in global memory.

You must use the pre-defined shared memory arrays to store blocks of Q, K, and V. You can assume for your implementation that the maximum head_dim will be 128.

A critical part of this task is performance tuning. In the custom_flash_attention() host function, you must experiment with and tune the tile sizes (B_r and B_c, the block sizes for rows and columns) to find the configuration that achieves the best performance on your GPU.

Testing Your Kernel: First, compile your new CUDA kernel with the following command.

python -m task4.compile

Then, run the test and benchmarking script:

python -m task4.test

The test code will verify the correctness of your output and then compare the performance of your custom FlashAttention kernel against the naive torch.nn.MultiheadAttention implementation.

Reporting Your Results: The score for this project will be based on the performance difference you achieve. You must report the 4 speedup numbers from the test script in the Results.md file.


Task 5 (4 pts)

This task focuses on the second phase of LLM inference: token generation (decode), and how to optimize it using a KV Cache.

How LLMs Generate New Tokens

LLM token generation is a two-step process:

  1. Prefill Phase: We run the full self-attention (as in Task 4) on the entire input prompt (e.g., 1024 tokens) to generate the first new token.
  2. Decode Phase: To generate every subsequent token, we append the newest token to our input sequence and run the attention operation again.

However, since we use causal attention, old tokens can never attend to new tokens. Only the single, newest token needs to attend to all the previous tokens (including itself). This insight changes the computation dramatically:

  • Instead of a full matrix-matrix multiply (GEMM), the attention computation for the new token is effectively a matrix-vector multiply (GEMV).

This is much faster, but we still need the Key (K) and Value (V) matrices from all the old tokens. To avoid recomputing them at every step, we store them in a KV Cache and simply reuse them.

  • To better understand KV Caching, please refer to this article: link

Your Task

Go through attention.py and test.py to understand the new workflow. You will see that it first generates Q, K, and V. Then, it uses a CUDA kernel to update (append) the newest K and V vectors into the KV cache. Finally, it runs a new decode kernel using the single new Q and the entire K and V from the cache.

Step 1: Implement update_cache_kernel()

Your first task is to write the update_cache_kernel() in CUDA. This kernel's simple job is to copy the new K and V vectors (from the current step) into the correct slot in the kv_cache tensor, based on the current token index.

We do this because we don't want to allocate and free new memory blocks at every decode step, as that would be extremely slow. Instead, we pre-allocate a fixed, large amount of memory for the KV cache (up to a max_seq_len) and just write into it.

This method can become inefficient if the allocated max_seq_len is much larger than the real sequence length. The state-of-the-art solution to this problem is PagedAttention. If you are interested, please look into the PagedAttention paper.

Step 2: Implement the Decode Kernel

Your second task is to modify the FlashAttention kernel you built in Task 4 and turn it into a specialized kernel for the decode phase.

The main change is that you can now assume the Q tensor's sequence length is always 1. This greatly simplifies your kernel's logic. The core FlashAttention concepts (tiling, shared memory, online softmax) are the same, but this assumption will likely eliminate one of the main for loops in your implementation (the one that iterates over blocks of Q, $B_r$).

Testing and Performance

After implementing both kernels, you will be able to run test.py. This test will:

  1. Run a prefill on 1024 tokens.
  2. Use your update_cache_kernel and decode_kernel to generate 100 new tokens, one by one.

Again, try tuning the tile size (B_r and B_c) in your decode kernel for the best performance. Be aware that a tile size that is too large may blow up the shared memory.

Compile and run the test:

python -m task5.compile
python -m task5.test

The test script will compare the performance of your decode kernel. Please report the speedup numbers in Results.md.

(Optional) Optimizing the Decode Kernel

Explore performance improvements to the decode stage by implementing a chunked FlashAttention kernel, as described in Stanford CRFM’s FlashDecoding article (link). Compare its performance against the baseline decode kernel to quantify potential speedups.


Task 6 (3 pts)

This final task brings all your work together. You will now implement an end-to-end LLM inference pipeline on a dummy LLM.

Your objective is to fill in the forward() method of the customAttention class. This single function must now be able to intelligently handle both the prefill and decode phases of inference.

Specifically, your forward() method must:

  1. Detect the Prefill Phase: If the model detects the prefill phase, it must call your optimized Task 4 FlashAttention kernel to process the full prompt.
  2. Detect the Decode Phase: If the model detects the decode phase, it must call your specialized Task 5 decode kernel, using the provided KV cache.
  3. Manage the Cache: On decode phase, it must correctly call your update_kv_cache (from Task 5) to append the newly computed K and V vectors to the cache, making them available for the next decode step. On prefill phase, it must store the current K and V vectors into the cache for future uses.

Testing and Reporting

The test script for this task simulates a full generative inference loop with generating 100 tokens:

python -m task6.test

This test will measure two key performance metrics:

  • TTFT (Time To First Token): This measures the performance of your prefill kernel (Task 4).
  • TBT (Time Between Tokens): This measures the average performance of your decode kernel (Task 5) over many steps.

Please report the final TTFT and TBT times in Results.md.


Task 7 (1 pts)

This final task is a demo to see your kernels working in a real end-to-end LLM inference pipeline with a Hugging Face model. You will see your custom attention kernel generating real tokens!

We (or the "cursor") have provided all the code for you, which attaches your prefill (Task 4) and decode (Task 5) kernels to a GPT-2 model. There is nothing new for you to code.

Your only job here is to run the model and verify that it works correctly. Make sure that your LLM is generating something useful and coherent. If it generates random, noisy tokens, it means your attention kernel has a bug. You will need to go back and debug your CUDA implementations from Tasks 4 and 5.

Running the Demo

To run the inference with your custom kernels:

python -m task7.inference

Please report the Generated text in Results.md.

Verifying Correctness

You might observe that the GPT-2 model generates sentences that are not factually correct. This is expected, as GPT-2 is an older model.

The important thing is to check if your model's "intelligence" matches the original GPT-2. To do this, you can run the reference implementation, which uses the standard Hugging Face attention:

python -m task7.inference_ref

Please check whether you think the two models are "equally intelligent." If your model (task7.inference) seems significantly dumber or produces nonsensical garbage compared to the reference, go back and check your kernel implementations for bugs.

Submissions

We are still finalizing the official code collection process for submission.

For all current beta testers, please upload your completed Results.md file directly to the class Teams channel. We appreciate your help in testing this assignment.

Furthermore, please feel free to make any pull requests to this repository to update the content of this README file. If you find any discrepancies, or if you think additional materials would be helpful for future students, your contributions are welcome.


Credits

Euijun Chung (echung67@gatech.edu)

About

Optional Final Project for GPU Hardware-Software

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published