RL: Training Inference Mismatch
导言
- 25年,RL训练崩溃归因于训推不一致;
- 为此提出了很多方法,TIS,Router Replay,FP16训推,batch一致性...
- 如何判断 模型当前训推不一致,并找到不一致实现处,是实践的要点。
基本概念¶
KL 散度¶
Reverse KL Divergence¶
GRPO 使用的是 Reverse KL(采样自当前策略 π_θ,参考策略为 π_ref):
\[D_{KL}(\pi_\theta \| \pi_{ref}) = \mathbb{E}_{x \sim \pi_\theta}\left[\log\frac{\pi_\theta(x)}{\pi_{ref}(x)}\right]\]
在 LLM 场景下,按 token 级别计算:
三种 KL 估计方法¶
由于直接计算 KL 的期望成本高,实践中采用蒙特卡洛近似。常见有三种估计器 [[49]][[50]]:
| 估计器 | 公式 | 特性 | 适用场景 |
|---|---|---|---|
| k1 | -log(r),其中 r = π_ref/π_θ |
无偏但方差极大,梯度不含 π_ref | PPO 中作为 reward shaping,不适合作为独立 KL loss |
| k2 | 0.5 * (log(r))² |
有偏但方差低,梯度等价于 Reverse KL | ✅ GRPO 推荐(VeRL/TRL 默认) |
| k3 | (r - 1) - log(r) |
无偏、方差低,但梯度等价于 Forward KL | 需注意:采样分布不匹配时可能不稳定 |
关键代码逻辑(VeRL/TRL 实现):
# 假设已获取 per-token logps
log_ratio = ref_logps - actor_logps # log(π_ref/π_θ)
# k1: 原始 KL (高方差)
kl_k1 = -log_ratio
# k2: 平方近似 (低方差,推荐)
kl_k2 = 0.5 * (log_ratio ** 2)
# k3: Bregman 形式 (无偏,但梯度对应 Forward KL)
kl_k3 = (log_ratio.exp() - 1) - log_ratio
# 最终 loss 加入
loss = policy_loss + beta * kl_loss_type(actor_logps, ref_logps)
🔍 为什么 k2 更推荐?
- k2 的梯度:∇_θ [0.5*(log r)²] = (log r) · ∇_θ log π_θ,恰好匹配 Reverse KL 的实用梯度形式 [[49]]
- k3 虽然无偏,但其梯度对应 Forward KL,在 π_θ 与 π_ref 差距较大时,重要性采样权重π_ref/π_θ可能爆炸,导致训练不稳定
监控指标的计算流程(每步 RL)
1️⃣ 采样阶段:
- 对每个 query,用当前策略 π_θ 生成 G 个 completions
- 同时用参考策略 π_ref 计算相同 tokens 的 log-probs
2️⃣ 计算 per-token KL:
log_ratio = log(π_ref) - log(π_θ)
kl_token = kl_loss_type(log_ratio) # k1/k2/k3 选一
3️⃣ 聚合为标量指标:
- 按 completion 平均:kl_per_seq = mean(kl_token over tokens)
- 按 batch 平均:kl_monitor = mean(kl_per_seq over all completions)
4️⃣ 用于:
- 📊 监控:TensorBoard/W&B 记录 kl 曲线,判断策略漂移
- ⚖️ 正则:loss += β * kl_monitor(β=0.001~0.04,依任务调整)
四、实践建议
-
配置选择(以 VeRL 为例)[[13]][[41]]:
-
监控阈值参考:
- KL < 0.01:策略变化过小,可能学习缓慢
- KL ∈ [0.01, 0.1]:健康更新区间
-
KL > 0.2:策略漂移过大,需检查 β 或奖励设计
-
调试技巧:
- 同时记录
kl_k2和kl_k3,若二者差异显著,说明 π_θ 与 π_ref 已偏离较大 - 若 KL 持续上升且 reward 不增,考虑定期重置 reference model(DeepSeek-R1 实践)[[50]]
log p¶
log_p 是模型对每个位置实际生成的 token 计算出的对数概率(Log Probability),属于标量值,而 hidden_state 是 Transformer 层输出的高维稠密向量。两者在计算链路、数据形态和用途上完全不同。
🔍 log_p 的完整计算链路
Input Tokens
↓ [Embedding]
Hidden States (layer 0)
↓ [Transformer Blocks × N]
Final Hidden States: h ∈ ℝ^{B×L×D} ← 这才是你问的 hidden_state
↓ [LM Head: Linear + Bias]
Logits: z ∈ ℝ^{B×L×V} ← 词表大小 V 的未归一化得分
↓ [Log Softmax]
Log Probabilities: log_p_all ∈ ℝ^{B×L×V}
↓ [Gather 实际 token ID]
log_p ∈ ℝ^{B×L} ← 你问的 log_p(每个位置一个标量)
📐 维度与形态对比
| 概念 | 形状 | 数据类型 | 物理含义 |
|---|---|---|---|
hidden_state |
[B, L, D] |
float32/16 | Transformer 输出的上下文表征向量 |
logits |
[B, L, V] |
float32/16 | 词表每个 token 的原始得分 |
log_p |
[B, L] |
float32/16 | 当前策略下,每个位置真实 token 的对数概率 |
💡
D通常为 4096/7680 等,V为词表大小(如 32k/128k),而log_p已坍缩到[B, L],只保留实际生成 token 的概率信息。
💻 代码级直观实现(PyTorch)
import torch
import torch.nn.functional as F
# 假设 model 已加载,input_ids 为 [B, L]
outputs = model(input_ids=input_ids, return_dict=True)
logits = outputs.logits # [B, L, V]
# 1. 计算全词表 log softmax
log_probs_all = F.log_softmax(logits, dim=-1) # [B, L, V]
# 2. 提取实际 token 对应的 log_p
# input_ids.unsqueeze(-1) -> [B, L, 1]
token_log_p = log_probs_all.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) # [B, L]
# 3. 通常只保留 response 部分(prompt 和 padding 用 mask 过滤)
response_mask = (input_ids >= tokenizer.vocab_size) # 示例:假设 response token ID 较大
valid_log_p = token_log_p * response_mask # 后续计算 KL/Loss 时会 mask 掉无效位置
⚠️ 常见误区澄清
| 误区 | 正确理解 |
|---|---|
“log_p 就是 hidden_state” |
hidden_state 是向量表征;log_p 是经过 LM Head + LogSoftmax + Gather 后的标量概率 |
“推理时不需要 log_p” |
推理(生成)时模型会隐式计算它用于采样;RL 训练时需显式保存用于 loss |
“log_p 越大越好” |
仅表示模型对该 token 更自信;RL 中需与 reward/advantage 结合,盲目最大化会导致 mode collapse |
| “KL 直接用 hidden_state 算” | KL 是概率分布距离,必须基于 log_p;hidden_state 是特征空间,无法直接算分布散度 |
📌 总结
log_p= 每个位置真实 token 的对数概率,形状[B, L]- 由
hidden_state → logits → log_softmax → gather得到,不是 hidden_state 本身 - 在 GRPO 中用于:KL 正则、importance ratio、策略梯度、loss mask
- 训练时需同时保存
actor_logp和ref_logp,推理时通常不显式输出但底层会计算
训推一致性¶
LogP Diff vs KL 散度:本质区别与阈值含义
| 维度 | 🔹 LogP Diff (训推一致性) | 🔹 KL 散度 (RL 正则/监控) |
|---|---|---|
| 比较对象 | 同一模型,train mode vs infer mode | 两个策略,actor π_θ vs reference π_ref |
| 核心目标 | 验证数值计算一致性(精度对齐) | 控制策略更新幅度(防止分布崩溃) |
| 数学形式 | δ = \|log_p^train - log_p^infer\| |
KL = 𝔼[log(π_θ/π_ref)] 或其近似 |
| 是否取绝对值 | ✅ 是,关注偏差大小 | ❌ 否,保留方向信息(谁更自信) |
| 是否加权 | ❌ 所有 token 平等对待 | ✅ 高概率 token 贡献更大(p·log(p/q)) |
| 量纲/单位 | log 空间的相对误差(无单位比值) | 信息论距离(nat,无单位但数值意义不同) |
| 典型阈值 | rel_diff < 0.01 (1%) | KL < 0.01~0.1 (依任务/β系数调整) |
# 在 trainer 的 logging 阶段
def compute_consistency_and_kl_metrics(batch):
metrics = {}
# 🔹 训推一致性(仅调试阶段启用)
if config.debug_consistency:
with torch.no_grad():
logp_train = model_train(input_ids).logps # train mode
logp_infer = model_infer(input_ids).logps # eval mode
rel_diff = (logp_train - logp_infer).abs() / (logp_train.abs() + 1e-8)
metrics["debug/consistency_rel_diff"] = verl_F.masked_mean(
rel_diff, batch["response_mask"]
).item()
# 🔹 KL 散度(训练必选)
log_ratio = batch["ref_logps"] - batch["actor_logps"]
kl_token = 0.5 * (log_ratio ** 2) # k2
metrics["actor/kl_loss"] = agg_loss(
kl_token, batch["response_mask"], config.loss_agg_mode
).item()
return metrics