Building up to FlashAttention-2
FlashAttention and FlashAttention-2 are some of the most notable improvements in recents years that have allowed Large Language Models to be scaled to as large as they are today. They mark the most significant efficiency improvement in both training and inference and are used in almost every LLM that people use today. DeepSeek even went as far as to make specific implementation decisions to guide the model towards using it. This will cover both FlashAttention and the changes made in FlashAttention-2 (the one DeepSeek uses) but will skip an overview of the Scaled Dot Product Attention or Multi Head Attention that it is designed to replace, although there is a small overview of it within my overview of [[RWKV]].
FlashAttention:
FlashAttention is an alternative to standard Dot Product Attention that produces the exact same outputs given the input with improvements on the algorithm’s memory efficiency. The main performance bottleneck of standard attention comes in how it interacts with memory due to the scale of information it has to manage and perform calculations on. For this overview not much knowledge of low level I/O computer architecture is needed, but the basic idea that HBM (CPU memory) is slower but can hold more and SRAM (GPU memory) is much faster but can hold much less will be crucial. Since standard attention performs on the entire token sequence at once, it is required that it is performed on the much slower HBM.
The main idea behind FlashAttention comes in a form of tiling. The idea is to break up the sequence into smaller groups of “tiles” each holding less information. If the attention can find a way to break its algorithm up into these smaller pieces, it would be able to perform them on the much faster SRAM, allowing the algorithm to receive very notable speed improvements.
Forward Pass:
FlashAttention breaks the sequence up into blocks of size where is the column block size and is the row block size (with on-chip memory and head dimension ). In order to get to this point the Queries , Keys , and Values have to be calculated first just like usual. Once they are, is broken up into (each of size ) and and are broken up into and respectively (each of size ). The algorithm iterates through all and to get information from each attention score with which it adds some calculations to it’s running summary of the sequence. For this post I will be shortening this to for the -th tile we are performing on to simplify the notation. I will use the following definitions of to denote the first tile and to denote the next (which will be continued for each subsequent tile). As well we will need to define some which are derived in the same way.
At the very start of the algorithm, some local computations need to be performed on the first tile. The algorithm has running values and for normalization. These are used to compute the exponentiated score for this tile .
These are then used to compute another running value which acts as the algorithm’s current partial output. This calculation is very simple for the first tile but the method in which the information is accumulated in this output will be shown for the second tile.
We then move on to the next tile in the sequence. The values of , , and will be regularly used in this calculation. First we perform similar calculations to get new normalization values and a new exponentiated score. I will be calling each of the new values for these simply and since they are now going to be a running value and the old values will be the only ones denoted.
Once these are calculated we then calculate a new partial output . This same procedure we denoted for tile 2 will then be repeated again, with this tile being treated as tile and its output being treated as .
Once each tile is ran through, the final output will be the exact same as if a scaled dot-product attention calculation was performed on the entire sequence. This form of tiling allows the information within each tile to be placed entirely on SRAM, which allows the computer to perform each step much quicker than it otherwise would on HBM, even though the algorithm is much more complex to human eyes. This algorithm is then parallelized over the batch size and each attention head to achieve to make it even more efficient.
This also has efficiency improvements for the backward pass during training. It is shown that both and can be very simply calculated if the rest of the information (, , and for each layer and and for the -th tile at that layer) is stored in memory, so we are able to save the memory that would otherwise be spent on keeping them for the backward pass at very little computational cost. This is a significant change even though it may not seem like it due to their size of each of and and their nature of mere intermediate matrices. The specificities of the backward pass will be shown later for FlashAttention-2 since it also makes a couple improvements on the calculations.
FlashAttention-2:
FlashAttention-2 makes some seemingly small but very powerful improvements to the model. It gets a lot deeper into the architectural knowledge required for some of the changes made, especially those about workload balancing, but I consider those specific implementation details and I will be skipping over them. For FlashAttention-2 there are two main changes to the algorithm detailed above.
First, the method in which we compute the partial output has it’s calculation split in two. It removes the need to perform the normalization calculation twice, which are much more computationally heavy on modern GPUs in comparison to the matrix multplications.
Second, instead of saving both and in memory we save one collective value for each tile. This doubles down on the goal to limit the amount of interaction with HBM during training.
As well one of the architectural changes that improves the algorithm the most comes in it’s improved parallelism. As a reminder the first algorithm only parallelized over the batch and over each attention head. An additional sequence level parallelization is implemented in FlashAttention-2. Even though the algorithm is sequential, the calculations that locally stay within the tile can be performed ahead of time. This allows the algorithm to move faster through the sequence since a majority of the computation will already be down per tile. This is one of the most important changes to the low level design of the algorithm because it allowed the attention to extend to much longer sequences, which is one of the if not the biggest weakpoint of almost every attention algorithm.
Backward Pass:
The backward pass for the model also uses this same tiling system to compute the gradient, in which the gradients are computed in tiles to use the same GPU based speedups that the forward pass has. The backward pass uses the same tiles as the forward pass did. As stated above the only information stored for this pass is , , and for each layer and and for each tile within that layer. As well the derivative in respect to the output is simply derived from the specified loss function of the model. First both and are recalculated for the given tile.
This allows each of and to be calculated very simply since they are per-tile values. They use another value in their calculations.
Since each of , , and are not per-tile values but for the entire sequence and are calculated as such, their derivatives need to take the entire sequence into account. This uses the same running summarization strategy as the forward pass and calculates a partial output for each derivative , , and as the backward pass is running through each tile. Once each tile is done the final state of each of these is then gradient used.