🎱FP8 强化学习

使用 Unsloth 在 FP8 精度下训练强化学习 (RL) 和 GRPO。

我们正在为强化学习引入 FP8 精度训练,使得在……上实现 FP8 GRPO 成为可能, 消费级 GPU (RTX 40、50 等)。DeepSeek-R1 展示了 FP8 的强大能力,通过 Unsloth,Qwen3-1.7B 的 FP8 GRPO 现在仅在 5GB 显存.

更快的 RL 推理至关重要,因为它是 RL 中计算量最大负载。我们与 TorchAOarrow-up-right 来自 PyTorch 的团队合作,以实现性能提升且不损失精度。

  • 约 1.4× 更快 通过 RL 推理(via) vLLMarrow-up-right • 与 BF16 和 FP16 相比上下文长度提高 2 倍

  • 显存减少 60% 并且 上下文长度提高 10× 比其他 FP8 RL 实现更长的上下文

  • Unsloth 是 唯一的框架 能使 FP8 RL LoRA 在消费级 GPU(例如 NVIDIA GeForce RTX 40 和 50 系列)上运行。同时也支持 H100、H200、B200 等。

  • 使用 load_in_fp8 = TrueFastLanguageModel 中以启用 FP8 RL。

  • 虽然 Qwen3-8B 可放入 16GB 显存,但免费的 Colab NVIDIA Tesla T4 GPU 不支持 FP8。因此我们的笔记本使用 24GB L4 GPU,可容纳 Qwen3-14B.

笔记本: Qwen3-8B FP8 GRPOarrow-up-right 并且 Llama-3.2-1B FP8 GRPOarrow-up-right

circle-check

我们的 FP8 支持使用 Unsloth 的 权重共享功能,进一步减少显存使用约 50%,使得 上下文增加 10× 而无需精度损失。我们使用 vLLMarrow-up-right 用于快速推理,并且,我们的技术如 Unsloth 的 备用(Standby) 并且 灵活注意力(Flex Attention) 进一步减少显存使用。TorchAO 实现了按需的通用 FP8,所以 Llama、Gemma、Mistral 等都能工作。我们还已将 上传了 大多数 FP8 模型(包括 Qwen3)。

奖励图显示 FP8 遵循与 BF16 相同的趋势

🌻FP8 与 BF16 的训练比较

研究表明 FP8 训练在很大程度上可以匹配 BF16 的精度,如果你以 FP8 提供模型, 在相同精度下进行训练和部署 有助于保持精度。此外,在 H100 上 FP8 相比 BF16 在吞吐量上提高 1.6 倍,并具有 2 倍更低的内存使用。

权重尺度与 FP8 类型

量化训练存储低精度权重(例如 FP8)以及更高精度的尺度(FP16/BF16/FP32)。你可以通过近似公式恢复原始权重: 原始权重 ≈ 量化权重 * 权重尺度

尺度将权重的范围映射到 FP8 的可表示范围。更多尺度通常能提高精度,但尺度会消耗额外的高精度内存,因此这是一个权衡。 例如,DeepSeek R1arrow-up-right主要倾向于块量化。

根据 vLLM 的 llm-compressorarrow-up-right定义,有 3 种常见的 FP8 类型。我们在所有 3 种类型上对 Qwen3-8B 进行了基准测试,并检查了吞吐量、MMLU Pro 和 GQPA Diamond。我们发现 FP8 块式或按通道(-FP8-Dynamic)在 精度和吞吐量方面是最好的

类型
吞吐量
MMLU Pro
GQPA Diamond

Bfloat16 基线

11,367

62.04%

28.79%

块式(Block-wise)

每块(128x128)使用尺度

12,041

62.37%

29.29%

按通道(Per-Channel)

每行或每列使用 1 个尺度

12,963

61.89%

31.82%

按张量(Per-Tensor)

整个张量使用 1 个尺度

13,681

61.83%

27.78%

FP8 性能基准

通过 vLLM 使用 Unsloth 的 FP8 RL 推理通常比 BF16 快 1.4 倍。如果模型更大,你可能会看到更大的速度提升!

精度 训练损失 基准

我们测试了多个模型,包括 Qwen3-4B、8B、14B、Llama 3.2 1B、3B、Qwen3-VL-2B、Qwen3-VL 4B 等等。所有模型均以 BF16 和 FP8 进行了训练。如图所示, BF16 与 FP8 在 SFT 期间的损失曲线紧密跟随彼此。在训练损失方面,两种数据类型之间没有太大差别:

针对 GRPO,鉴于生成差异,目标是查看奖励图是否至少匹配而不发生发散(例如某些 Qwen3-14B 的运行可能不完全相同)

⛩️推理占 RL 训练的 96%

在 RL 中,我们必须调用 LLM / VLM 生成一些可能的候选解,然后对每个可能的解进行评分,接着 奖励好的解,惩罚错误的答案。为了实现最大效率,我们必须使推理几乎占训练运行的 100%。在 Unsloth 中,我们 设法使训练仅占整个 RL 运行的 <4%,96% 完全是 vLLM 推理。

例如对于 Qwen-3-8B,在较短序列长度上速度为 1.15× 更快,vLLM 的 FP8 推理本身(不含训练)吞吐量也快 1.15×。我们看到 Unsloth 中的 RL 运行在处理的 tokens 上也达到了 1.15× 的提升,这表明 训练开销在 Unsloth 中可以忽略不计。

🔢显存使用减少 60%

理论上,你会预期内存节省大致 等于模型权重的内存,因为:优化器状态仍以高精度存储,激活也以高精度存储(目前如此)。我们的发现与理论一致。对于 LoRA 微调,我们观察到: 约节省 30 GB 用于 Qwen3-32B,约节省 14 GB 用于 Qwen2.5-14B 并且 约节省 8 GB 用于 Qwen3-8B

对于 在 Qwen3-32B 上的 BF16 LoRA 微调, 我们在更大的批量大小下出现了 OOM(内存不足),不得不缩小批量。 FP8 变体没有此类问题,我们可以使用 更大的批量大小 而不会 OOM。

另外提醒,在 Unsloth 中我们共享 vLLM 的权重内存空间,如 内存高效 RL 中所介绍 —— 我们将这一技巧带入了 FP8 领域!

80GB GPU
推理引擎
训练引擎

模型权重

8GB 共享 FP8

<<< 共享

多用途

72GB 空间

KV 缓存

激活、梯度、优化器状态

要启用 Unsloth 备用(Standby) 用于 FP8(或 BF16)RL,只需在任何 Unsloth 导入之前将下面内容添加到所有 RL / GRPO 训练运行中:

如何使用 FP8 RL / 安装

只需更新 Unsloth 或在新的虚拟环境中安装 Unsloth,以用于 H100、L4、RTX 50x、RTX 40x、H200、B200 以及任何在 RTX 4090 之后发布的 NVIDIA GPU(消费级或数据中心级)。

更新 Unsloth: pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo或创建一个新环境:

然后使用 load_in_fp8 = True 你就可以开始了!我们会自动将模型名称映射到 Float8 变体,或者在运行时将模型转换为 Float8!

例如在 RTX 5090 上(提醒设置 os.environ["UNSLOTH_VLLM_STANDBY"] = "1" )

然后使用我们的 2 个 FP8 笔记本用于 RL:

💿实现 FP8 训练

我们的第一个参考点是 transformers,它已经以几种方式支持 FP8。其中一种是块量化的矩阵乘实现:当某层接收到 16 位激活时,它会对其进行量化并传递给自定义的 FP8 矩阵乘内核。在将其接线并在 NVIDIA H100 上基准测试后,我们看到的结果与预期相反:微调大约 慢 4× 比标准 BF16 微调更慢。

🔥TorchAO 合作

因此我们与 TorchAOarrow-up-right 团队合作(非常感谢 Andrewarrow-up-right)将 TorchAO 的 FP8 支持整合进我们的 RL 工作负载,观察到约 1.4× 的吞吐量加速 并且最多可达 60% 的模型内存使用减少。总体来说:

  • 我们将冻结的 LoRA 权重以 FP8 存储。

  • 在前向传递中,我们对输入激活应用动态 FP8 量化,同时将可训练的 LoRA 适配器保留为 BF16。

  • 这些 FP8 权重与 vLLM 模型权重共享相同缓冲区,因此内存中在任何时候只有一个 FP8 的模型副本(没有“双模型”内存开销)。

  • 在反向传递中,我们对 LoRA 权重解量化,因此所有梯度计算均在 BF16 中进行以获得更好的精度。

该通用设置适用于所有支持的 RL 算法,包括 GSPO、Dr. GRPO、PPO 和 DPO。

TorchAO 为训练和推理提供 PyTorch 原生的 FP8 支持,提供多种尺度粒度,包括张量级、行级和 128x128 块级(原型)。TorchAO 的 FP8 支持在 27B 规模下使用行级尺度时可将推理吞吐量提高最多 1.64xarrow-up-right 。详情请参见 TorchAO 的 FP8 自述(README)arrow-up-right.

TorchAO 的块量化 FP8 矩阵乘

我们使用了 TorchAO 的块量化 FP8 矩阵乘实现,提供了:

  • 达到 BF16 吞吐量的 80%

  • 而不降低损失或训练稳定性

因此有一段时间,这成为了我们的默认 FP8 矩阵乘后端,直到 FBGEMM 赶上 —— 如果你的 GPU 支持,我们现在默认使用 FBGEMM 的实现!当前版本的 Unsloth 可以根据已安装的软件自动选择最佳后端。如果你安装了合适的包,就不必将性能浪费掉 🙂

附注:我们也曾尝试 DeepSeek 的 DeepGEMM,但无法将其端到端完全集成以进行干净的苹果对苹果比较。

🐦按需 TorchAO FP8 量化

非常感谢 Andrewarrow-up-right 来自 TorchAO 的贡献,Unsloth 的 FP8 RL 还允许你在模型加载时进行按需量化并将其传递给 vLLM。这样,你无需自己显式量化模型(我们为你处理)。你可以通过在模型加载参数中设置 load_in_fp8 = True 来实现,如果找不到合适的预量化检查点,我们将执行离线 FP8。

🎉Unsloth 的 FP8 上传

为方便起见,我们已在 Hugging Face 上上传了 FP8 Dynamic 和 FP8 Block 模型。你可以将它们用于 FP8 训练或通过 vLLM/SGLang 等进行高效且快速的部署/服务。

FP8 Dynamic 在训练速度和显存使用上略优于 FP8 Block,但在精度上有小幅权衡。 请参见此处 获取我们完整的 FP8 量化列表,但此处列出最受欢迎的:

模型
FP8 上传

Qwen3(2507)

4B 指令版 — FP8arrow-up-right 4B 思考版 — FP8arrow-up-right 30B-A3B 指令版 — FP8arrow-up-right 30B-A3B 思考版 — FP8arrow-up-right

Qwen3-VL

4B 指令版 — FP8arrow-up-right 4B 思考版 — FP8arrow-up-right 8B 指令版 — FP8arrow-up-right 8B 思考版 — FP8arrow-up-right

Mistral Small 3.2

💁致谢

非常感谢整个 PyTorch 和 TorchAO 团队的帮助与合作!特别感谢:Andrew Or、Jerry Zhang、Supriya Rao、Scott Roy 和 Mergen Nachin 在 FP8 RL 方面的诸多讨论以及将其集成到 Unsloth 中的帮助!也感谢 Executorch 团队!

最后更新于

这有帮助吗?