🌀7 倍更长上下文的强化学习 GRPO

了解 Unsloth 如何实现超长上下文的 RL 微调。

强化学习(RL)最大的挑战是支持长推理轨迹。我们正在引入新的批处理算法以实现约7 倍更长的上下文 (可能超过 12 倍) RL 训练在准确性或速度上不劣于使用 FA3、内核和分块损失的其他优化设置。

  • Unsloth 现在使用以下配置训练 gpt-oss QLoRA: 380K 上下文 在单个 192GB 的 NVIDIA B200 GPU 上

  • Qwen3-8B GRPO 达到 110K 上下文 在 80GB VRAM 的 H100 上通过 vLLM 和 QLoRA,并且 65K 用于 gpt-oss 使用 BF16 LoRA。

  • 在 24GB VRAM 上,gpt-oss 达到 20K 上下文,Qwen3-VL 可达 32K, Qwen3-VL-8B QLoRA

  • Unsloth GRPO RL 可与 Llama、Gemma 及所有模型自动支持更长上下文一起运行

我们新的数据移动和批处理内核与算法解锁了更多 上下文 通过:

circle-info

您可以在 Unsloth 中将所有特性结合使用:

  1. Unsloth 的 权重共享 功能与 vLLMarrow-up-right 以及我们在 内存高效 RL

  2. Unsloth 的 Flex Attention 中的备用(Standby)特性,适用于长上下文的 gpt-oss,以及我们的 500K Context Training

  3. 中的 Float8 训练, FP8 RL 以及 Unsloth 的 异步梯度检查点(async gradient checkpointing)arrow-up-right 以及更多功能

🎉入门

要开始,您可以使用任何现有的 GRPO 笔记本 (或在本地更新 Unsloth):

将 Unsloth 应用于您的 RL 任务可为高效管理大规模模型提供稳健的框架。为了有效利用 Unsloth 的增强功能:

  • 硬件建议:建议使用 NVIDIA H100 或等效设备以实现最佳 VRAM 利用率。

  • 配置提示:请确保 batch_sizegradient_accumulation_steps 设置与您的计算资源对齐以获得最佳性能。

circle-check

我们的基准测试突出了与早期版本相比在 GPT OSS 和 Qwen3-8B 上实现的内存节省。下面两个图(不含 备用(standby))是在 batch_size = 4gradient_accumulation_steps=2 的情况下运行的,因为 standby 设计上会使用所有 VRAM。

在我们的基准中,我们将 BF16 GRPO 与在所有优化启用情况下的 Hugging Face 进行比较(kernels 库中的所有内核、Flash Attention 3、分块损失内核等):

🔢扁平化序列长度分块

以前,Unsloth 通过在批次维度上分块来避免 logits 张量的完全物化,从而减少了 RL 的内存使用。前向传递期间物化 logits 所需 VRAM 的粗略估计如公式(1)所示。

Equation 1: Logit Memory (GB)=batch size×context length×vocab dim10243\text{Equation 1: } \text{Logit Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{vocab dim}}{1024^3}

使用此公式,配置为 batch_size = 4, context_length = 8192,并且 vocab_dim = 128,000 将大约需要 3.3 GB 的 VRAM 来存储 logits 张量。

通过 Long Context gpt-oss 去年,我们随后为 GRPO 引入了融合损失方法。该方法确保一次仅处理单个批样本,从而显著降低峰值内存使用。在相同配置下,VRAM 使用降至约 0.83 GB,如公式(2)所示。

Equation 2: Logit Memory (GB)=context length×vocab dim10243\text{Equation 2: }\text{Logit Memory (GB)} = \frac{\text{context length} \times \text{vocab dim}}{1024^3}
图 1:gpt-oss BF16 GRPO LoRA(Unsloth vs. HF 在所有优化开启的情况)
图 2:Qwen3-8B QLoRA GRPO LoRA(Unsloth vs. HF 在所有优化开启的情况)

在本次更新中,我们通过引入沿 序列维度 的分块进一步扩展了相同思路。我们不再一次性为整个 (batch_size × context_length) 空间物化 logits,而是将这些维度扁平化并使用可配置的乘数按较小块处理。这使 Unsloth 在不增加峰值内存使用的情况下支持显著更长的上下文。

在下面的图 5 中,我们使用的乘数为 max(4, context_length // 4096),尽管可以根据所需的内存-性能权衡指定任意乘数。使用此设置,相同示例配置(batch_size = 4, context_length = 8192, vocab_dim = 128,000)现在仅需要 0.207 GB 的 VRAM 用于 logits 的物化。

Equation 3: Logit Memory (GB)=context lengthmultiplier×vocab dim10243\text{Equation 3: }\text{Logit Memory (GB)} = \frac{\frac{\text{context length}}{\text{multiplier}} \times \text{vocab dim}}{1024^3}
图 3:gpt-oss-20b(H100)Unsloth 新版 vs. 旧版
图 4:Qwen3-8B(H100)Unsloth 新版 vs. 旧版
图 5:gpt-oss-20b(H100)
图 6:Qwen3-8B(B200)

此更新反映在下方编译的 chunked_hidden_states_selective_log_softmax 中,该实现现在支持跨批次和序列两个维度的分块。为了保留 logits 张量( [batch_size, context_length, vocab_dim] ),它始终在批次维度上进行分块。额外的序列分块由 GRPO 配置中的 unsloth_logit_chunk_multiplier 控制;如果未设置,则默认为 max(4, context_length // 4096)。在下面的示例中, input_ids_chunk[0] 对应于优化 2 中隐藏状态小批次的大小。

  1. 我们使用带有自定义编译选项的 torch.compile 以减少 VRAM 并提高速度。

  2. 所有分块的 logits 都会被提升为 float32 以保留精度。

  3. 我们支持 logit 软上限、温度缩放以及所有其他功能。

👻隐藏状态分块

我们还观察到,在更长的上下文长度下,隐藏状态可能成为内存使用的重要来源。为演示起见,我们假设 hidden_states_dim=4096。相应的内存使用遵循与 logits 情况类似的公式,如下所示。

Hidden States Memory (GB)=batch size×context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{hidden states dim}}{1024^3}

batch_size = 8context_length = 64000的情况下,这将导致大约 2 GB的 VRAM 使用。在此版本中,我们引入了在计算对数概率时对隐藏状态张量在批次维度上的可选分块。这将使 VRAM 使用按批次大小划分,在本例中为 0.244 GB。这减少了物化隐藏状态所需的峰值 VRAM,如下更新的公式所示:

Hidden States Memory (GB)=context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{context length} \times \text{hidden states dim}}{1024^3}

类似于我们在 500K Context Training 版本中对交叉熵损失所做的工作,新的实现 会自动调整隐藏状态的批处理大小。用户也可以通过 unsloth_grpo_mini_batch来控制此行为。然而,将 unsloth_grpo_mini_batch 增加到超过最佳值可能会引入轻微的性能提升或变慢(通常是更快),与之前的损失函数相比。

然而,在一次 GPT-OSS 运行中(context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2),设置 unsloth_grpo_mini_batch = 1unsloth_logit_chunk_multiplier = 4 会导致 几乎不影响速度,同时将 VRAM 使用大约减少 5 GB 与旧版本的 Unsloth 相比。

circle-check

🌵为 log softmax 卸载激活

在本次发布的开发过程中,我们发现当在隐藏状态的批次维度上进行铺瓦(tiling)时,激活在融合的 logits 和 logprobs 计算之后并未被卸载。由于 logits 是使用 hidden_states[i] @ lm_head逐批次计算的,因此现有的激活卸载和梯度检查点逻辑(设计为在模型的前向传递内工作)在这种情况下并不适用。

为了解决此问题,我们添加了明确的逻辑以在模型前向传递之外卸载这些激活,如下面的 Python 伪代码所示:

circle-check

配置参数:

如果您不配置 unsloth_grpo_mini_batchunsloth_logit_chunk_multiplier,我们将为您 基于您可用的 VRAM 并根据上下文长度的大小自动调整这两个参数。 下面是如何在您的 GRPO 运行中更改这些变量:

下面的示意图展示了这些优化和 unsloth_grpo_mini_batchunsloth_logit_chunk_multiplier 的可视化效果。

这 3 个矩阵代表总体上更大的批次或 unsloth_grpo_mini_batch (由黑色方括号的数量表示),每个矩阵的行表示该 unsloth_logit_chunk_multiplier 通过(由红色方括号的数量表示)对序列长度进行分块的数量。

📼用于 RL 的 vLLM

对于 RL 工作流,推理/生成阶段是主要瓶颈。为了解决这一问题,我们使用了 vLLMarrow-up-right,与普通生成相比,它将生成速度提高了最多 11 倍。自从去年 GRPO 普及以来,vLLM 已成为包括 Unsloth 在内的大多数 RL 框架的核心组件。我们要向 vLLM 团队及所有贡献者表示感谢,因为他们在提升 Unsloth 的 RL 表现方面起到了关键作用!

要尝试更长上下文的 RL,您可以使用任何现有的 GRPO 笔记本 (或在本地更新 Unsloth):

致谢:非常感谢 Hugging Face 团队和其库为 Unsloth 提供支持并使之成为可能。

最后更新于

这有帮助吗?