大规模训练技术

本章从工程视角系统梳理大模型训练的“系统三角”:

  • 显存:参数/梯度/优化器状态/激活

  • 通信:带宽、拓扑、同步点

  • 稳定性:混合精度、梯度尺度、长训练的可复现与容错

目标不是罗列名词,而是给出一套可落地的选型逻辑:在给定模型规模、上下文长度、集群形态下,如何组合 DP/TP/PP/ZeRO/FSDP,并在可用的成本内把训练跑稳。

分布式训练基础

数据并行

数据并行(Data Parallel, DP)把 batch 切到多卡上,每张卡做前向与反向,最后对梯度做 all-reduce 同步:

  • 优点:概念简单,几乎所有模型都适用

  • 缺点:单卡仍需放下完整参数与优化器状态(大模型会爆显存)

典型瓶颈:梯度同步通信。优化方向包括通信-计算重叠、ZeRO/FSDP 分片等。

模型并行

“模型并行”是一个总称,常见两类:

张量并行(Tensor Parallel, TP)

把大矩阵乘按维度切分到多卡(例如把 FFN/Attention 的线性层按列/按行切),核心代价是:每层都会引入通信(all-reduce / all-gather)。

收益是:单卡显存压力下降,可以训练更大的 (d) 或更长的序列。

参数/状态分片(ZeRO / FSDP)

把参数、梯度、优化器状态分片到不同卡上(典型是 ZeRO stage 1/2/3,或 PyTorch FSDP)。

工程直觉:

  • 显存收益最大,但会引入更多通信与复杂同步点

  • checkpoint 与恢复更复杂(但生态已相对成熟)

流水线并行

流水线并行(Pipeline Parallel, PP)按层切分模型,把不同 layer 放在不同 GPU 上,并通过 micro-batch 填满流水线。

  • 优点:显存随层切分下降,可以承载更深模型

  • 缺点:pipeline bubble(气泡开销)、调度复杂、跨 stage 通信

现代系统常把 DP + TP + PP 组合成 3D 并行,以适配超大模型。

混合精度训练(FP16 / BF16)

混合精度的目标是:用更低精度的 GEMM 提升吞吐与降低显存,同时用策略保证数值稳定。

FP16 vs BF16

  • FP16:指数位更少,溢出/下溢更常见,通常需要 loss scaling

  • BF16:指数位更宽,更不易溢出,训练更稳(现代硬件普遍支持)

Loss Scaling(主要针对 FP16)

当梯度很小导致下溢时,引入缩放系数 (s):

  • 前向:loss 乘以 (s)

  • 反向:梯度也乘以 (s)

  • 更新前:把梯度除以 (s)

并通过动态策略在溢出时降低 (s)。

梯度累积与梯度检查点

梯度累积(Gradient Accumulation)

当显存不够放下目标 batch 时,把一个大 batch 拆成多个 micro-batch,累积梯度再更新一次参数。

注意:学习率与有效 batch size 相关;改变累积步数相当于改变有效 batch,需要同步调整学习率/调度。

梯度检查点(Activation Checkpointing)

反向传播需要保存中间激活。检查点的做法是:

  • 前向只保存少量 checkpoint

  • 反向时对缺失部分重新前向计算

它用额外算力换显存,是训练长上下文/深层模型的常用手段。

超参数调优

大模型训练中最关键的超参通常不是“花哨的 trick”,而是这些朴素项:

  • 学习率与调度:warmup + decay(cosine/linear)

  • 有效 batch size:与收敛速度/稳定性强相关

  • 权重衰减(AdamW):与泛化相关

  • 梯度裁剪:防止少数异常 batch 破坏训练

  • 数据混比与去重:对最终效果的影响往往大于很多结构微调


工程落地建议:如何选并行策略(简化版决策树)

  1. 能否单卡放下参数+优化器状态?

    • 能:先用 DP(最简单)

    • 不能:上 ZeRO/FSDP

  2. attention/MLP 的 GEMM 是否过大导致单卡算力吃不满?

    • 是:考虑 TP(并行 GEMM)

  3. 模型太深导致单卡激活爆显存?

    • 是:优先 checkpoint;仍不够再上 PP


本章小结

  • DP 最简单但显存压力最大;ZeRO/FSDP 通过分片把大模型“放得下”。

  • TP 解决大 GEMM 的承载与吞吐;PP 解决深层模型的层级切分。

  • 混合精度与 checkpoint 是“把训练跑稳且跑得动”的工程关键。