跳到主要内容

模板崩塌诊断:推理是不是变成了套话

本章讲什么:RAGEN-2 的另一个贡献——诊断工具。为什么 entropy 看不出“模板崩塌”,怎么用互信息看出来。

6.1 两种“崩塌”不一样

要解决的小问题。 训练到后期,模型的 <think> 看起来还很丰富(熵不低),但其实不管输入是什么状态, 都吐出差不多的套话。这叫模板崩塌(template collapse),传统的熵指标看不出来。

两个轴拆开看。 collapse_metrics.py:5 的文档注释说得很清楚:

现象含义指标
熵崩塌同一输入下变得太确定H(Z|X) 低(组内多样性低)
模板崩塌推理变得与输入无关I(X;Z) 低(跨输入可区分性低)

关键点:高熵低 MI 是隐藏的坏事——每条推理都不一样(熵高),但都跳过了输入(MI 低)。

6.2 思路:“让推理去认输入”能不能认出来

核心直觉:拿一段推理 z,问“它到底是哪个输入状态 x 生的?”。如果推理真的依赖输入,那么用生它的那个 x 去算 log p(z|x) 应该明显高于用别的 x'。反之,如果是套话,用哪个 x 算都差不多。

所以 RAGEN 构造一个 交叉 log-prob 矩阵:行是推理 z,列是候选输入 x,格子是“用 x_j 接 z_i 算的 log-prob”。

x₁ x₂ x₃ ← 候选输入(每组一个代表 prompt)
z₁ [ ★ . . ] ★=matched(真正生它的输入)
z₂ [ . ★ . ] 其他=marginal 的成分
z₃ [ . . ★ ]
↑ 采样出的推理

MI ≈ mean( matched - marginal )
检索准确率 = “每行 argmax 落在 ★ 上”的比例

MI 高 → ★ 明显高于同行其他格 → 健康;模板崩塌 → ★ 和别的差不多 → 检索准确率跳到随机水平 1/N。

6.3 X 和 Z 怎么划

这个划分是诊断能不能做准的关键(docs/reference_mutual_information_metrics.md §2):

变量内容
X(条件)系统提示 + 用户轮(状态)+ assistant 前缀 + <think> 标签
Z(推理)<think></think> 之间的推理内容(不含两个标签)

为什么把 <think> 归入 X。 它是“开始推理”的控制标记,每条都一样,放进 Z 会用高概率常量 token 稀释指标。 同理 </think> 是格式稳定信号,放进 Z 会把“推理依赖”和“收不收得住标签”混为一谈。

这些 token 在 ctx_manager 里预先抽好:_build_first_turn_prompt_and_reasoning_idsctx_manager.py:637)返回 (prompt_ids, reasoning_ids),只在开了 collapse 检测且 enable_think 时才构造(ctx_manager.py:1110)。

6.4 交叉 log-prob 怎么算

_compute_cross_log_probscollapse_metrics.py:455)是核心:对每个候选 prompt x_j,拼出 [x_j | z] 序列, 调 actor 的 compute_log_prob 做 teacher-forcing,把推理 token 的 log-prob 求和:

# 示意,非源码(改编自 collapse_metrics.py:_compute_cross_log_probs)
for j, gid in enumerate(unique_groups):
prompt = prompts[gid] # 该组的代表 prompt x_j
cross_ids = cat([prompt.expand(NK,-1), reasoning_ids], dim=1) # [x_j | z]
log_probs = compute_log_prob_fn(cross_batch)["old_log_probs"]
# 按推理长度归一化,减长度偏差
cross_log_probs[:, j] = (log_probs * mask).sum(-1) / token_counts

然后 _compute_log_prob_statscollapse_metrics.py:566)算两个量:

  • matched:每行取“真正生它的那列”(对角线)—— log p(z|x)。
  • marginallogsumexp(所有列) - log(N)——均匀 prompt 混合下的 log p_mix(z)。

6.5 三组指标

有了 matched / marginal,三组指标顺理成章:

MI 估计collapse_metrics.py:581):

# 示意,非源码(改编自 collapse_metrics.py:_compute_mi_estimate)
mi = matched.mean() - marginal.mean() # Î(X;Z)
mi_upper_bound = math.log(N_prompts) # 理论上限

检索准确率collapse_metrics.py:610):argmax 落对列的比例,还算 retrieval@2/4/8 和“高出随机水平多少”。 它还处理了“不同组 prompt 恰好相同”的情况:用 prompt 签名(token 元组)把同 prompt 的列当成等价,避免误判 (collapse_metrics.py:655)。

条件熵 / 推理熵collapse_metrics.py:693):H(Z|X) ≈ -matched.mean(),H(Z) ≈ -marginal.mean(),于是 I(X;Z) = H(Z) - H(Z|X) 与 MI 估计一致。

6.6 何时算、怎么采样

频率。 should_computecollapse_metrics.py:228):step 1 算一次,之后每 compute_freq 步算一次;step 0 不算。 只在 full 模式算collapse_metrics.py:93)——只有 full 模式下“同 group_id = 同 prompt”这个前提才成立。

两种采样策略。

  • 首轮(first_turn):_sample_first_turn_pairscollapse_metrics.py:720),只拿每条轨迹的第一轮。
  • 轨迹均匀(trajectory-uniform):_sample_trajectory_uniformcollapse_metrics.py:798),先均匀选轨迹再均匀选轮, 每条轨迹权重相等(不偏长轨迹)。

过滤前算,保证公平。 主循环里 collapse 指标是在 filter 之前算的(agent_trainer.py:1070),这样诊断看的是 原始分布而不是过滤后的。

6.7 巧妙之处

  • EMA z-score 让指标可比。 matched-marginal 的绝对值随任务差别很大,collapse_metrics.py:316 用 marginal 的标准差 (加 EMA 平滑)归一化成 z-score,跨 step / 跨任务都能比。
  • 随机水平作准绳。 检索准确率总是配一个 retrieval_chance_levelcollapse_metrics.py:687)——只看绝对准确率会被 N 忽悠,“高出随机多少”才是健康信号。
  • 多卡 padding 对齐。 log_prob_world_size>1 时把交叉批次 pad 到能被 world_size 整除(collapse_metrics.py:521), 避免分布式 compute_log_prob 报错。

6.8 边界与局限

  • 只在 full 模式用。turn-level 模式下直接返回空指标(collapse_metrics.py:93)。
  • 需要 ≥2 个不同 prompt_compute_metrics_for_pairs 里 N_prompts<2 直接返空(collapse_metrics.py:254)。
  • 偏贵。交叉 log-prob 是 O(N 个 prompt × NK 个推理) 的 forward,所以默认只每 compute_freq 步算一次。
  • turn-uniform 暂关compute_turn_uniform=Falsecollapse_metrics.py:152),代码保留但不走。

6.9 横向对比

RAGEN 是 ai-frontier-reference 里“把诊断当一等公民”的 agent RL 框架:许多 RLHF / agent-RL 库(如 veRL 本身、 TinyZero)主要提供训练能力,而 RAGEN 额外把“为什么崩、怎么提前发现”做成了一等公民功能:奖励方差过滤(第3章)

  • MI 诊断(本章)+ 三种 early stopping(第2章)。它复用 veRL 做分布式训练,自己则专注于“多轮缝合 + 诊断”。

6.10 代码地图

主题文件符号
诊断总入口ragen/trainer/collapse_metrics.pyCollapseDetector.compute_collapse_metrics
交叉 log-prob 矩阵ragen/trainer/collapse_metrics.py_compute_cross_log_probs
matched / marginalragen/trainer/collapse_metrics.py_compute_log_prob_stats
MI 估计ragen/trainer/collapse_metrics.py_compute_mi_estimate
检索准确率ragen/trainer/collapse_metrics.py_compute_retrieval_accuracy
条件熵 / 推理熵ragen/trainer/collapse_metrics.py_compute_reasoning_entropy
X/Z 抽 tokenragen/llm_agent/ctx_manager.py_build_first_turn_prompt_and_reasoning_ids · _build_all_turns_prompt_and_reasoning_ids
何时算ragen/trainer/collapse_metrics.pyshould_compute
接入主循环ragen/trainer/agent_trainer.pyfitcollapse_metrics) · init_workers
原理参考docs/reference_mutual_information_metrics.md