SNR-Adaptive Filtering:筛掉学不到东西的 prompt
本章讲什么:RAGEN-2 的一个轻量干预——在梯度更新前,按“组内奖励方差”筛掉低信号的 prompt 组。
5.1 为什么要过滤
要解决的小问题。 GRPO 用“组内差异”做优势。但如果一个 prompt 的 N 条轨迹全对或全错,组内方差≈ 0,优势也≈ 0——这个 prompt 产生的梯度几乎是噪声,却照样占 GPU、拉低信噪比(SNR)。
思路(信噪比代理)。 用组内奖励方差当“这个 prompt 有多少可学东西”的轻量代理:方差大 = 有的 轨迹好有的坏 = 梯度有信息;方差小 = 要么太难要么太简单 = 没什么可学。每步只留高方差的组。
README 还指出这直接治“模板崩塌”的根因(参第 4 章)。这是 V2 两个贡献中“下药”的那个(另一个是“诊断”)。
5.2 接入点:在主循环哪里调
过滤发生在 rollout 之后、算优势之前(agent_trainer.py:1078):
# 示意,非源码(改编自 agent_trainer.py:fit)
with marked_timer("filter", timing_raw):
batch, filter_metrics = self.rollout_filter.filter(batch)
# 留下的样本比例记进 meta_info,给后面 loss scaling 用
batch.meta_info["filter_kept_ratio"] = metrics.get("rollout/filter_kept_ratio", 1.0)
过滤器在 init_workers(agent_trainer.py:603)用 build_rollout_filter 造好,参数全从 rollout 配置读:
rollout_filter_metric(默认 reward_variance)、rollout_filter_strategy、rollout_filter_value、rollout_filter_type。
5.3 三类指标维度
build_rollout_filter(rollout_filter.py:768)按 metric 选出三种过滤器:
| metric | 过滤器类 | 按什么算组分 |
|---|---|---|
reward / reward_sum / reward_variance | RewardRolloutFilter | 组内奖励的均值 / 和 / 标准差 |
entropy / entropy_variance | EntropyRolloutFilter | 组内熵的均值 / 标准差(需要 compute_log_prob) |
length | LengthRolloutFilter | 组内响应长度的均值 |
默认的主角是 reward_variance:RewardRolloutFilter._selection_scores(rollout_filter.py:355)返回 in_group_std。
5.4 从“一批轨迹”到“组打分”
过滤以**组(group)**为单位(整组留或整组丢)。RewardRolloutFilter.filter(rollout_filter.py:367)把 batch
按 num_groups × group_size reshape,算每组的 std/mean/max/sum:
# 示意,非源码(改编自 rollout_filter.py:RewardRolloutFilter.filter)
rm_scores = scores.view(num_groups, group_size) # 每行一组
in_group_std = rm_scores.std(dim=-1) # 组内方差 = 信号代理
top_groups = self._select_top_groups(in_group_std) # 选高信号组
turn-level 模式要先按 episode 聚。 若 batch 里有 episode_ids(single_turn / limited_multi_turn),过滤器先把
同一 episode 的多个样本聚成一个 episode 级奖励,再 reshape 成组(rollout_filter.py:376)——这样“组”的含义始终是
“同一初始状态的 N 条 episode”,而不是 token 样本。
5.5 四种选组策略
_select_top_groups(rollout_filter.py:68)根据 strategy 分派。画个“组按得分排序后怎么划线”的图:
组按得分从高到低排好: [g1][g2][g3][g4][g5][g6][g7][g8]
│←─── 保留 ───→│←── 丢弃 ──→│
top_p (linear): 累加得分 ≥ p×总和 就停(nucleus 风格)
top_k : 保留前 k=⌊p×num_groups⌋ 个组
top_k_abs : 保留前 k 个(k 是绝对个数,不是比例)
min_p : 保留得分 ≥ max_score×p 的组(阈值相对最高分)
top_p 的两种聚合模式。(rollout_filter.py:91)
linear(默认):阈值 =p × sum(scores) - eps,从高到低累加得分直到超阈,遇到 ≤0 的分就停(rollout_filter.py:115)。softmax:先对得分做 softmax 变成概率,再做经典 nucleus 累积截断(rollout_filter.py:98)。
value >= 1.0 是“不过滤”的特例。 默认 config/base.yaml:60 的 rollout_filter_value: 1.0 就是不过滤(全保留);
要开 V2 过滤才设成如 0.9。filter 在这种情况下提前返回,但会把 reward_std 贴到 batch 上(rollout_filter.py:458)。
5.6 过滤之后:loss scaling
过滤会改变有效批量,为了不让梯度量级突变,RAGEN 可选按 filter_kept_ratio 缩放优势(agent_trainer.py:1403):
# 示意,非源码(改编自 agent_trainer.py:fit)
if filter_loss_scaling == "linear":
batch.batch["advantages"] *= filter_kept_ratio
elif filter_loss_scaling == "sqrt":
batch.batch["advantages"] *= (filter_kept_ratio ** 0.5)
还有一个可选的 soft advantage reweight(agent_trainer.py:1374):不硬筛,而是按每个 prompt 的
std / max_std 软加权,把低方差 prompt 的梯度按比例调小。
5.7 顺手服务的梯度分析
RewardRolloutFilter.split_into_buckets(rollout_filter.py:502)能把 batch 按 reward_std 分桶(等分位或固定 RV 间隔),
供梯度分析用——这是 RAGEN “诊断”定位的一部分(配合 gradient_reporter.py,docs/guide_gradient_analysis.md)。
5.8 边界与坑
- “组”的整除要求。 turn-level 模式下
num_episodes % num_groups != 0会直接报错(rollout_filter.py:402)。 - 全过滤掉的保护。 连续多步过滤出 0 样本会触发 early stopping(
agent_trainer.py:1143,阈值rollout_filter_empty_stop_steps)。 - min_p 重复返回语句。
rollout_filter.py:180-182有一行死代码(return indices[mask]写了两遍),无功能影响。
5.9 代码地图
| 主题 | 文件 | 符号 |
|---|---|---|
| 过滤器工厂 | ragen/trainer/rollout_filter.py | build_rollout_filter |
| 选组策略(top_p/k/min_p) | ragen/trainer/rollout_filter.py | RolloutFilter._select_top_groups |
| 奖励方差过滤 | ragen/trainer/rollout_filter.py | RewardRolloutFilter.filter |
| 熵过滤 | ragen/trainer/rollout_filter.py | EntropyRolloutFilter.filter |
| 梯 度分析分桶 | ragen/trainer/rollout_filter.py | RewardRolloutFilter.split_into_buckets |
| 接入主循环 + loss scaling | ragen/trainer/agent_trainer.py | fit(filter / filter_loss_scaling) · init_workers |
| soft reweight | ragen/trainer/agent_trainer.py | fit(soft_advantage_reweight) |