大模型训练时如何计算显存占用
2026/6/11 13:04:53 网站建设 项目流程

首先了解一些基本概念

以Llama13B为例,

首先是输入输出:

这里的2是因为每个值都是float16,占两个字节

然后转换为MB,输入输出相加为20MB,所占显存大小和其他部分相比可以忽略不计

这里的2是因为每个值都是float16,占两个字节

1B和1GB大致相当

都是float32存储的

为什么优化器要存模型参数

  • 从归属上看:模型参数属于 Model,优化器属于 Optimizer。
  • 从物理内存上看:优化器不复制模型参数,而是通过引用直接修改它们;但优化器会为每个参数分配额外的状态缓存(如动量缓冲池)。
  • 在大模型显存规划中:评估优化器带来的显存压力时,必须将这部分“辅助状态”计算在内(例如 Adam 需要额外增加约 8~16 GB/十亿参数的显存消耗,具体取决于精度格式)。

为什么平滑值不能用float16

因为会丢失精度,梯度很小,学习率更小

在反向传播中会用到前向传播中的激活值

https://zhuanlan.zhihu.com/p/673916177

关于激活值显存占用更详细可以参考上面这个链接

具体的 "34" 是一个经验估算值或特定实现下的精确计数,涵盖了 LayerNorm 的统计量、MLP 层的多个线性变换输入输出缓存等。

这里的系数 "5" 可能对应:Q, K, V, Score, Output 这 5 个主要张量的保存需求。

激活值计算好像漏乘了2(FP16占两个字节)

计算 QKT。其中 Q 和 K 的 shape 都是[b, a, s, h/a]。矩阵乘法后,得到的分数矩阵 shape 为[b, a, s, s]

显存占用:需要保存 Q 和 K 用于反向传播,大小为 bsh。

分数矩阵本身:大小为 bs^2a。

在计算总显存时,Attention模块与序列长度相关的主要二次方项来自于 bs^2a ,将sbh提取出括号后得到 as/h

参考视频:

RethinkFun投稿视频-RethinkFun视频分享-哔哩哔哩视频

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询