跳转至

MuonClip / QK-Clip 与 QKV bias:为什么 attention 的 q/k projection 要 bias=False

工程笔记 / 训练踩坑记录 — 2026-06-08


一句话结论 (TL;DR)

QK-Clip 是在「Q/K projection 没有 bias」这个隐含前提下设计的——它只缩放权重 W_q/W_k,从不碰 bias。 给一个带 bias 的 attention 套上 Muon + QK-Clip,网络会把 attention logit 的幅度「藏」进 QK-Clip 够不到的 bias 路径里,导致 weight 被反复缩小直到学死、而 qkv bias 越长越大。修复就是把 q/k(最好连 v/o)projection 设成 bias=False,对齐 Kimi/DeepSeek/LLaMA/Qwen3 的主流做法。

1. 现象:weight 学死,bias 很大

一个用 FlatFormer 的 LiDAR 模型,用 Muon + QK-Clip 做大数据、长迭代训练。训练之后发现:

  • 模型的一些 layer 的 weight「学死了」(趋于 0);
  • 因为有 skip connection,端到端最终效果还可以
  • debug 定位到:qkv projection 的 bias 变得很大

2. 猜想(已被 QK-Clip 的数学和主流配置佐证)

用了 QK-Clip 来压住 attention logit 不要太大,但 QK-Clip 缩放的是 weight,没有处理 bias。于是网络为了把 attention logit 学大,就让 bias 学大,weight 则越来越小直到死掉。

这个猜想是对的,而且能直接从 QK-Clip 的公式推出来(见 §4)。

3. 背景:QK-Clip 与「主流都 bias=False」

3.1 QK-Clip(来自 Kimi K2 的 MuonClip)做了什么

为了抑制 Muon 训练中的 attention logit 爆炸(vanilla Muon 在中等规模就能让 max logit 冲到 1000+,引发 loss spike / 发散),Kimi K2 提出 QK-Clip

  • per-head 监控最大 logit S_max^h,算缩放因子 γ_h = min(1, τ / S_max^h)(Kimi 用 τ = 100);
  • 更新后直接缩放权重W_q^h ← γ^α · W_q^hW_k^h ← γ^(1-α) · W_k^hα = 0.5,q/k 各开方分担);
  • 不改当前 step 的前后向,只把 max logit 当信号去约束权重增长。Kimi K2 用 MuonClip 训了 15.5T token,零 loss spike

关键点:Kimi 论文里 attention 写成 Q^h = X·W_q^h通篇没有 bias 项;Fireworks 的拆解、Megatron-Core 的 qk_clip 实现,也都只动 W_q/W_k,从不碰 bias。换句话说 QK-Clip 默认 Q/K 是无 bias 的

3.2 现代 attention 普遍 bias=False

模型 QKV projection bias
GPT-2 / 早期 transformer ✅ 有(PyTorch nn.Linear 默认 bias=True
Qwen2 ✅ 故意加 QKV-bias(说是增强外推)
LLaMA / DeepSeek-V3 bias=False(DeepSeek attention_bias 默认 False)
Kimi K2 ❌ 无 bias(MLA,沿用 DeepSeek-V3 架构)
Qwen3 ❌ 删掉 QKV-bias,改用 QK-Norm
Gemma3 / ViT-5 ❌ bias-free(配 RMSNorm + QK-Norm)

注意 Qwen2 → Qwen3 的演变:Qwen2 专门加了 QKV-bias,Qwen3 把它删掉换成 QK-Norm——整个领域在主动远离 QKV bias。Kimi 的 QK-Clip 正是建立在「无 bias」这条主流惯例之上

4. 为什么会把 weight 学死、bias 学大(机理)

把带 bias 的 logit 展开(query \(i\) / key \(j\) / 某个 head):

S_ij = (W_q·x_i + b_q) · (W_k·x_j + b_k)
     =  (W_q x_i)·(W_k x_j)   ← T1  内容 × 内容
      + (W_q x_i)·b_k          ← T2  内容 × bias
      +  b_q·(W_k x_j)         ← T3  bias × 内容
      +  b_q·b_k               ← T4  bias × bias(常数)

QK-Clip 把 W_q,W_k 各乘 γ^0.5,对这四项的压制严重不均等

受 QK-Clip 缩放 对 attention 分布的作用
T1 内容×内容 ×γ(完全压制) 正常的 content attention
T2 / T3(含一个 bias) ×γ^0.5(只压一半 T3 是「query 无关的 key 偏好」,有用
T4 bias×bias ×1(完全免疫) 对所有 key 是常数 → softmax 平移不变 → 对 attention 毫无贡献

于是形成病态正反馈:

  1. bias 是 QK-Clip 够不到的「避难所」:要维持 logit 幅度,走 W 每次被 γ 砍,走 b 只被 γ^0.5 砍甚至不砍。梯度发现「把信号塞进 bias 更划算」,于是 b_q/b_k 不断长大、W_q/W_k 反复被砍又不重建 → weight → 0(学死)、bias → 大
  2. T4 反咬一口b_q·b_k 这个常数项对 attention 没用(softmax 平移不变),却会抬高 S_max——而 S_max 正是 QK-Clip 的触发信号。结果是没用的常数 bias 把 logit 顶上去,QK-Clip 却去砍有用的 content 权重,加速 weight 死亡。
  3. skip connection 兜底:head 退化成 query 无关(W_q ≈ 0,只剩常数 query),但残差让网络绕过这个废 head,所以端到端指标「还可以」——刚好掩盖了问题。

这与观察到的 weight 学死 + qkv bias 很大 + 有 skip 所以效果还行 完全吻合

诚实说明:公开资料里没有明确把 QK-Clip 与这个 bias 补偿失效模式串起来的讨论(现有对 QK-Clip 的批评是另一回事——「压 logit 削弱了表达极端 confidence 的能力」)。所以这个诊断基本是原创发现,但它与 QK-Clip 的数学、与「主流 bias=False」的事实三方自洽。

5. 怎么修(按推荐度排序)

  • 首选:q/k(最好连 v/o)projection 设 bias=False —— 直接对齐 Kimi/DeepSeek/LLaMA/Qwen3。T2/T3/T4 全消失,QK-Clip 自洽,盲区消失。最贴合「common practice」。
  • 若必须保留 bias:把 bias 一起纳入 QK-Clip 缩放 —— 让整条投影输出 W·x + b 都乘 γ^α(而不是只缩 W),四项都按预期被压制。
  • 换 QK-Norm(Qwen3 / Gemma3 的做法) —— 对 q、k 做 RMSNorm,无论幅度来自 weight 还是 bias 都被归一化,从机制上堵死 bias 补偿;社区普遍视为近乎「免费午餐」,是目前更受青睐的稳定化手段。

⚠️ 别想着加 weight decay 去「压住 bias」——那只会把本就在死的 weight 更快推向 0。

6. 很可能的根因 & 排查动作

base 是 FlatFormer(LiDAR 点云 transformer),attention 是标准 MHA 风格。和绝大多数非 LLM 的视觉 / 点云代码一样,QKV 很可能直接用了 PyTorch nn.Linear 的默认 bias=True——这正是 QK-Clip 假设不存在的那个 bias。

  • grep attention 模块里 q/k/v 的 nn.Linear(..., bias=?),确认是不是 True
  • 把 q/k(建议连 v/o)改成 bias=False 重训,复查 weight 范数是否不再塌、qkv bias 是否不再发散。
  • 顺手记录各 head 的 S_max 曲线与 q/k weight 范数,作为回归监控指标。

本质:把一个为「bias-free LLM」设计的稳定化技巧(MuonClip)搬到一个「默认带 bias」的视觉模型上时的典型水土不服。


关于这条笔记

训练踩坑记录,非论文总结。来源是 2026-06-08 排查 FlatFormer LiDAR 模型 Muon + QK-Clip 训练中「weight 学死、qkv bias 变大」时的诊断 + 对 Kimi K2 / 主流 attention 配置的查证。

参考: - Kimi K2 Technical Report — arXiv 2507.20534(QK-Clip 公式、τ=100、per-head、MLA) - Fireworks — Deep-dive into MuonClip(QK-Clip 只缩 W_q/W_k) - Megatron-Core core.optimizer.qk_clip(参考实现,只动权重) - DeepSeek-V3 — HuggingFace docsattention_bias 默认 False)、Technical Report arXiv 2412.19437 - Qwen3 用 QK-Norm 替换 Qwen2 的 QKV-bias;QK-Norm 综述见 Sebastian Raschka"QK-Norm is probably a free lunch"