Transformer中模型的核心是自注意力机制(self-attention),其时间和存储复杂度均为$O(n^{2})$。于是有人提出了近似注意力的方法来降低注意力计算和内存需求,而FlashAttention在此基础上考虑了内存访问(IO)的开销。
关于FlashAttention的详细介绍,可以参考文章《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》。
通过减少GPU内存读取和写入,FlashAttention的运行速度比PyTorch标准注意力快2-4倍,而所需的内存减少了5-20倍。
我们知道GPU的SRAM(静态随机存取内存)的IO读写速度大概为19TB/s,而GPU的HBM高带宽显存的读写速度为1.5TB/s,但是它们的存储容量分别是20MB和40GB。
FlashAttention在运行注意力机制算法时,需要从HBM中读取Q、K、V这3个矩阵,并在SRAM中进行计算,并在计算完成后写回到HBM中。因此就得考虑如何减少这样的IO开销,从而得到更好的效果。
在标准注意力机制实现中,假设HBM中矩阵$Q,K,V\in\mathbb{R}^{N\times d}$。其过程如下:
- 从HBM中加载Q,K,计算$S=QK^{T}$,并将S写入HBM
- 从HBM中读取S,计算$P=\text{Softmax}(S)$,将P写入HBM
- 从HBM中加载P和V,计算$O=PV$,将O写入HBM
- 返回O
从上面步骤可以看到$$ S=QK^{T}\in\mathbb{R}^{N\times N},\quad P=\text{Softmax}(S)\in\mathbb{R}^{N\times N},O=PV\in\mathbb{R}^{N\times d} $$
在计算过程中需要存储中间值S和P到HBM中,这会极大占用HBM(高带宽显存)。
而FlashAttention希望可以避免从HBM中读取和写入注意力矩阵,从而对如下方面进行优化:
- 在不访问整个输入的情况下计算softmax函数的缩减
- 在后向传播中不能存储中间注意力矩阵S和P
对于第1点,FlashAttention将输入分割成块,并在输入块上进行多次传递,从而以增量方式进行softmax缩减,从而大大加快运行的速度。 对于第2点,FlashAttention存储一个softmax函数的归一化因子,通过这个归一化因子在反向传播过程中再重新计算这2个矩阵,从而大大节省所占用的内存
参考视频:
https://www.bilibili.com/video/BV1zs4y1J7tb/ https://zhuanlan.zhihu.com/p/669926191
