🌀7 倍更长上下文的强化学习 GRPO
了解 Unsloth 如何实现超长上下文的 RL 微调。
强化学习(RL)最大的挑战是支持长推理轨迹。我们正在引入新的批处理算法以实现约7 倍更长的上下文 (可能超过 12 倍) RL 训练在准确性或速度上不劣于使用 FA3、内核和分块损失的其他优化设置。
Unsloth 现在使用以下配置训练 gpt-oss QLoRA: 380K 上下文 在单个 192GB 的 NVIDIA B200 GPU 上
在 24GB VRAM 上,gpt-oss 达到 20K 上下文,Qwen3-VL 可达 32K, Qwen3-VL-8B QLoRA
Unsloth GRPO RL 可与 Llama、Gemma 及所有模型自动支持更长上下文一起运行
我们新的数据移动和批处理内核与算法解锁了更多 上下文 通过:
动态 扁平化序列分块 以避免物化巨大的 logits 张量并且
卸载 log softmax 激活,这可以防止随时间静默增长的内存占用。
您可以在 Unsloth 中将所有特性结合使用:
Unsloth 的 Flex Attention 中的备用(Standby)特性,适用于长上下文的 gpt-oss,以及我们的 500K Context Training
中的 Float8 训练, FP8 RL 以及 Unsloth 的 异步梯度检查点(async gradient checkpointing) 以及更多功能
🎉入门
要开始,您可以使用任何现有的 GRPO 笔记本 (或在本地更新 Unsloth):
将 Unsloth 应用于您的 RL 任务可为高效管理大规模模型提供稳健的框架。为了有效利用 Unsloth 的增强功能:
硬件建议:建议使用 NVIDIA H100 或等效设备以实现最佳 VRAM 利用率。
配置提示:请确保
batch_size和gradient_accumulation_steps设置与您的计算资源对齐以获得最佳性能。
将 Unsloth 更新到最新的 Pypi 版本以获取最新更新:
我们的基准测试突出了与早期版本相比在 GPT OSS 和 Qwen3-8B 上实现的内存节省。下面两个图(不含 备用(standby))是在 batch_size = 4 和 gradient_accumulation_steps=2 的情况下运行的,因为 standby 设计上会使用所有 VRAM。
在我们的基准中,我们将 BF16 GRPO 与在所有优化启用情况下的 Hugging Face 进行比较(kernels 库中的所有内核、Flash Attention 3、分块损失内核等):
🔢扁平化序列长度分块
以前,Unsloth 通过在批次维度上分块来避免 logits 张量的完全物化,从而减少了 RL 的内存使用。前向传递期间物化 logits 所需 VRAM 的粗略估计如公式(1)所示。
使用此公式,配置为 batch_size = 4, context_length = 8192,并且 vocab_dim = 128,000 将大约需要 3.3 GB 的 VRAM 来存储 logits 张量。
通过 Long Context gpt-oss 去年,我们随后为 GRPO 引入了融合损失方法。该方法确保一次仅处理单个批样本,从而显著降低峰值内存使用。在相同配置下,VRAM 使用降至约 0.83 GB,如公式(2)所示。


在本次更新中,我们通过引入沿 序列维度 的分块进一步扩展了相同思路。我们不再一次性为整个 (batch_size × context_length) 空间物化 logits,而是将这些维度扁平化并使用可配置的乘数按较小块处理。这使 Unsloth 在不增加峰值内存使用的情况下支持显著更长的上下文。
在下面的图 5 中,我们使用的乘数为 max(4, context_length // 4096),尽管可以根据所需的内存-性能权衡指定任意乘数。使用此设置,相同示例配置(batch_size = 4, context_length = 8192, vocab_dim = 128,000)现在仅需要 0.207 GB 的 VRAM 用于 logits 的物化。




此更新反映在下方编译的 chunked_hidden_states_selective_log_softmax 中,该实现现在支持跨批次和序列两个维度的分块。为了保留 logits 张量( [batch_size, context_length, vocab_dim] ),它始终在批次维度上进行分块。额外的序列分块由 GRPO 配置中的 unsloth_logit_chunk_multiplier 控制;如果未设置,则默认为 max(4, context_length // 4096)。在下面的示例中, input_ids_chunk[0] 对应于优化 2 中隐藏状态小批次的大小。
我们使用带有自定义编译选项的 torch.compile 以减少 VRAM 并提高速度。
所有分块的 logits 都会被提升为 float32 以保留精度。
我们支持 logit 软上限、温度缩放以及所有其他功能。
👻隐藏状态分块
我们还观察到,在更长的上下文长度下,隐藏状态可能成为内存使用的重要来源。为演示起见,我们假设 hidden_states_dim=4096。相应的内存使用遵循与 logits 情况类似的公式,如下所示。
在 batch_size = 8 和 context_length = 64000的情况下,这将导致大约 2 GB的 VRAM 使用。在此版本中,我们引入了在计算对数概率时对隐藏状态张量在批次维度上的可选分块。这将使 VRAM 使用按批次大小划分,在本例中为 0.244 GB。这减少了物化隐藏状态所需的峰值 VRAM,如下更新的公式所示:
类似于我们在 500K Context Training 版本中对交叉熵损失所做的工作,新的实现 会自动调整隐藏状态的批处理大小。用户也可以通过 unsloth_grpo_mini_batch来控制此行为。然而,将 unsloth_grpo_mini_batch 增加到超过最佳值可能会引入轻微的性能提升或变慢(通常是更快),与之前的损失函数相比。
然而,在一次 GPT-OSS 运行中(context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2),设置 unsloth_grpo_mini_batch = 1 和 unsloth_logit_chunk_multiplier = 4 会导致 几乎不影响速度,同时将 VRAM 使用大约减少 5 GB 与旧版本的 Unsloth 相比。

注意: 在图 3 和图 4 中,我们使用了最大的有效批次大小,在此设置中为 8。有效批次大小计算为 batch_size × gradient_accumulation_steps,得到 4 × 2 = 8。有关有效批次大小在 RL 中如何工作的更深入解释,请参见我们的 高级 RL 文档.
🌵为 log softmax 卸载激活
在本次发布的开发过程中,我们发现当在隐藏状态的批次维度上进行铺瓦(tiling)时,激活在融合的 logits 和 logprobs 计算之后并未被卸载。由于 logits 是使用 hidden_states[i] @ lm_head逐批次计算的,因此现有的激活卸载和梯度检查点逻辑(设计为在模型的前向传递内工作)在这种情况下并不适用。
为了解决此问题,我们添加了明确的逻辑以在模型前向传递之外卸载这些激活,如下面的 Python 伪代码所示:
注意: 仅当在批次维度上进行分块或当 unsloth_grpo_mini_batch > 1时,此特性才有效。如果在前向传递期间一次性物化所有隐藏状态(即 unsloth_grpo_mini_batch = 1),则无论是否卸载激活,反向传递都需要相同量的 GPU 内存。由于在这种情况下激活卸载会引入轻微的性能减慢且并不减少内存使用,因此并无益处。
✨配置参数:
如果您不配置 unsloth_grpo_mini_batch 和 unsloth_logit_chunk_multiplier,我们将为您 基于您可用的 VRAM 并根据上下文长度的大小自动调整这两个参数。 下面是如何在您的 GRPO 运行中更改这些变量:
下面的示意图展示了这些优化和 unsloth_grpo_mini_batch 和 unsloth_logit_chunk_multiplier 的可视化效果。

这 3 个矩阵代表总体上更大的批次或 unsloth_grpo_mini_batch (由黑色方括号的数量表示),每个矩阵的行表示该 unsloth_logit_chunk_multiplier 通过(由红色方括号的数量表示)对序列长度进行分块的数量。
📼用于 RL 的 vLLM
对于 RL 工作流,推理/生成阶段是主要瓶颈。为了解决这一问题,我们使用了 vLLM,与普通生成相比,它将生成速度提高了最多 11 倍。自从去年 GRPO 普及以来,vLLM 已成为包括 Unsloth 在内的大多数 RL 框架的核心组件。我们要向 vLLM 团队及所有贡献者表示感谢,因为他们在提升 Unsloth 的 RL 表现方面起到了关键作用!
要尝试更长上下文的 RL,您可以使用任何现有的 GRPO 笔记本 (或在本地更新 Unsloth):
致谢:非常感谢 Hugging Face 团队和其库为 Unsloth 提供支持并使之成为可能。
最后更新于
这有帮助吗?

