Muon Optimizer + FSDP
导言
muon 优化器在FSDP场景下 ,xtuner以及业界先进方法是如何实现的。
xtuner的实现是根据相同shape串性把tensor来all2all,专家tensor不够fsdp_size时,还需要padding。并且内存快照时发现all2all要申请一个大buffer 35B 128卡 sp4 256k,好像有10GB左右。
这份文档围绕两个问题展开:
- FSDP + Muon 的 All2All 在内存快照中申请了约 10GB 大 buffer,该如何避免?是否可以分块通信?
- veScale-FSDP 中关于 RaggedShard + Muon 的设计有什么值得学习的地方?
先给结论:
- 10GB All2All buffer 大概率不是 Muon 必然开销,而是 bucket 粒度过大、padding、临时 pack/unpack buffer、Newton-Schulz workspace 叠加导致的峰值。
- 可以分块通信,但应优先“按矩阵组分块”,不要把单个 Muon 矩阵切碎后分别正交化,否则会改变优化器语义。
- exact Muon 的下界是:某个 rank 至少要持有一个完整 2D 矩阵及其 Newton-Schulz 临时 workspace。这个下界无法通过普通 All2All 消除。
- veScale-FSDP 的关键思想不是简单换一个 collective,而是把“结构感知布局”做进 FSDP:RaggedShard 表示不规则 sharding,Planner 减少 padding,DBuffer 做持久化零拷贝通信 buffer。
背景:FSDP 与 Muon 的冲突¶
Muon 的核心更新不是逐元素操作,而是对 2D 矩阵的 momentum 做近似正交化:
AdamW 可以在 FSDP local shard 上直接做,因为它是 element-wise update。 但 Muon 不行,因为:
也就是说,在每个 FSDP rank 的 shard 上独立跑 Newton-Schulz,不等价于对完整矩阵跑 Muon。
因此 exact FSDP + Muon 通常需要:
- 每个 rank 持有矩阵的一个 shard;
- 通过 All2All / gather 把完整矩阵重组到某些 rank;
- 这些 rank 执行 Newton-Schulz;
- 再把更新后的 shard 发回原 FSDP layout。
XTuner 那种“同 shape tensor 分桶,然后 All2All;数量不够 FSDP size 时 padding”的方案,本质上是在实现这个 exact 语义。
问题一:为什么 All2All buffer 会到 10GB¶
假设:
如果实现一次性处理一个大 bucket,那么每个 rank 至少可能需要:
input pack buffer ≈ C * P
output full buffer ≈ C * P
Newton-Schulz workspace ≈ alpha * C * P
reverse All2All buffer ≈ C * P
padding/alignment ≈ extra
所以峰值近似是:
其中 alpha 取决于 Newton-Schulz 实现,可能是 1 到 3 甚至更高。
如果一个 35B 模型里某些 MLP 矩阵本身就有数百 MB,而一个 bucket 让每个 rank 同时接收多个完整矩阵,那么 10GB 峰值并不意外。
还要注意:SP4、256k sequence length 本身通常不直接决定 Muon All2All buffer 大小。Muon buffer 主要由模型参数形状、FSDP group、bucket 策略、dtype 决定。但 256k 长序列会让 activation、grad bucket、allocator reserved memory 压力更大,导致 optimizer step 的临时 buffer 更容易成为 OOM 触发点。
分块通信:可以,但要按矩阵分块¶
推荐分块方式¶
应该把一个大 Muon bucket 拆成多个 micro-bucket:
big bucket:
[W0, W1, W2, ..., W127]
micro-bucket 0:
[W0, W1, ..., W15]
micro-bucket 1:
[W16, W17, ..., W31]
...
每个 micro-bucket 执行:
1. pack local momentum shards
2. All2All: shard layout -> full matrix layout
3. Newton-Schulz on full matrices
4. All2All: full update layout -> original shard layout
5. apply update
6. reuse buffer for next micro-bucket
关键是:每个 Muon 单元仍然是完整矩阵。
不推荐的分块方式¶
不要把单个矩阵切成多个 row block 分别做 Muon:
这不再是原始 Muon,而是 block-wise Muon / approximate Muon。它可能可用,但优化器语义已经变了,loss 曲线和稳定性都需要重新验证。
例外情况是 fused tensor:
这类 tensor 可以先按语义拆成多个逻辑矩阵,再分别做 Muon。这个拆分是合理的,因为 Muon 单元本来就应该是每个逻辑矩阵,而不是整个 fused 大 tensor。
内存控制方案¶
1. 给 Muon 设置独立 bucket cap¶
不要沿用 FSDP 的通信 bucket 大小,也不要把所有 same-shape tensor 一次性 All2All。
建议引入:
muon:
comm_bucket_cap_mb: 512 # 起步建议 512MB 或 1GB
max_inflight_buckets: 1 # 内存紧张时先设 1
preallocate_workspace: true
dtype: bf16
估算 micro-bucket 大小:
def estimate_muon_peak(full_bytes, ns_alpha=2.0, reverse_buffer=True):
# full_bytes: 当前 micro-bucket 中每个 rank 需要持有的完整矩阵总 bytes
base = 2 * full_bytes # input + output
ns = ns_alpha * full_bytes
reverse = full_bytes if reverse_buffer else 0
return base + ns + reverse
选择 micro-bucket 时应满足:
如果当前看到 10GB buffer,可以先把 cap 降到:
然后观察 optimizer step time 的变化。通常这会增加 All2All 次数,但能显著降低峰值显存。
2. 预分配并复用 workspace¶
不要每个 bucket 都:
而应该在 optimizer 初始化时预分配固定 workspace:
workspace = MuonWorkspace(
in_buf=torch.empty(max_bytes, dtype=torch.bfloat16, device="cuda"),
out_buf=torch.empty(max_bytes, dtype=torch.bfloat16, device="cuda"),
ns_buf=torch.empty(max_ns_bytes, dtype=torch.bfloat16, device="cuda"),
)
每个 micro-bucket 只使用 narrow/view:
这样可以避免:
- PyTorch caching allocator 反复申请大块内存;
- stream lifetime 导致旧 buffer 不能及时复用;
- memory fragmentation;
- snapshot 中出现多个临时大块并存。
veScale-FSDP 的 DBuffer 思想也类似:用持久化 distributed buffer 和地址映射来避免反复 copy/alloc。
3. 限制 in-flight bucket 数量¶
为了 overlap,很多实现会同时挂多个异步 All2All:
这有利于性能,但会增加峰值显存。
在 35B、128 卡、SP4、256k 这种 activation 压力很高的场景,建议先关闭 aggressive overlap:
稳定后再尝试:
不要一开始就让 3 到 4 个 Muon chunks 同时在飞。
4. 尽量使用 BF16 通信和计算 buffer¶
检查 All2All buffer 的 dtype。
如果 momentum 或 update buffer 是 FP32:
Muon 的大规模实现通常会尽量让通信和 Newton-Schulz 主路径使用 BF16 / FP16 / Tensor Core 友好的格式。 但需要注意:这可能影响数值稳定性。建议至少记录:
如果 BF16 Muon 不稳定,可以只对 Newton-Schulz 的某些归一化标量保留 FP32,而不是让整个 All2All buffer 变成 FP32。
5. 用 all_to_all_single 的 split sizes 或 ragged all-to-all 减少 padding¶
XTuner 的 same-shape bucket 通常要求:
如果 expert tensor 数量少于 FSDP size,会产生严重 padding:
可以改成两条路径:
dense/common shape:
equal-size all_to_all_single,走高带宽路径
expert/small/irregular shape:
ragged all_to_all / all_to_all_single(split_sizes) / gather-to-root
如果后端支持 variable split sizes,优先避免 dummy tensor padding。 如果 variable all-to-all 性能不理想,小 bucket 可以退回 P2P gather/scatter 或 AdamW。
6. 跨 layer 合并 expert bucket¶
不要按 layer 做 expert bucket:
应该跨 layer 合并:
例如:
num_layers = 32
experts_per_layer = 8
fsdp_size = 64
按层 bucket:
每层 8 个,padding 87.5%
跨层 bucket:
32 * 8 = 256 个
可切成 4 个 64-bucket,几乎无 padding
这是 MoE + FSDP + Muon 中非常关键的优化。
7. 对不划算的参数退回 AdamW¶
你前面已经观察到:在 Qwen 35B SFT 中,Muon 每步 loss 下降未必比 AdamW 快。
因此在 SFT 场景没必要追求 Muon 覆盖率 100%。建议:
Muon:
attention q/k/v/o projection
dense MLP gate/up/down
大多数 routed expert MLP
AdamW:
embedding
lm_head
norm
bias
router/gate
LoRA 参数
padding ratio 过高的小 expert bucket
极大且导致显存峰值的个别矩阵
可以加策略:
if bucket.padding_ratio > 0.5:
use_adamw(bucket)
if estimated_muon_peak(bucket) > hard_cap:
split_bucket_or_use_adamw(bucket)
if sft and eval_not_improved_by_muon:
reduce_muon_coverage()
8. HSDP 降低 FSDP group size¶
veScale-FSDP 论文也给出类似经验:不要盲目扩大 FSDP group size,必要时用 HSDP 控制 collective group。
例如总共 128 卡:
方案 A:
fsdp_size = 128
dp_replicas = 1
方案 B:
fsdp_size = 64
dp_replicas = 2
方案 C:
fsdp_size = 32
dp_replicas = 4
较小的 fsdp_size 通常可以:
- 减少 expert padding;
- 降低 collective group 复杂度;
- 改善 NCCL latency 和 LCM rounding;
- 让 bucket 更容易规划。
代价是:
- 每卡参数 shard、grad shard、optimizer state 变大;
- DP replica 之间还需要同步梯度或 optimizer state;
- 总显存不一定下降,需要实测。
对于 35B SFT,如果 activation 才是主要压力,HSDP 未必能直接省显存;但如果 All2All padding 和大 bucket 是主要问题,HSDP 很值得试。
一个推荐的分块执行伪代码¶
class MuonFSDPExecutor:
def __init__(self, fsdp_group, workspace_cap_bytes, max_inflight=1):
self.fsdp_group = fsdp_group
self.workspace = preallocate_workspace(workspace_cap_bytes)
self.max_inflight = max_inflight
def step_bucket(self, bucket):
chunks = plan_micro_buckets(
bucket,
cap_bytes=self.workspace.cap_bytes,
cost_fn=estimate_muon_peak,
)
for chunk in chunks:
# 1. 本地 momentum update
local_m_shards = update_momentum_local(chunk)
# 2. pack 到持久化 input buffer
in_view = self.workspace.pack(local_m_shards)
# 3. shard layout -> full matrix layout
out_view = self.workspace.alloc_output(chunk)
dist.all_to_all_single(
output=out_view,
input=in_view,
group=self.fsdp_group,
)
# 4. 每个 rank 对自己负责的完整矩阵跑 Newton-Schulz
full_mats = unpack_full_matrices(out_view, chunk)
updates = []
for mat in full_mats:
updates.append(newton_schulz(mat))
# 5. pack update,反向 All2All 回原 FSDP shard layout
update_view = self.workspace.pack(updates)
shard_update_view = self.workspace.alloc_shard_output(chunk)
dist.all_to_all_single(
output=shard_update_view,
input=update_view,
group=self.fsdp_group,
)
# 6. apply local shard update
apply_update_local(chunk.params, shard_update_view)
# 7. workspace 逻辑释放,下一 chunk 复用
self.workspace.reset()
重点:
veScale-FSDP 的核心思想¶
veScale-FSDP 论文认为,传统 FSDP 的 element-wise 或 row-wise fixed sharding 难以支持结构感知训练,例如 Muon、Shampoo、block-wise quantization。它提出三个关键组件:
RaggedShard¶
RaggedShard 是一种 DTensor placement,用来表达不规则 sharding。
传统 Shard(0) 通常要求均匀切分:
RaggedShard 允许:
也允许:
这对于 Muon 很有用,因为可以把某个完整矩阵重分布到一个 root rank:
original FSDP placement:
each rank owns a shard
RaggedShard(root):
only root owns the full 2D matrix
other ranks own empty tensor
veScale 文档中也明确提到,Muon 的 Newton-Schulz 需要完整 2D 参数矩阵,RaggedShard 可以通过 DTensor.redistribute 表达 gather -> compute -> scatter 这个过程。
Structure-aware Planner¶
如果只是把 RaggedShard tensor 简单拼起来,可能出现:
veScale 的 planner 目标是:
这点对 Muon 和 MoE 都很重要。
对于你看到的 expert tensor padding,veScale 的思路不是简单“补 dummy tensor 到 fsdp_size”,而是做全局 layout planning,尽量减少 padding 和 LCM rounding。
论文中还给出经验:不要使用过大的 FSDP group size,可以通过 HSDP 控制 shard group,并通过离线模拟选择 padding 最小的 FSDP size。
DBuffer¶
DBuffer 是 veScale-FSDP 的通信 buffer 抽象。
它的目标是:
1. 持久化分配通信 buffer
2. 多 tensor group-level 操作
3. zero-copy access
4. in-place communication/computation
5. 降低 PyTorch allocator fragmentation
这正好对应你看到的 10GB 临时 All2All buffer 问题。
如果没有 DBuffer,一个朴素实现通常会反复:
这会在 memory snapshot 中出现大量临时大块。
借鉴 DBuffer 后,应把 Muon 的通信区改成:
veScale 的 Distributed Muon 流程¶
veScale-FSDP 论文中的 Muon 逻辑可以概括为:
for each 2D parameter w:
g = grad(w)
u = MomentumUpdate(g, m)
p = original placement(u)
r = SelectRoot() # 负载均衡选择 root
o = Redistribute(u, RaggedShard(r)) # root 持有完整矩阵
o = NewtonSchulz(o) # 只有 root 真正计算
o = Redistribute(o, p) # 回到原 FSDP shard
w = w - lr * o
这个设计和 XTuner same-shape All2All 的目标类似,都是 exact Muon。 但抽象层次不同:
| 方案 | 核心思路 | 优点 | 风险 |
|---|---|---|---|
| XTuner-style same-shape All2All | 同 shape tensor 批量重排 | 简单,高带宽,容易实现 | padding、大 bucket、大 buffer |
| veScale RaggedShard | 用 placement 表达不规则 gather/scatter | 语义清晰,减少 padding,适合结构感知 optimizer | 需要 DTensor/RaggedShard/Planner 支撑 |
| DBuffer | 持久化通信 buffer 和地址映射 | 降低 allocator 峰值和 copy | 工程复杂度更高 |
| Rooted gather | 每个矩阵选 root rank | 避免同 shape 数量不足 padding | 需要负载均衡和异步 overlap |
推荐架构¶
如果你要从当前 XTuner-style FSDP + Muon 演进,我建议分三层做。
第一层:保留 exact All2All,但加 micro-bucket¶
这是最容易落地的改造。
目标:
把 10GB 临时 buffer 降到 1GB ~ 2GB 可控范围
做法:
1. same-shape bucket 保留
2. bucket 内按 workspace_cap 切 micro-bucket
3. max_inflight 先设为 1
4. 所有临时 buffer 预分配并复用
建议配置:
muon:
exact: true
comm_dtype: bf16
workspace_cap_mb: 1024
max_inflight_buckets: 1
preallocate_workspace: true
fallback_min_fill_ratio: 0.5
第二层:MoE expert 改成全局规划¶
针对 expert tensor:
1. 跨 layer 合并 same-shape experts
2. padding ratio > 50% 的 bucket 退回 AdamW 或走 ragged path
3. 如果 expert 是 [E, out, in] fused layout,并且 shard 在 E 维,则本地 per-expert Muon,不需要 All2All
判断逻辑:
if is_expert_batch_sharded(param):
# local shard already owns complete expert matrices
run_local_per_expert_muon(param)
elif bucket.padding_ratio <= 0.5:
run_all2all_muon(bucket)
else:
run_adamw(bucket)
第三层:引入 RaggedShard / Rooted Muon¶
当 same-shape + padding 已经成为主要瓶颈时,再引入 veScale 风格设计:
1. 每个 Muon 矩阵选择一个 root rank
2. 使用 ragged placement 表示 root 持有完整矩阵
3. redistribute 到 root
4. root 上 Newton-Schulz
5. redistribute 回原 placement
6. 使用 planner 控制 root 负载和 buffer cap
root 选择可以按 estimated cost 做负载均衡:
def select_root(matrix, rank_load):
cost = matrix.numel() * matrix.dtype.itemsize
root = min(rank_load, key=rank_load.get)
rank_load[root] += cost
return root
不要简单 round-robin,因为不同矩阵大小差异很大。
针对 35B / 128 卡 / SP4 / 256k 的建议¶
优先级如下:
-
先确认 All2All buffer dtype 如果是 FP32,优先改成 BF16。
-
把 Muon bucket cap 降到 512MB 或 1GB 先牺牲一点 optimizer step time,换取显存稳定。
-
关闭多 chunk overlap
max_inflight_buckets=1,稳定后再开到 2。 -
预分配 Muon workspace 不要在每个 step、每个 bucket 动态
torch.empty大 tensor。 -
跨 layer 合并 expert bucket 避免每层 expert 数量小于 FSDP size 导致巨量 padding。
-
padding ratio 高的 expert bucket 退回 AdamW SFT 中 Muon 未必带来收益,不值得为这些参数付出 10GB buffer。
-
评估 HSDP 比如:
用离线脚本估算 padding 和 per-rank state,再实测。
- 检查 optimizer step 前 activation 是否真正释放 256k 下 activation 压力极大。可以诊断性地在 optimizer 前插入同步,确认是否是 stream lifetime 导致临时 buffer 共存:
这不是最终性能方案,但可以帮助定位峰值来源。
需要记录的指标¶
建议每个 Muon bucket 打印:
bucket_name
logical_shape
num_real_tensors
num_padded_tensors
padding_ratio
full_bytes_per_rank
estimated_comm_buffer_bytes
estimated_ns_workspace_bytes
actual_allocated_before
actual_allocated_after
actual_peak_allocated
all2all_in_time_ms
newton_schulz_time_ms
all2all_out_time_ms
示例:
logger.info(
"[muon_bucket] key=%s real=%d padded=%d pad=%.2f "
"full_rank=%.2fGB comm=%.2fGB ns=%.2fGB "
"t_a2a_in=%.2fms t_ns=%.2fms t_a2a_out=%.2fms",
bucket.key,
bucket.real_count,
bucket.padded_count,
bucket.padding_ratio,
full_bytes_per_rank / 2**30,
comm_bytes / 2**30,
ns_bytes / 2**30,
t_in,
t_ns,
t_out,
)
没有这些指标,很难判断 10GB 是来自:
bucket 太大
padding 太多
dtype 不对
NS workspace 太大
多 chunk overlap
allocator fragmentation
activation 未释放
结论¶
对于你的场景,最实际的路线是:
短期:
XTuner-style same-shape All2All 保留
加 micro-bucket + workspace cap + BF16 + max_inflight=1
padding 高的 expert bucket 退回 AdamW
中期:
跨 layer expert bucket
expert-batch-sharded fast path
HSDP 调整 fsdp_size
长期:
学 veScale-FSDP
引入 RaggedShard / rooted redistribution / planner / persistent DBuffer
一句话总结:
FSDP + Muon 的核心不是“能不能 All2All”,而是“能否在保持完整矩阵 Muon 语义的同时,把重分布、padding、workspace 和 allocator 生命周期都纳入统一规划”。XTuner 的实现解决了 correctness,veScale-FSDP 的设计进一步解决了 layout、padding 和 buffer 生命周期问题。
参考资料¶
- veScale-FSDP paper: veScale-FSDP: Flexible and High-Performance FSDP at Scale 1
- veScale RaggedShard 文档: RaggedShard Placement 2
- veScale GitHub: volcengine/veScale 3
- PyTorch / TorchTitan FSDP2 notes: torchtitan FSDP documentation 4
- Microsoft Dion / Muon distributed implementation: microsoft/dion 5