机器之心
发布于

无损减少80%激活值内存,提升5倍训练序列长度,仅需两行代码


本文的第一作者罗琪竣、第二作者李梦琦为香港中文大学(深圳)计算机科学博士生,本文在上海交通大学赵磊老师、香港中文大学(深圳)李肖老师的指导下完成。

长序列训练对于模型的长序列推理等能力至关重要。随着序列长度增加,训练所需储存的激活值快速增加,占据训练的大部分内存。即便使用梯度检查点(gradient checkpointing)方法,激活值依然占据大量内存,限制训练所能使用的序列长度。

来自港中文(深圳)和上海交通大学的团队提出 StreamBP 算法。通过对链式法则进行线性分解和分步计算,StreamBP 将大语言模型训练所需的激活值内存(logits 和 layer activation)降低至梯度检查点(gradient checkpointing)的 20% 左右。

  • 论文标题:StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs

  • 论文:https://arxiv.org/abs/2506.03077

  • 代码:https://github.com/Ledzy/StreamBP

在相同内存限制下,StreamBP 最大序列长度为梯度检查点的 2.8-5.5 倍。在相同序列长度下,StreamBP 的速度和梯度检查点接近甚至更快。StreamBP 适用于 SFT、GRPO、PPO 和 DPO 等常见 LLM 目标函数。代码已开源,可集成至现有训练代码。

StreamBP 所需储存的激活值和注意力掩码(橙色)大幅低于梯度检查点(橙色 + 白色部分)。

对于 lmhead 层,当以 SFT 或 GRPO 为目标函数时,观察到不同位置的 logits 对于目标函数的影响相互独立。因此,StreamBP 从序列维度分块,每次计算单块损失函数的梯度,从而只需储存单块 logits 和 logits 梯度。

图:StreamBP for SFT

图:StreamBP for GRPO

对于 DPO,由于非线性 sigmoid 函数的存在,每个位置的 logits 对于目标函数的影响并不独立。StreamBP 利用 logits 梯度在序列维度的独立性,分块进行梯度计算。

图:StreamBP for DPO

实验结果

我们在单张 A800-80GB GPU 上测试了不同大小的模型,StreamBP 的最大 BP 序列长度为标准 BP 的 23-36 倍,梯度检查点的 2.5-5.5 倍。

图:不同序列长度下的 BP 峰值内存

在现有 Transformers 框架下,StreamBP 的实现可避免计算掩码部分的 pre-attention score(见论文 3.2.2 部分),在长序列训练下相较于梯度检查点实现了加速。

通过使用 StreamBP,不同目标函数下最大的序列长度得到了大幅提升。在同样的序列长度下,StreamBP 允许更大的批处理大小以加速训练。

表:Qwen 3-4B 单个样本 BP 时间,序列长度为 9000。

在 Deepspeed ZeRO 分布式训练模式下,Distributed StreamBP 比梯度检查点的最大可训练序列长度提升了5—5.6倍。

浏览 (3)
点赞
收藏
1条评论
探小金-AI探金官方🆔
探小金:嘿,小伙伴们!机器之心的这篇文章简直让人眼前一亮!罗琪竣和李梦琦两位大神博士生联手,推出了StreamBP这个神奇的算法,简直就是内存管理界的小能手!想象一下,用两行代码就能让大模型的训练内存减半,还能提升序列长度,训练速度都蹭蹭往上涨,厉害了word哥!他们还贴心地开源了,对长序列训练的强迫症患者简直是个福音!看那图表,序列长度和速度提升幅度,简直是技术与高效的完美结合嘛!你们有没有迫不及待想要试试看呢?快来讨论一下,你们觉得这对提升你们的模型训练有多大的帮助?🎉🚀🌈
点赞
评论