显存消耗的组成与模型权重计算
1.1 核心问题
大模型训练时显存被什么占满了?不同量化精度下模型权重需要多少显存?
1.2 原文核心要点
深度神经网络训练的显存消耗主要包括两大部分:模型状态(模型权重、梯度、优化器状态)和激活值(各个非线性模块的中间激活值)。不同量化精度下的显存占用差异巨大。
1.3 显存消耗的两大组成部分
换句话说,显存就像你的工作台空间:一部分放置"工具箱和材料"(模型状态),一部分作为"临时加工区"(激活值)。前者大小固定,后者随工作量波动。
| 组成部分 | 具体内容 | 说明 |
|---|---|---|
| 模型状态 | 模型权重(参数)、梯度、优化器状态 | 与模型参数量Φ成正比,是固定开销 |
| 激活值 | 各个非线性模块的中间激活值 | 与batch_size和序列长度相关,是动态开销 |
1.4 模型权重与量化精度的关系
假设模型参数量为Φ(单位:参数个数),不同量化精度下的显存占用如下:
| 量化程度 | 每参数字节数 | 显存占用 | 1B参数模型 | 7B参数模型 |
|---|---|---|---|---|
| FP32 | 4字节 | 4Φ | 4GB | 28GB |
| FP16/BF16 | 2字节 | 2Φ | 2GB | 14GB |
| INT8 | 1字节 | 1Φ | 1GB | 7GB |
| INT4 | 0.5字节 | ≤1Φ | 0.5GB | 3.5GB |
1.5 模型参数量的计算公式
以Llama-3模型为例,其参数量由以下符号定义:
| 符号 | 含义 |
|---|---|
| n_vocab | 词表中词的个数 |
| d_hidden | 隐藏层维度(嵌入向量的维度) |
| n_head | 注意力头的数量 |
| n_kv-head | 分组查询注意力中的键值头数量 |
| n_layer | Transformer的层数 |
| d_FFN | 前馈神经网络的隐藏层维度 |
| b | 输入数据的批次大小(batch size) |
| s | 输入序列长度 |
模型总参数量公式:
$$
\Phi = n_{\text{vocab}} \times d_{\text{hidden}} + n_{\text{layer}} \times \left[ d_{\text{hidden}} + \left(2 + 2 \cdot \frac{n_{\text{kv}}}{n_{\text{head}}}\right) d_{\text{hidden}}^2 + d_{\text{hidden}} + 3 \cdot d_{\text{hidden}} \cdot d_{\text{FFN}} \right] + d_{\text{hidden}} + d_{\text{hidden}} \times n_{\text{vocab}}
$$
| 组成部分 | 公式项 | 说明 |
|---|---|---|
| 词嵌入层 | $n_{\text{vocab}} \times d_{\text{hidden}}$ | 词表大小 × 隐藏维度 |
| Transformer层(×$n_{\text{layer}}$) | 含 QKV 投影 + FFN | GQA 时 KV 头数 < 注意力头数 |
| 输出层 | $d_{\text{hidden}} + d_{\text{hidden}} \times n_{\text{vocab}}$ | LayerNorm + 输出投影 |
注意:
- 当 n_kv-head = 1 时为多查询注意力(MQA)
- 当 n_kv-head = n_head 时为多头注意力(MHA)
- 当 1 < n_kv-head < n_head 时为分组查询注意力(GQA)
1.6 通俗理解
直观类比
想象你在搬家,需要把所有家当装上卡车(GPU显存)。
- 模型权重= 你的家具(沙发、床、桌子)——这些是固定的,搬多少次都一样重。
- 梯度= 每件家具的"搬运说明书"——和家具数量一一对应,同样多。
- 优化器状态= 每件家具的"维修记录"和"使用日志"——Adam优化器需要记录每个参数的"动量"和"方差",所以额外占用2倍的家具重量。
- 激活值= 搬运过程中的临时存放点——搬的批次(batch_size)越多,需要的临时空间越大。
量化精度就像选择不同精度的"包装方式":
- FP32 = 用厚实的防震泡沫包裹每件家具(4字节/参数,最安全但最占空间)
- FP16 = 用薄一些的包装(2字节/参数,空间减半)
- INT8 = 只用塑料薄膜简单裹一下(1字节/参数)
核心要点
- 显存 = 模型状态(固定)+ 激活值(动态),两者都需要关注
- 量化精度每降一档,模型权重显存减半
- 7B模型仅权重(FP32)就需要28GB,整体训练显存远超单卡容量
1.7 小结
| 维度 | 说明 |
|---|---|
| 两大组成 | 模型状态(权重+梯度+优化器)+ 激活值 |
| 量化关系 | FP32=4Φ, FP16=2Φ, INT8=1Φ |
| 参数计算 | 含词嵌入层 + n_layer个Transformer层 + 输出层 |
| 关键认知 | 7B模型FP32权重=28GB,训练总显存约112GB |
2. FP32训练与混合精度训练
2.1 核心问题
FP32训练需要多少显存?混合精度训练能节省显存吗?
2.2 原文核心要点
使用AdamW优化器进行FP32训练,模型状态总显存为16Φ。混合精度训练并没有节省模型状态的显存!其真正优势是加速计算和降低激活值显存。
2.3 FP32训练的显存占用
通俗来讲,训练模型不仅要存"模型本身",还要存"每个参数的更新历史"(优化器状态),这才是显存的大头。
使用AdamW优化器进行FP32训练时: