# 大规模训练技术 本章从工程视角系统梳理大模型训练的“系统三角”: - **显存**:参数/梯度/优化器状态/激活 - **通信**:带宽、拓扑、同步点 - **稳定性**:混合精度、梯度尺度、长训练的可复现与容错 目标不是罗列名词,而是给出一套可落地的选型逻辑:在给定模型规模、上下文长度、集群形态下,如何组合 DP/TP/PP/ZeRO/FSDP,并在可用的成本内把训练跑稳。 ## 分布式训练基础 ### 数据并行 数据并行(Data Parallel, DP)把 batch 切到多卡上,每张卡做前向与反向,最后对梯度做 all-reduce 同步: - 优点:概念简单,几乎所有模型都适用 - 缺点:单卡仍需放下完整参数与优化器状态(大模型会爆显存) 典型瓶颈:梯度同步通信。优化方向包括通信-计算重叠、ZeRO/FSDP 分片等。 ### 模型并行 “模型并行”是一个总称,常见两类: #### 张量并行(Tensor Parallel, TP) 把大矩阵乘按维度切分到多卡(例如把 FFN/Attention 的线性层按列/按行切),核心代价是:每层都会引入通信(all-reduce / all-gather)。 收益是:单卡显存压力下降,可以训练更大的 \(d\) 或更长的序列。 #### 参数/状态分片(ZeRO / FSDP) 把参数、梯度、优化器状态分片到不同卡上(典型是 ZeRO stage 1/2/3,或 PyTorch FSDP)。 工程直觉: - **显存收益最大**,但会引入更多通信与复杂同步点 - checkpoint 与恢复更复杂(但生态已相对成熟) ### 流水线并行 流水线并行(Pipeline Parallel, PP)按层切分模型,把不同 layer 放在不同 GPU 上,并通过 micro-batch 填满流水线。 - 优点:显存随层切分下降,可以承载更深模型 - 缺点:pipeline bubble(气泡开销)、调度复杂、跨 stage 通信 现代系统常把 DP + TP + PP 组合成 3D 并行,以适配超大模型。 ## 混合精度训练(FP16 / BF16) 混合精度的目标是:用更低精度的 GEMM 提升吞吐与降低显存,同时用策略保证数值稳定。 ### FP16 vs BF16 - **FP16**:指数位更少,溢出/下溢更常见,通常需要 loss scaling - **BF16**:指数位更宽,更不易溢出,训练更稳(现代硬件普遍支持) ### Loss Scaling(主要针对 FP16) 当梯度很小导致下溢时,引入缩放系数 \(s\): - 前向:loss 乘以 \(s\) - 反向:梯度也乘以 \(s\) - 更新前:把梯度除以 \(s\) 并通过动态策略在溢出时降低 \(s\)。 ## 梯度累积与梯度检查点 ### 梯度累积(Gradient Accumulation) 当显存不够放下目标 batch 时,把一个大 batch 拆成多个 micro-batch,累积梯度再更新一次参数。 注意:学习率与有效 batch size 相关;改变累积步数相当于改变有效 batch,需要同步调整学习率/调度。 ### 梯度检查点(Activation Checkpointing) 反向传播需要保存中间激活。检查点的做法是: - 前向只保存少量 checkpoint - 反向时对缺失部分重新前向计算 它用额外算力换显存,是训练长上下文/深层模型的常用手段。 ## 超参数调优 大模型训练中最关键的超参通常不是“花哨的 trick”,而是这些朴素项: - **学习率与调度**:warmup + decay(cosine/linear) - **有效 batch size**:与收敛速度/稳定性强相关 - **权重衰减(AdamW)**:与泛化相关 - **梯度裁剪**:防止少数异常 batch 破坏训练 - **数据混比与去重**:对最终效果的影响往往大于很多结构微调 --- ## 工程落地建议:如何选并行策略(简化版决策树) 1. **能否单卡放下参数+优化器状态?** - 能:先用 DP(最简单) - 不能:上 ZeRO/FSDP 2. **attention/MLP 的 GEMM 是否过大导致单卡算力吃不满?** - 是:考虑 TP(并行 GEMM) 3. **模型太深导致单卡激活爆显存?** - 是:优先 checkpoint;仍不够再上 PP --- ## 本章小结 - DP 最简单但显存压力最大;ZeRO/FSDP 通过分片把大模型“放得下”。 - TP 解决大 GEMM 的承载与吞吐;PP 解决深层模型的层级切分。 - 混合精度与 checkpoint 是“把训练跑稳且跑得动”的工程关键。