🎱FP8 強化学習

Unsloth で FP8 精度で強化学習(RL)と GRPO をトレーニングします。

RL向けのFP8精度トレーニングを導入しており、これによりFP8でのGRPOが今や可能になりました(実行環境: コンシューマー向けGPU (RTX 40、50など)。DeepSeek-R1はFP8の強力さを示し、Unsloth を使えば Qwen3-1.7B の FP8 GRPO が今やただ 5GBのVRAMで動作します.

RLにおいて最も計算負荷の高いワークロードであるため、より高速なRL推論は極めて重要です。我々は TorchAOarrow-up-right (PyTorch側)と協力して、精度を損なうことなく性能向上を実現しました。

  • 約1.4×高速化 によるRL推論 vLLMarrow-up-right • BF16やFP16と比べて2倍長いコンテキスト

  • 60%少ないVRAM および 10倍長い 他のFP8 RL実装よりも長いコンテキスト

  • Unslothは 唯一のフレームワーク で、コンシューマー向けGPU(例:NVIDIA GeForce RTX 40および50シリーズ)上でFP8 RL LoRAを動作させます。H100、H200、B200などでも動作します。

  • 使用する load_in_fp8 = TrueFastLanguageModel 内に設定することでFP8 RLを有効にします。

  • Qwen3-8Bは16GBのVRAMに収まりますが、無料のColabのNVIDIA Tesla T4 GPUは FP8をサポートしていません。したがって我々のノートブックでは Qwen3-14Bが収まる24GBのL4 GPUを使用しています.

ノートブック: Qwen3-8B FP8 GRPOarrow-up-right および Llama-3.2-1B FP8 GRPOarrow-up-right

circle-check

我々のFP8サポートはUnslothの 重み共有機能を利用しており、VRAM使用量をさらに 50%削減し、 10倍以上の コンテキストを精度劣化なしに可能にします。我々は高速推論のために vLLMarrow-up-right を使い、またUnslothのような手法(例: Standby および Flex Attention )でVRAM使用量をさらに削減します。TorchAOはオンザフライでの汎用FP8を可能にするため、Llama、Gemma、Mistralなども動作します。さらに我々は をアップロードしました ほとんどのFP8モデル(Qwen3を含む)。

報酬プロットはFP8がBF16と同じ傾向に従うことを示しています

🌻FP8対BF16のトレーニング

研究ではFP8トレーニングが概ねBF16の精度に匹敵できることが示されており、モデルをFP8でサーブする場合、 同じ精度でのトレーニングとサービング が精度維持に役立ちます。またFP8はBF16に比べてH100上でスループットが1.6倍高く、メモリ使用量は2倍低くなります。

重みスケールとFP8の種類

量子化されたトレーニングでは、低精度の重み(例:FP8)と高精度のスケール(FP16/BF16/FP32)を保存します。おおよそ次の式で元の重みを復元できます: original_weight ≈ quantized_weight * weight_scale

スケールは重みのレンジをFP8の表現可能なレンジにマッピングします。スケール数が多いほど精度が改善することが多いですが、スケールは追加の高精度メモリを消費するためトレードオフになります。 DeepSeek R1arrow-up-rightは例えば、主にブロック量子化を好みます。

vLLMの llm-compressorarrow-up-rightで定義される3つの一般的なFP8タイプがあります。我々はQwen3-8Bを3種類すべてでベンチマークし、スループット、MMLU Pro、GQPA Diamondもチェックしました。我々の結論は FP8のブロック単位またはチャネル毎(-FP8-Dynamic)が最良である ということです(精度とスループットの観点で)。

タイプ
スループット
MMLU Pro
GQPA Diamond

Bfloat16ベースライン

11,367

62.04%

28.79%

ブロック単位

ブロックごとのスケール(128×128)

12,041

62.37%

29.29%

チャネル単位

行または列ごとに1つのスケール

12,963

61.89%

31.82%

テンソル単位

テンソル全体に1つのスケール

13,681

61.83%

27.78%

FP8パフォーマンスベンチマーク

vLLM経由のUnsloth FP8 RL推論は一般にBF16より約1.4倍高速です。モデルが大きいほどさらに速度改善が見られる可能性があります!

精度 トレーニング損失ベンチマーク

我々はQwen3-4B、8B、14B、Llama 3.2 1B、3B、Qwen3-VL-2B、Qwen3-VL-4Bなど多くのモデルをテストしました。すべてBF16とFP8の両方で訓練しました。プロットに見られるように、 SFT中のBF16とFP8の損失曲線はお互いにほぼ一致します。トレーニング損失の観点では2つのデータ型の間に大きな差はありません:

GRPOに関しては、生成の違いのため、報酬プロットが少なくとも一致して発散しないかを確認することが目的です(例えばQwen3-14Bの実行は必ずしも完全に同一にはならないことがあります)。

⛩️推論はRLトレーニングの96%に相当

RLでは、LLM / VLMを呼び出していくつかの候補解を生成し、それぞれの解を評価し、 良い解には報酬を与え、悪い回答にはペナルティを与えます。最大効率を達成するために、推論をトレーニング実行のほぼ100%にする必要があります。Unslothでは、 我々はトレーニングを全RL実行のわずか<4%に抑え、96%を純粋なvLLM推論にすることに成功しました。

例えばQwen-3-8Bでは、短いシーケンス長で1.15×高速化され、推論(トレーニングを含まない)におけるvLLM FP8自体のスループットも1.15×高速でした。Unslothにおける我々のRL実行でも処理トークン数で1.15×の高速化が確認され、 Unslothではトレーニングオーバーヘッドが無視できるほど小さいことが示されます。

🔢メモリ使用量を60%削減

理論的には、メモリ節約は概ね モデルの重みメモリに等しいと期待されます。なぜなら:オプティマイザ状態は依然として高精度で保存され、活性化も高精度で保存されるからです(現時点では)。我々の観察は理論と一致します。LoRA微調整では次を観察しました: 約30GB節約 用に Qwen3-32Bで約14GB節約 用に Qwen2.5-14B および 約8GB節約 用に Qwen3-8B

に関して BF16でのLoRA微調整において Qwen3-32Bでは、大きめのバッチサイズでOOMが発生し、バッチを縮小する必要がありました。 FP8バリアントではそのような問題は発生せず、我々は より大きなバッチサイズ をOOMなしで使用できました。

またUnslothでは、vLLMの重み用メモリ空間を共有する機能を導入しています(詳細は メモリ効率の良い RL で紹介しています)- このトリックをFP8領域にも持ち込みました!

80GB GPU
推論エンジン
トレーニングエンジン

モデル重み

8GB 共有 FP8

<<< 共有

多目的

72GBの空間

KVキャッシュ

活性化、勾配、オプティマイザ状態

レイヤーとしてではありません。これが量子化を複雑にします。特にMoE/MLPエキスパートは20Bパラメータのうち約19Bを占めます。 Unsloth Standby FP8(またはBF16)RL用には、Unslothのインポート前にすべてのRL / GRPOトレーニング実行に以下を追加してください:

FP8 RL の使い方 / インストール方法

単にUnslothをアップデートするか、H100、L4、RTX 50x、RTX 40x、H200、B200、およびRTX 4090以降にリリースされた任意のNVIDIA GPU(コンシューマーまたはデータセンター)用に新しい仮想環境へUnslothをインストールしてください。

Unslothをアップデートするには: pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zooあるいは新しい環境を作る場合:

その後、次を使用します load_in_fp8 = True これで準備完了です!モデル名をFloat8バリアントに自動マップするか、オンザフライでモデルをFloat8に変換します!

例えばRTX 5090上では(設定を忘れずに、 os.environ["UNSLOTH_VLLM_STANDBY"] = "1" )

それから我々の2つのFP8ノートブックをRL用に使ってください:

💿FP8トレーニングの実装

我々の最初の参照点は transformersで、ここでは既にいくつかの方法でFP8をサポートしています。その一つがブロック量子化されたmatmul実装です:あるレイヤが16ビット活性化を受け取ると、それを量子化してカスタムのFP8 matmulカーネルに渡します。これをNVIDIA H100で配線してベンチマークしたところ、我々が望んだ結果とは逆になり:微調整が 4倍遅く なり、標準のBF16微調整より遅くなりました。

🔥 TorchAOとのコラボ

そこで我々は TorchAOarrow-up-right チームと協力し(特に Andrewarrow-up-rightに大きな感謝を)、TorchAOのFP8サポートを我々のRLワークロードに組み込み、約 1.4×のスループット向上 を確認し、最大で 60%のモデルメモリ使用量削減を達成しました。概要として:

  • 我々は固定されたLoRA重みをFP8で保存します。

  • フォワードパスでは入力活性化に動的FP8量子化を適用し、訓練可能なLoRAアダプタはBF16のままにします。

  • これらのFP8重みはvLLMのモデル重みと同じバッファを共有するため、メモリ上にFP8のモデルコピーは一つだけ(“二重モデル”のメモリオーバーヘッドはなし)です。

  • バックワードパスではLoRA重みをデクオンタイズして、すべての勾配計算をBF16で行い精度を確保します。

この一般的なセットアップは、我々がサポートするすべてのRLアルゴリズム(例えば や Dr. GRPO のような他のものに設定することもできる点に注意。、Dr. GRPO、PPO、DPOなど)で動作します。

TorchAOはトレーニングと推論の両方に対するPyTorchネイティブのFP8サポートを提供し、テンソル単位、行単位、128x128のブロック単位(プロトタイプ)など様々なスケーリング粒度を提供します。TorchAOのFP8サポートは、行単位スケーリング粒度で27B規模において推論スループットを最大 1.64倍向上させることができますarrow-up-right 。詳細はTorchAOの FP8 READMEarrow-up-right.

をご覧ください。

TorchAOのブロック量子化FP8 matmul

  • 我々はTorchAOのブロック量子化FP8 matmul実装を使用し、以下を得ました:

  • BF16の約80%のスループット

損失やトレーニングの安定性を損なうことなく

しばらくの間、これは我々のデフォルトのFP8 matmulバックエンドとなりましたが、FBGEMMが追いつくまでのことでした。現在のUnslothはインストール状況に基づいて最適なバックエンドを自動選択できます。適切なパッケージがあれば、性能を無駄にする必要はありません 🙂

🐦追記:DeepSeekのDeepGEMMも試しましたが、エンドツーエンドで完全に統合してクリーンに比較できる状態にすることはできませんでした。

オンザフライのTorchAO FP8量子化 Andrewarrow-up-right 大変感謝します: load_in_fp8 = True TorchAOの貢献により、Unsloth FP8 RLはモデルロード時にオンザフライで量子化を行い、それをvLLMに渡すことができます。この方法では、ユーザーが明示的にモデルを量子化する必要はありません(我々が処理します)。モデルロード引数で

🎉load_in_fp8 = True, # ブロックFP8なら"block"、行FP8ならTrue、無効ならFalse

UnslothのFP8アップロード vLLM/SGLang など。

FP8 Dynamic は FP8 Block より若干高速なトレーニングと低い VRAM 使用量を提供しますが、精度に小さなトレードオフがあります。 利便性のため、我々はHugging FaceにFP8 DynamicおよびFP8 Blockモデルをアップロードしました。これらはFP8トレーニングや、 を介した効率的で高速なサービング/デプロイに利用できます。

モデル
こちらをご覧ください

Qwen3(2507)

14B — FP8arrow-up-right 32B — FP8arrow-up-right 4B インストラクト — FP8arrow-up-right 4B シンキング — FP8arrow-up-right

14B — FP8arrow-up-right 32B — FP8arrow-up-right 3B ベース — FP8arrow-up-right 30B-A3B インストラクト — FP8arrow-up-right

Llama 3.1

3B ベース — 動的なarrow-up-right · Blockarrow-up-right 3B インストラクト — 動的なarrow-up-right · Blockarrow-up-right で我々のFP8量子化一覧を確認できますが、ここではよく使われるものを示します: 動的なarrow-up-right · Blockarrow-up-right

Qwen3

8B インストラクト — FP8arrow-up-right 8B ベース — FP8arrow-up-right 4B — FP8arrow-up-right 0.6B — FP8arrow-up-right 1.7B — FP8arrow-up-right 8B — FP8arrow-up-right

Llama 3.3

で我々のFP8量子化一覧を確認できますが、ここではよく使われるものを示します: 動的なarrow-up-right · Blockarrow-up-right

Granite 4.0

30B-A3B シンキング — FP8 Dynamicarrow-up-right 8B シンキング — FP8 Dynamicarrow-up-right

Mistral Small 3.2

💁h-small —

270m —

最終更新

役に立ちましたか?