· 6 min read ·

What Breaks When You Train RL on a Production MoE Model

Source: huggingface

Supervised fine-tuning is forgiving. You can use different kernels for forward and backward passes, approximate certain operations differently across devices, and generally tolerate small numerical divergences without observable effect on the loss. Reinforcement learning for language models does not offer the same tolerance, and the LinkedIn and HuggingFace retrospective on GPT-OSS agentic RL, published in January 2026, demonstrates exactly why.

The core requirement in on-policy RL is that the importance sampling ratio π(a|s) / π_old(a|s) should equal 1.0 for freshly-sampled data. Any deviation, even from floating-point-level differences between two forward passes, corrupts the gradient signal. In practice this means the system used to generate rollouts and the system used to compute training gradients must agree, at the token level, on log-probabilities. For a model with a conventional dense transformer architecture, this is largely guaranteed by using the same kernel code. For a Mixture of Experts model with learnable attention sink parameters, it turned out to be a substantial engineering problem.

The Model and the Framework

GPT-OSS is LinkedIn’s open-source language model, available in 20B and 120B parameter variants, built as a Mixture of Experts architecture. The training was done using verl, a framework that pairs FSDP (Fully Sharded Data Parallel) for training with vLLM or SGLang for rollout inference. That separation of training and inference engines is standard practice for throughput, but it creates a gap that becomes critical in RL.

The evaluation tasks span single-turn math reasoning (GSM8K), multi-turn tool use where the model calls a code execution environment across multiple steps (ReTool), and an out-of-distribution instruction-following benchmark called VerifyIf. Hardware: 16 H200 GPU nodes for the 20B model in bf16 precision, with context configured at 8k tokens prompt and up to 16k tokens maximum response.

Problem One: MoE Routing Non-Determinism

The first failure mode appears immediately. In PPO, you compute old log-probabilities π_old(a|s) during rollout, then later recompute π(a|s) during the gradient update step, expecting the ratio to be exactly 1.0 for on-policy data. In a dense transformer, two forward passes with identical weights produce identical outputs. In a MoE model, they may not.

MoE gating networks route each token to a subset of experts. The routing decision involves a top-k selection over expert scores, and the floating-point arithmetic in that selection can differ between two separate forward passes, particularly when running in different parallelism configurations or with different batch layouts. The result is nonzero clip ratios on freshly-sampled data, which should be impossible under correct on-policy training.

When data is freshly sampled and therefore on-policy, recomputing log-probabilities serves no purpose except to introduce routing divergence. Setting old_log_prob = log_prob.detach() eliminates the mismatch. The clip ratio drops to zero, as it should for data that has not been replayed. The fix is specific to freshly-sampled trajectories; replayed data still requires the second forward pass for the importance weight to be meaningful.

Problem Two: Training-Inference Kernel Mismatch

The second problem has a different character. Inference engines like vLLM and SGLang use heavily optimized custom kernels for throughput, including custom FlashAttention forks, quantization paths, and attention backends that differ numerically from the FSDP training path. Token-level log-probability values computed during rollout differ from those recomputed during training, and the divergence compounds across a sequence.

The symptom is exploding gradient norms and flat reward curves even on simple tasks like GSM8K. The partial remedy is rollout correction, a sequence-level importance sampling technique documented in the verl rollout correction documentation, which stabilizes gradient norms. It does not fully close the gap, however. Stable gradient norms are a necessary precondition for convergence, but the root cause remained in the attention computation itself, and freezing the attention layers during ablation confirmed this directly.

Problem Three: The Attention Sink Divergence

This is the central problem the retrospective resolves. GPT-OSS implements learnable attention sinks, a concept introduced in Xiao et al. (2023). In standard attention, the softmax over key-query scores distributes probability mass entirely across the sequence tokens. With attention sinks, each attention head has a learnable scalar parameter that participates in the softmax as a virtual token, giving the distribution somewhere to concentrate mass other than content positions:

# Standard attention
scores = Q @ K.T / sqrt(d_k)           # [B, H, N, N]
probs  = softmax(scores, dim=-1)        # rows sum to 1 over sequence
output = probs @ V

# GPT-OSS attention with learnable sinks
extended  = concat([scores, sink_param], dim=-1)   # [B, H, N, N+1]
probs     = softmax(extended, dim=-1)               # rows sum to 1 over sequence + sink
probs_seq = probs[..., :-1]                         # discard sink column
output    = probs_seq @ V

The sink column receives probability mass proportional to exp(sink_param) relative to the content scores but contributes nothing to the output. The backward pass through this mechanism requires computing gradients through a truncated softmax, which is non-standard and not implemented in either FlashAttention v2 or the standard FA v3 release.

vLLM maintains a custom FlashAttention fork with sink support. The training path in verl uses standard FlashAttention v2, which has no sink support whatsoever. The log-probability divergence visible in the experiments is structural: rollout and training are computing fundamentally different attention operations for every head in the model, with the sink probability absorbing mass differently depending on which kernel executes.

To fix this, the team implemented the FlashAttention v3 sink backward pass from scratch. The forward pass was adapted from the vLLM FA fork; the backward was derived analytically. The key gradient expression through the extended softmax is:

Content probabilities P_ij, sink probability P_ih per query position i:
dL/dS_h = -sum_i( P_ih * sum_{j in content}( P_ij * dL/dS_ij ) )

The effect on training is unambiguous in the published figures: without the fix, GSM8K reward curves plateau or collapse; with FA3 plus sink support, convergence is faster and gradient norms remain stable throughout. The VerifyIf task, more sensitive to gradient stability, shows outright training collapse without the fix, along with entropy divergence that confirms the model is being pushed in contradictory directions.

Problem Four: MoE Memory Materialization

The fourth problem is operational. Running GPT-OSS under FSDP for log-probability computation triggered a CUDA out-of-memory error: an attempt to allocate 180 GiB on a GPU with 139.72 GiB total capacity.

The source was a code path in Hugging Face Transformers that materializes all expert tensors simultaneously, appropriate for inference batch processing but not for training iteration. FSDP was invoking the module in .eval() mode, which selected the inference path. Patching the mode selection to force the sequential expert-iteration path resolved the OOM; the issue is tracked in HuggingFace Transformers issue #40073.

The expanded context window also required sequence parallelism, splitting the sequence across GPUs with all-to-all communication before and after attention layers. Non-attention layers require no cross-device communication in this scheme, and per-GPU activation memory scales inversely with the parallelism degree. Variable-length token support in FlashAttention v3 integrates cleanly with this layout.

The Broader Pattern

Most RL training writeups describe clean experimental results. This one describes four failures, their root causes, and the fixes. That kind of engineering transparency is uncommon and more useful than a polished results table, particularly for teams evaluating whether to attempt similar work.

The training-inference kernel mismatch problem is not unique to GPT-OSS. Any model that uses custom attention variants in its inference kernels, including quantized attention, sliding window attention with custom cache management, or non-standard attention backends, will have some version of this problem when combined with an RL training loop that uses a different kernel stack. The severity scales with how much the attention semantics differ; for a learnable architectural parameter like an attention sink scalar, the divergence is large enough to prevent convergence entirely.

Related work on MoE RL stabilization raises a parallel concern about routing alignment: routing decisions affect which parameters receive gradients, so a systematic routing mismatch between rollout and training is a structural bias in which parts of the model learn at all, not merely numerical noise.

Supervised fine-tuning has low sensitivity to these kernel-level differences; on-policy RL has essentially none. The important audit before attempting RL on a non-standard model architecture is comparing what each inference kernel computes against what the training kernel computes, and determining how those differences propagate across a sequence. The GPT-OSS work is a concrete record of what it costs to discover those differences empirically rather than in advance.

The FA3 sink backward implementation is planned for upstream release following internal review, which would make it available for other teams using GPT-OSS with verl. The ReTool recipe is already publicly available for reference.

Was this interesting?