Attention 机制
Attention 的标准(scaled dot-product)形式可以写成:
[ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)V ]
其中:
(Q \in \mathbb{R}^{T_q \times d_k}):queries
(K \in \mathbb{R}^{T_k \times d_k}):keys
(V \in \mathbb{R}^{T_k \times d_v}):values
(M):mask(例如 causal mask / padding mask),把不允许关注的位置加上 (-\infty)
(\sqrt{d_k}):缩放项,防止点积随维度增大导致 softmax 饱和
直觉:在做什么?
对每个 query(例如“当前 token”),我们计算它与所有 key 的相似度,得到权重,再对 value 做加权求和。你可以把它看成:
检索:用 query 在 key 空间里找相关内容
聚合:把相关 value 搬运到当前位置
为什么要除以 (\sqrt{d_k})?
如果 (Q) 和 (K) 的每个维度近似零均值、方差为 1,则点积 (q \cdot k) 的方差随 (d_k) 增大而增大,softmax 会更容易进入“极端尖峰”区间,导致梯度不稳定。缩放后可以让 logits 的尺度更可控。
Mask:Transformer “自回归”的关键
在语言模型的自回归生成里,token (t) 不能看见未来 token,因此使用 causal mask:
若 (j > t),则 (M_{t,j}=-\infty)
否则 (M_{t,j}=0)
这样 softmax 后未来位置权重为 0。
复杂度:为什么推理会卡?
对于长度 (T) 的序列((T_q=T_k=T)),注意力矩阵 (QK^\top) 是 (T \times T),因此:
时间复杂度:(O(T^2 d))
显存/缓存压力:注意力权重与中间张量对长上下文非常昂贵
这推动了后续的优化方向:FlashAttention、长上下文稀疏注意力、以及推理阶段的 KV Cache(把 (K,V) 缓存起来把 decode 复杂度从 (O(T^2)) 变成每步 (O(T)))。
工程要点(你写代码时会踩的坑)
数值稳定:softmax 前通常会减去 max(log-sum-exp trick)
mask 的 dtype/广播:注意 -inf 的 dtype 与 fp16/bf16 的兼容性(实践中常用一个很小的负数近似)
padding mask:batch 内不同长度需要屏蔽 padding,否则会污染聚合