大规模训练技术
本章从工程视角系统梳理大模型训练的“系统三角”:
显存:参数/梯度/优化器状态/激活
通信:带宽、拓扑、同步点
稳定性:混合精度、梯度尺度、长训练的可复现与容错
目标不是罗列名词,而是给出一套可落地的选型逻辑:在给定模型规模、上下文长度、集群形态下,如何组合 DP/TP/PP/ZeRO/FSDP,并在可用的成本内把训练跑稳。
分布式训练基础
数据并行
数据并行(Data Parallel, DP)把 batch 切到多卡上,每张卡做前向与反向,最后对梯度做 all-reduce 同步:
优点:概念简单,几乎所有模型都适用
缺点:单卡仍需放下完整参数与优化器状态(大模型会爆显存)
典型瓶颈:梯度同步通信。优化方向包括通信-计算重叠、ZeRO/FSDP 分片等。
模型并行
“模型并行”是一个总称,常见两类:
张量并行(Tensor Parallel, TP)
把大矩阵乘按维度切分到多卡(例如把 FFN/Attention 的线性层按列/按行切),核心代价是:每层都会引入通信(all-reduce / all-gather)。
收益是:单卡显存压力下降,可以训练更大的 (d) 或更长的序列。
参数/状态分片(ZeRO / FSDP)
把参数、梯度、优化器状态分片到不同卡上(典型是 ZeRO stage 1/2/3,或 PyTorch FSDP)。
工程直觉:
显存收益最大,但会引入更多通信与复杂同步点
checkpoint 与恢复更复杂(但生态已相对成熟)
流水线并行
流水线并行(Pipeline Parallel, PP)按层切分模型,把不同 layer 放在不同 GPU 上,并通过 micro-batch 填满流水线。
优点:显存随层切分下降,可以承载更深模型
缺点:pipeline bubble(气泡开销)、调度复杂、跨 stage 通信
现代系统常把 DP + TP + PP 组合成 3D 并行,以适配超大模型。
混合精度训练(FP16 / BF16)
混合精度的目标是:用更低精度的 GEMM 提升吞吐与降低显存,同时用策略保证数值稳定。
FP16 vs BF16
FP16:指数位更少,溢出/下溢更常见,通常需要 loss scaling
BF16:指数位更宽,更不易溢出,训练更稳(现代硬件普遍支持)
Loss Scaling(主要针对 FP16)
当梯度很小导致下溢时,引入缩放系数 (s):
前向:loss 乘以 (s)
反向:梯度也乘以 (s)
更新前:把梯度除以 (s)
并通过动态策略在溢出时降低 (s)。
梯度累积与梯度检查点
梯度累积(Gradient Accumulation)
当显存不够放下目标 batch 时,把一个大 batch 拆成多个 micro-batch,累积梯度再更新一次参数。
注意:学习率与有效 batch size 相关;改变累积步数相当于改变有效 batch,需要同步调整学习率/调度。
梯度检查点(Activation Checkpointing)
反向传播需要保存中间激活。检查点的做法是:
前向只保存少量 checkpoint
反向时对缺失部分重新前向计算
它用额外算力换显存,是训练长上下文/深层模型的常用手段。
超参数调优
大模型训练中最关键的超参通常不是“花哨的 trick”,而是这些朴素项:
学习率与调度:warmup + decay(cosine/linear)
有效 batch size:与收敛速度/稳定性强相关
权重衰减(AdamW):与泛化相关
梯度裁剪:防止少数异常 batch 破坏训练
数据混比与去重:对最终效果的影响往往大于很多结构微调
工程落地建议:如何选并行策略(简化版决策树)
能否单卡放下参数+优化器状态?
能:先用 DP(最简单)
不能:上 ZeRO/FSDP
attention/MLP 的 GEMM 是否过大导致单卡算力吃不满?
是:考虑 TP(并行 GEMM)
模型太深导致单卡激活爆显存?
是:优先 checkpoint;仍不够再上 PP
本章小结
DP 最简单但显存压力最大;ZeRO/FSDP 通过分片把大模型“放得下”。
TP 解决大 GEMM 的承载与吞吐;PP 解决深层模型的层级切分。
混合精度与 checkpoint 是“把训练跑稳且跑得动”的工程关键。