码力全开 / FlashAttention简述

Created Tue, 30 Dec 2025 15:46:57 +0800 Modified Tue, 30 Dec 2025 16:24:43 +0800
918 Words 1 min

Transformer中模型的核心是自注意力机制(self-attention),其时间和存储复杂度均为$O(n^{2})$。于是有人提出了近似注意力的方法来降低注意力计算和内存需求,而FlashAttention在此基础上考虑了内存访问(IO)的开销。

关于FlashAttention的详细介绍,可以参考文章《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》

通过减少GPU内存读取和写入,FlashAttention的运行速度比PyTorch标准注意力快2-4倍,而所需的内存减少了5-20倍。

我们知道GPU的SRAM(静态随机存取内存)的IO读写速度大概为19TB/s,而GPU的HBM高带宽显存的读写速度为1.5TB/s,但是它们的存储容量分别是20MB和40GB。

FlashAttention在运行注意力机制算法时,需要从HBM中读取Q、K、V这3个矩阵,并在SRAM中进行计算,并在计算完成后写回到HBM中。因此就得考虑如何减少这样的IO开销,从而得到更好的效果。

在标准注意力机制实现中,假设HBM中矩阵$Q,K,V\in\mathbb{R}^{N\times d}$。其过程如下:

  1. 从HBM中加载Q,K,计算$S=QK^{T}$,并将S写入HBM
  2. 从HBM中读取S,计算$P=\text{Softmax}(S)$,将P写入HBM
  3. 从HBM中加载P和V,计算$O=PV$,将O写入HBM
  4. 返回O

从上面步骤可以看到$$ S=QK^{T}\in\mathbb{R}^{N\times N},\quad P=\text{Softmax}(S)\in\mathbb{R}^{N\times N},O=PV\in\mathbb{R}^{N\times d} $$

在计算过程中需要存储中间值S和P到HBM中,这会极大占用HBM(高带宽显存)。

而FlashAttention希望可以避免从HBM中读取和写入注意力矩阵,从而对如下方面进行优化:

  1. 在不访问整个输入的情况下计算softmax函数的缩减
  2. 在后向传播中不能存储中间注意力矩阵S和P

对于第1点,FlashAttention将输入分割成块,并在输入块上进行多次传递,从而以增量方式进行softmax缩减,从而大大加快运行的速度。 对于第2点,FlashAttention存储一个softmax函数的归一化因子,通过这个归一化因子在反向传播过程中再重新计算这2个矩阵,从而大大节省所占用的内存

参考视频:

https://www.bilibili.com/video/BV1zs4y1J7tb/ https://zhuanlan.zhihu.com/p/669926191

如果喜欢这篇文章或对您有帮助,可以:[☕] 请我喝杯咖啡 | [💓] 小额赞助