多头注意力(Multi-Head Attention)

单头注意力用一组 (Q,K,V) 做一次检索与聚合。多头注意力的动机是:让模型在不同子空间并行地学习不同的对齐关系(例如语法依赖、实体指代、局部模式等)。

标准形式:

[ \mathrm{MHA}(X)=\mathrm{Concat}(\mathrm{head}_1,\ldots,\mathrm{head}_H)W_O ]

其中每个 head:

[ \mathrm{head}_h=\mathrm{Attention}(XW_Q^{(h)}, XW_K^{(h)}, XW_V^{(h)}) ]

通常取 (d_k=d_v=d/H),使得拼接后维度回到 (d)。

为什么多头有效:一个工程直觉

如果只有一头,注意力权重矩阵((T\times T))只有一张“关系图”。多头相当于有 (H) 张关系图并行学习,然后再线性组合回主干表示。这让模型既能捕获:

  • 局部模式(例如邻近 token)

  • 长程依赖(例如跨句依赖)

  • 不同类型的相关性(语法/语义/格式/引用等)

实现要点:QKV 打包与 reshape

工程实现常见做法:

  • 一次线性层得到 (QKV \in \mathbb{R}^{T \times 3d})

  • reshape 为 ((T, H, d/H)),并转置到适合 GEMM 的布局

这样做的理由是:减少 kernel launch、便于算子融合、提升吞吐。

注意:多头不是“越多越好”

在固定 (d) 下增大头数会让每头维度 (d/H) 变小,可能造成表达瓶颈;另一方面头数也会影响 KV cache 带宽与推理性能。实践里常见的选择由模型规模与硬件决定(例如 7B/13B/70B 的头数通常不同)。

Grouped-Query Attention(GQA)与 Multi-Query Attention(MQA)

这是推理优化中的高频改动,核心是:减少 K/V 的头数以降低 KV cache 的显存与带宽开销。

  • MHA:每个 head 都有独立的 (K,V)

  • GQA:多个 query head 共享一组 (K,V)(按 group 共享)

  • MQA:所有 query head 共享同一组 (K,V)

代价是表达能力可能下降,但对长上下文推理的性价比很高(尤其在吞吐受 KV 带宽限制时)。