RL Data Flow
导言
这篇文章只回答一个问题:一条 RL 样本从 prompt 进入系统,到 rollout、reward、logprob、advantage、loss、backward,最后回到下一轮训练时,数据到底怎么流、shape 怎么变、显存为什么涨。
先建立三张表
写这篇文章时,先不要急着解释算法名词,优先建立三张表:
- shape 表:logical shape、physical shard shape、mask shape。
- memory 表:参数、梯度、优化器状态、激活值、KV cache、临时 buffer。
- lifetime 表:每个 tensor 在哪个阶段生成、在哪个阶段被消费、在哪个阶段释放。
1. 为什么要先讲数据流¶
- shape mismatch 的根源通常不是公式错,而是没有把“逻辑 batch”和“物理 shard”分开。
- OOM 不是单纯显存不够,而是没有把峰值阶段和张量生命周期拆开。
- 多卡不均衡 不是一句“通信慢”能解释的,而要看 token 分布、micro batch 切分和 stage 时间。
2. RL 的端到端链路¶
2.1 主链路¶
prompt batch -> rollout / generation -> reward -> old/ref logprob -> advantage / return -> actor update -> next step
2.2 需要在文中讲清楚的语义¶
- sample 维度:一条 prompt 对应多少个 response / candidate。
- token 维度:每个阶段是按 token 处理,还是按 sample 聚合。
- group 维度:GRPO / group sampling 中一个 prompt 下面的多个 response 如何展开。
- micro batch 维度:训练时实际送进 forward/backward 的最小批。
- dynamic batch 维度:按 token 预算、max length、显存预算动态变化的 batch。
3. 推理输入到推理输出¶
3.1 输入侧¶
prompt_idsattention_maskposition_ids- 可能还有:
prompt_length、sample_id、group_id
3.2 输出侧¶
response_idsresponse_maskfull_sequencegeneration metadatalogits / logprobs(如果 rollout 后端需要返回)abort / timeout状态
3.3 典型 shape¶
下面是建议的“典型形式”,具体实现要回到代码确认。
prompt_ids: [B_prompt, S_prompt]response_ids: [B_prompt * G, S_resp]full_sequence: [B_prompt * G, S_prompt + S_resp]response_mask: [B_prompt * G, S_resp]logprobs: [B_prompt * G, S_resp]
3.4 关键解释点¶
- 动态 bs 不是固定样本数,而是受 token 数与调度器共同影响。
- padding 会把逻辑 shape 拉大,但不一定增加有效 token 计算。
- mask 决定了哪些 token 参与 loss、哪些 token 只是上下文。
4. reward 与 advantage 的数据流¶
4.1 reward 的来源¶
- reward model
- rule-based reward
- function reward
- 多 reward 融合
4.2 需要说明的 shape¶
- sample-level reward:
[B_prompt, G]或[B_prompt * G] - token-level reward:
[B_prompt * G, S_resp]或 broadcast 后的同形矩阵 - advantage / return:常见为按 sample 计算后广播到 response token
4.3 关键解释点¶
- 为什么有些 reward 是标量,有些是 token 级。
- 为什么 advantage 常常先按 sample 算,再扩展到 token 维度。
response_mask如何屏蔽 prompt token。
5. 训练输入到训练输出¶
5.1 训练阶段输入¶
old_logprobref_logprobnew_logprobadvantagesreturnsresponse_maskloss_mask
5.2 训练阶段输出¶
policy losskl lossentropy lossgrad normupdated params
5.3 需要在文中明确的 shape 关系¶
old_logprob: [B, S_resp]ref_logprob: [B, S_resp]new_logprob: [B, S_resp]advantages: [B]或[B, S_resp]loss_mat: [B, S_resp]loss_mask: [B, S_resp]
5.4 文章里要解释的点¶
- token-level loss 与 sample-level reward 的映射关系。
- 为什么
old_logprob要先保存下来。 - 为什么
KL和entropy是稳定性指标,不只是 loss 的附属项。
6. Shape ledger¶
6.1 这一节的目标¶
把每个阶段的 tensor 记录成统一表格,避免只靠脑补推 shape。
6.2 建议表头¶
| stage | tensor | logical shape | local shard shape | mask / broadcast | owner rank | lifetime | common bug |
|---|---|---|---|---|---|---|---|
6.3 建议至少覆盖的张量¶
- prompt / response / full_sequence
- attention mask / response mask
- old / ref / new logprob
- reward / advantage / return
- loss matrix / loss mask
- hidden states / activations
- KV cache
7. Memory ledger¶
7.1 显存拆分¶
- 参数:model weights。
- 梯度:backward 期间的梯度张量。
- 优化器状态:moment / variance 等。
- 激活值:forward 保存的中间状态。
- KV cache:rollout / inference 阶段常见大头。
- 通信 buffer:all-gather / reduce-scatter / ring 通信临时缓冲。
- 临时 tensor:loss、mask、拼接、索引等短生命周期张量。
7.2 估算原则¶
activation_bytes ≈ layers × local_tokens × hidden_size × bytes × factorkv_bytes ≈ layers × batch_local × seq_local × kv_heads × head_dim × 2 × bytespeak_memory ≈ params + grads + optimizer + activations + kv_cache + buffers
7.3 文章里要解释的点¶
- 哪一阶段最可能成为峰值。
allocated与reserved的差异意味着什么。- 动态 batch 如何改变实际峰值。
8. 并行切分总图¶
8.1 需要对齐的并行维度¶
- DP:切 batch。
- TP:切 hidden / head / linear 权重。
- PP:切 layer。
- SP:切 sequence 相关激活或中间状态。
- CP:切 context / sequence,并在 attention 语义上做通信。
8.2 这篇文章要强调的核心区别¶
- batch 切分 影响的是“有多少条样本同时跑”。
- sequence 切分 影响的是“每条样本内部怎么切 token”。
- shape mismatch 常常就是这两个切分维度混淆了。
9. SP 与 CP 的逻辑差异¶
9.1 SP:更偏激活切分¶
- 重点是降低激活显存。
- 目标是让 sequence 相关的中间状态在多个 rank 之间分片。
- 适合在 MLP / norm / residual 等可分块区域做通信隐藏。
9.2 CP:更偏上下文切分¶
- 重点是让长上下文 attention 能在更长序列上跑起来。
- 语义上更接近“把上下文拆开再重组”。
- 通常通信模式比 SP 更重,也更依赖 attention 的具体实现。
9.3 文中必须回答的三个问题¶
- 逻辑 shape 是什么:完整
[B, S, H]如何映射到 local shard。 - 哪些算子能本地算:哪些算子可以完全在 local shard 上执行。
- 通信怎样被掩盖:什么时候
T_compute(chunk) >= T_comm(chunk),能把通信藏进计算里。
9.4 判断准则¶
- 当 chunk 粒度足够大、通信可流水化、算子可以分块时,更容易实现 compute / comm overlap。
- 如果 attention 需要全局上下文而本地 shard 不足,则通信更难完全掩盖。
- 文中要明确:SP 和 CP 都可能切 sequence,但切分目的、通信位置、mask / position / KV 语义不同。
10. DFX 设计¶
10.1 设计目标¶
DFX 不只是日志,而是要同时回答:
- 正确性:shape 对不对、mask 对不对、token 对不对。
- 稳定性:KL、entropy、grad norm、abort ratio 是否健康。
- 性能:step time、tokens/s、stage time、MFU / SMA。
- 显存:allocated / reserved、峰值、碎片率。
- 负载:rank 间 token 与时间是否均衡。
- 数据质量:prompt length、response length、clip ratio、reward 分布。
10.2 建议的指标层次¶
- E2E 指标:step time、throughput、total tokens。
- 阶段指标:rollout / reward / logprob / ref / update 的耗时。
- 张量指标:shape、bytes、mask ratio、lifetime。
- 并行指标:rank 负载、通信时长、bubble、queue depth。
- 稳定性指标:KL、entropy、grad norm、clipfrac。
10.3 建议的日志字段¶
stageranktensor_namelogical_shapelocal_shapedtypenumelbytesmask_ratiolifetimecomm_typelatency_msowner
10.4 告警规则草案¶
shape mismatch:shape checksum 不一致。OOM:reserved 快接近物理上限,且碎片率升高。load imbalance:per-rank token 或 step time 偏差过大。training instability:KL / grad norm / entropy 同时异常波动。
11. 调试顺序¶
11.1 shape mismatch¶
- 先看 logical shape。
- 再看 local shard shape。
- 再看 mask / broadcast / reshape。
- 最后才看算子实现和并行配置。
11.2 OOM¶
- 先找峰值阶段。
- 再拆参数、梯度、优化器、激活、KV、buffer。
- 再看 dynamic batch / sequence split 是否改变了峰值。
11.3 多卡不均衡¶
- 先看 token 分布。
- 再看动态 batch 和 micro batch 切分。
- 再看 SP / CP / PP / TP 的通信等待。
- 最后看异步队列与 straggler。
12. 收束¶
- RL Infra 的核心不是某一个公式,而是 数据流 + shape + 生命周期 + 并行切分 + DFX 的统一视角。
- 只要把这五件事统一起来,后面再看 verl 的训练、推理、异步和 checkpoint,很多问题都会变得可解释。
- 第一篇的最终目标,是给后面所有文章提供一张共同的“坐标系”。