模板崩塌诊断:推理是不是变成了套话
本章讲什么:RAGEN-2 的另一个贡献——诊断工具。为什么 entropy 看不出“模板崩塌”,怎么用互信息看出来。
6.1 两种“崩塌”不一样
要解决的小问题。 训练到后期,模型的 <think> 看起来还很丰富(熵不低),但其实不管输入是什么状态,
都吐出差不多的套话。这叫模板崩塌(template collapse),传统的熵指标看不出来。
两个轴拆开看。 collapse_metrics.py:5 的文档注释说得很清楚:
| 现象 | 含义 | 指标 |
|---|---|---|
| 熵崩塌 | 同一输入下变得太确定 | H(Z|X) 低(组内多样性低) |
| 模板崩塌 | 推理变得与输入无关 | I(X;Z) 低(跨输入可区分性低) |
关键点:高熵低 MI 是隐藏的坏事——每条推理都不一样(熵高),但都跳过了输入(MI 低)。
6.2 思路:“让推理去认输入”能不能认出来
核心直觉:拿一段推理 z,问“它到底是哪个输入状态 x 生的?”。如果推理真的依赖输入,那么用生它的那个 x 去算 log p(z|x) 应该明显高于用别的 x'。反之,如果是套话,用哪个 x 算都差不多。
所以 RAGEN 构造一个 交叉 log-prob 矩阵:行是推理 z,列是候选输入 x,格子是“用 x_j 接 z_i 算的 log-prob”。
x₁ x₂ x₃ ← 候选输入(每组一个代表 prompt)
z₁ [ ★ . . ] ★=matched(真正生它的输入)
z₂ [ . ★ . ] 其他=marginal 的成分
z₃ [ . . ★ ]
↑ 采样出的推理
MI ≈ mean( matched - marginal )
检索准确率 = “每行 argmax 落在 ★ 上”的比例
MI 高 → ★ 明显高于同行其他格 → 健康;模板崩塌 → ★ 和别的差不多 → 检索准确率跳到随机水平 1/N。
6.3 X 和 Z 怎么划
这个划分是诊断能不能做准的关键(docs/reference_mutual_information_metrics.md §2):
| 变量 | 内容 |
|---|---|
| X(条件) | 系统提示 + 用户轮(状态)+ assistant 前缀 + <think> 标签 |
| Z(推理) | <think> 与 </think> 之间的推理内容(不含两个标签) |
为什么把 <think> 归入 X。 它是“开始推理”的控制标记,每条都一样,放进 Z 会用高概率常量 token 稀释指标。
同理 </think> 是格式稳定信号,放进 Z 会把“推理依赖”和“收不收得住标签”混为一谈。
这些 token 在 ctx_manager 里预先抽好:_build_first_turn_prompt_and_reasoning_ids(ctx_manager.py:637)返回
(prompt_ids, reasoning_ids),只在开了 collapse 检测且 enable_think 时才构造(ctx_manager.py:1110)。
6.4 交叉 log-prob 怎么算
_compute_cross_log_probs(collapse_metrics.py:455)是核心:对每个候选 prompt x_j,拼出 [x_j | z] 序列,
调 actor 的 compute_log_prob 做 teacher-forcing,把推理 token 的 log-prob 求和:
# 示意,非源码(改编自 collapse_metrics.py:_compute_cross_log_probs)
for j, gid in enumerate(unique_groups):
prompt = prompts[gid] # 该组的代表 prompt x_j
cross_ids = cat([prompt.expand(NK,-1), reasoning_ids], dim=1) # [x_j | z]
log_probs = compute_log_prob_fn(cross_batch)["old_log_probs"]
# 按推理长度归一化,减长度偏差
cross_log_probs[:, j] = (log_probs * mask).sum(-1) / token_counts
然后 _compute_log_prob_stats(collapse_metrics.py:566)算两个量:
- matched:每行取“真正生它的那列”(对角线)—— log p(z|x)。
- marginal:
logsumexp(所有列) - log(N)——均匀 prompt 混合下的 log p_mix(z)。
6.5 三组指标
有了 matched / marginal,三组指标顺理成章:
MI 估计(collapse_metrics.py:581):
# 示意,非源码(改编自 collapse_metrics.py:_compute_mi_estimate)
mi = matched.mean() - marginal.mean() # Î(X;Z)
mi_upper_bound = math.log(N_prompts) # 理论上限
检索准确率(collapse_metrics.py:610):argmax 落对列的比例,还算 retrieval@2/4/8 和“高出随机水平多少”。
它还处理了“不同组 prompt 恰好相同”的情况:用 prompt 签名(token 元组)把同 prompt 的列当成等价,避免误判
(collapse_metrics.py:655)。
条件熵 / 推理熵(collapse_metrics.py:693):H(Z|X) ≈ -matched.mean(),H(Z) ≈ -marginal.mean(),于是
I(X;Z) = H(Z) - H(Z|X) 与 MI 估计一致。