码力全开 / 预训练大模型内存占用估计

Created Tue, 22 Jul 2025 08:30:57 +0800 Modified Tue, 22 Jul 2025 09:44:23 +0800
1020 Words 1 min

这里主要介绍一下对于预训练大模型微调时内存占用情况的估算,对于一个预训练大模型,其训练时内存的占用需要考虑如下一些部分及相应的精度:

  • 模型参数
  • 梯度
  • 优化器状态
  • 激活值

对于精度其由数据类型决定,其中FP32为4字节,FP16/BF16为2字节,INT8为1字节,INT4为0.5字节。

对于模型参数内存的占用,可以使用如下公式进行表示:

$$ M_{\text{params}}=P\times B $$

其中P为模型参数量,而B为单参数字节数(精度)。以LLaMA-70B为例,使用BF16进行训练,其参数内存约等于$70\times 10^{9}\times 2\approx 140\text{GB}$。

之后是梯度的内存占用,主要用于反向传播生成过程。其可以用如下公式进行表示:

$$ M_{\text{grads}} = P\times B_{\text{grad}} $$

梯度通常与模型参数是相同的精度,即$B_{\text{grad}}=B$,同上所述,对于70B模型使用BF16训练时,梯度内存约等于140GB。

接下来是优化器状态的内存占用,这是训练开销的主要部分。之前的模型参数内存和梯度内存都是基础占用。对于不同的优化器,其占用情况有所不同。这里以Adam优化器为例,其占用可以使用如下公式进行表示$$ M_{\text{opt}}=P\times (2\times 4+B) $$

其中动量与二阶矩一般采用FP32,因此其内存占用为$2\times 4\times P$,之后需要保存模型参数副本,通常也是FP32。还是以70B模型使用Adam优化器为例,优化器状态内存约等于$70\times 10^{9}\times 12B\approx 840\text{GB}$。

假设使用简单的SGD,由于其不需要存储模型参数和梯度,因此其计算公式为$$ M_{\text{SGD}}=P\times B+P\times B=2P\times B $$

其中参数内存为$P\times B$,梯度内存为$P\times B$。

若使用改进版SGDM(SGD with Momentum),由于引入动量机制(一阶矩),因此其计算公式为$$ M_{\text{SGDM}}=P\times B+P\times B+P\times 4=(2B+4)P $$

其中动量内存一般采用FP32。

而如果采用RMSProp优化器,由于需要存储梯度平方的指数移动平均(二阶矩),其计算公式为$$ M_{\text{RMSProp}} = P\times B+P\times B+P\times 4=(2B+4)P $$

对于优化器,可以借助其更新步骤数学公式中$w_{t}$及$g_{t}$分别确定其参数与梯度。

下面再考虑激活值内存,这部分是动态占用。其近似范围可以使用如下公式进行估计:

$$ M_{\text{activations}}\approx (0.7\sim 1.5)\times M_{\text{params}} $$

由于该部分内存依赖于batch size、序列长度、模型结构,因此没有统一的公式。

综上所述,对于一般的情况,其总内存占用可以使用如下公式进行估计:

$$ M_{\text{total}}\approx M_{\text{params}}+M_{\text{grads}}+M_{\text{opt}} $$

主要包括参数、梯度和优化器这3部分。而对于完整版本,还需要考虑激活值及安全余量。此时可以使用如下公式:

$$ M_{\text{total}}\approx (M_{\text{params}}+M_{\text{grads}}+M_{\text{opt}}+M_{\text{activations}})\times 1.2 $$

其中1.2为系统预留内存缓冲(安全余量)。

如果喜欢这篇文章或对您有帮助,可以:[☕] 请我喝杯咖啡 | [💓] 小额赞助