🌀7x 長いコンテキストでの GRPO 強化学習
Unsloth が超長コンテキスト RL ファインチューニングを可能にする方法を学びます。
強化学習(RL)の最大の課題は長い推論トレースをサポートすることです。私たちは新しいバッチ処理アルゴリズムを導入して〜7倍長いコンテキスト (場合によっては12倍以上) FA3、カーネル、およびチャンク化損失を使用する他の最適化されたセットアップと比較して、精度や速度の劣化なしでのRLトレーニング。
Unslothは現在gpt-oss QLoRAを 380Kコンテキスト 単一の192GB NVIDIA B200 GPUでトレーニングします
24GB VRAMでは、gpt-ossは20Kコンテキストに達し、32Kは 。-8B QLoRA
Unsloth GRPO RLはLlama、Gemmaおよびすべてのモデルで長いコンテキストを自動サポートします
私たちの新しいデータ移動およびバッチ処理カーネルとアルゴリズムはより多くの コンテキスト を可能にします:
動的な 平坦化されたシーケンスチャンク処理 巨大なロジットテンソルの具現化を避けるためと
ログソフトマックスのオフロード アクティベーションは時間経過による静かなメモリ増加を防ぎます。
Unslothのすべての機能を組み合わせることができます:
Unslothの 重み共有 機能は vLLM と私たちのスタンバイ機能と組み合わせて メモリ効率の良い RL
Unslothの Flex Attention 長コンテキストgpt-oss用と私たちの 500K Context Training
のFloat8トレーニング、 FP8 RL およびUnslothの 非同期勾配チェックポイント など多くの機能
🎉始め方
開始するには、既存の任意の GRPOノートブック (またはローカルでUnslothを更新):
gpt-oss-20b や Dr. GRPO のような他のものに設定することもできる点に注意。
UnslothをRLタスクに採用することは、大規模モデルを効率的に管理するための堅牢なフレームワークを提供します。Unslothの改善を効果的に活用するには:
ハードウェアの推奨事項:VRAMの最適な利用のためにNVIDIA H100または同等の使用を推奨します。
構成のヒント:次を確認してください
batch_sizeおよびgradient_accumulation_steps設定が最良のパフォーマンスのために計算資源と一致していること。
最新のアップデートを入手するにはUnslothを最新のPypiリリースに更新してください:
私たちのベンチマークは、GPT OSSおよびQwen3-8Bに対して以前のバージョンと比較したメモリ節約を示しています。下の両方のプロットは( スタンバイなしで) batch_size = 4 および gradient_accumulation_steps=2 で実行されました。スタンバイは設計上すべてのVRAMを使用するためです。
私たちのベンチマークでは、BF16 GRPOをHugging Faceのすべての最適化(カーネルライブラリ内のすべてのカーネル、Flash Attention 3、チャンク化損失カーネルなど)を有効にした状態と比較します:
🔢平坦化されたシーケンス長チャンク処理
以前、Unslothはバッチ次元でのチャンク処理を通じてロジットテンソルの完全な具現化を回避することでRLのメモリ使用量を削減しました。順伝播中にロジットを具現化するために必要なVRAMの概算は式(1)に示されています。
この定式化を使用すると、次の構成は batch_size = 4, context_length = 8192、および vocab_dim = 128,000 の場合、概ね必要になります 3.3 GBのVRAM ロジットテンソルを格納するために。
去年の Long Context gpt-oss を経て、私たちはGRPO向けに融合損失アプローチを導入しました。このアプローチは一度に単一のバッチサンプルのみを処理することを保証し、ピークメモリ使用量を大幅に低減します。同じ構成では、VRAM使用量は概ね 0.83 GBに低下し、式(2)に反映されています。


このアップデートでは、同じアイデアをさらに拡張し、 シーケンス次元 にもまたチャンク処理を導入します。バッチ_size × context_length全体のためにロジットを一度に具現化する代わりに、これらの次元を平坦化して構成可能な乗数を使って小さいチャンクで処理します。これにより、Unslothはピークメモリ使用量を増やすことなく大幅に長いコンテキストをサポートできます。 (batch_size × context_length) 空間を一度に処理する代わりに、これらの次元を平坦化して小さなチャンクで処理します。
下の図5では、乗数として max(4, context_length // 4096)を使用していますが、望ましいメモリとパフォーマンスのトレードオフに応じて任意の乗数を指定できます。この設定では、同じ例の構成(batch_size = 4, context_length = 8192, vocab_dim = 128,000)は現在必要とするのはわずか 0.207 GBのVRAM ロジット具現化のために。




このアップデートはコンパイルされた chunked_hidden_states_selective_log_softmax にも反映されており、現在はバッチ次元とシーケンス次元の両方でのチャンク処理をサポートします。ロジットテンソル([batch_size, context_length, vocab_dim])を保持するために、常にバッチ次元でチャンク化されます。追加のシーケンスチャンク処理は unsloth_logit_chunk_multiplier によってGRPO構成で制御されます;未設定の場合はデフォルトで max(4, context_length // 4096)となります。下の例では、 input_ids_chunk[0] は最適化2における隠れ状態ミニバッチのサイズに対応します。
私たちはカスタムコンパイルオプション付きのtorch.compileを利用してVRAMを削減し、速度を向上させます。
すべてのチャンク化されたロジットは精度を保つためにfloat32にアップキャストされます。
私たちはロジットソフトキャッピング、温度スケーリングおよびその他すべての機能をサポートします。
👻隠れ状態のチャンク処理
また、より長いコンテキスト長では隠れ状態がメモリ使用量の大きな要因になることを観察しました。デモのために、次を仮定します hidden_states_dim=4096。対応するメモリ使用量はロジットの場合と同様の定式化に従い、以下に示されます。
での batch_size = 8 および context_length = 64000の場合、これは概ねVRAM使用量が 2 GBになります。このリリースでは、対数確率計算中の隠れ状態テンソルに対してバッチ次元でのオプションのチャンク処理を導入します。これによりVRAM使用量はバッチサイズで割られ、この場合は 0.244 GBになります。これにより隠れ状態を具現化するために必要なピークVRAMが削減され、以下の更新された式に反映されています:
私たちのリリースでのクロスエントロピー損失と同様に、 500K Context Training 新しい実装は 自動的に隠れ状態のバッチ処理を調整します。ユーザーはこの挙動を以下で制御することもできます unsloth_grpo_mini_batch経由で。ただし、 unsloth_grpo_mini_batch を最適値より大きくすると、以前の損失関数と比較して若干の性能向上または低下(通常は高速化)を引き起こす可能性があります。
しかし、GPT-OSSの実行中(context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2)、 unsloth_grpo_mini_batch = 1 および unsloth_logit_chunk_multiplier = 4 を設定すると、 ほとんどまたはまったく速度低下がなく、VRAM使用量を約5 GB削減します 以前のUnslothバージョンと比較して。

注意: 図3および図4では、このセットアップでの最大有効バッチサイズ(この場合8)を使用しています。有効バッチサイズは次のように計算されます batch_size × gradient_accumulation_steps、したがって 4 × 2 = 8となります。RLにおける有効バッチサイズの動作の詳細な説明は、私たちの 高度なRLドキュメント.
🌵をご参照ください。
ログソフトマックスのためのアクティベーションのオフロード このリリースの開発中に、隠れ状態をバッチ次元でタイル処理すると、融合されたロジットとlogprobsの計算後にアクティベーションがオフロードされていないことを発見しました。ロジットはhidden_states[i] @ lm_head
を使用して一度に1バッチずつ計算されるため、モデルの順伝播内で動作するように設計された既存のアクティベーションオフロードおよび勾配チェックポイントのロジックはこのケースには適用されませんでした。
注意: torch.autograd.backward(output, grad_output) この機能はバッチ次元でチャンク処理する場合、またはunsloth_grpo_mini_batch > 1 unsloth_grpo_mini_batch = 1のときにのみ効果的です。順伝播中にすべての隠れ状態が一度に具現化される場合(すなわち
✨)、逆伝播はアクティベーションがオフロードされているかどうかに関わらず同じ量のGPUメモリを必要とします。この場合、アクティベーションのオフロードはメモリ使用量を減らさずにわずかなパフォーマンス低下を招くため、利点はありません。
パラメータの設定: unsloth_grpo_mini_batch および unsloth_logit_chunk_multiplierもしあなたが を設定しない場合、私たちは 利用可能なVRAMおよびコンテキスト長のサイズに基づいてこれら二つのパラメータを自動的に調整します
unsloth_logit_chunk_multiplier = 2 unsloth_grpo_mini_batch および unsloth_logit_chunk_multiplier 最適化と

の可視化は下の図で見ることができます。 unsloth_grpo_mini_batch 3つの行列は全体の大きなバッチまたは unsloth_logit_chunk_multiplier (黒い角括弧の数で表現)を表し、各行列の行はシーケンス長をどのように
📼がチャンク化するかを表します(赤い角括弧の数で表現)。
RL向けvLLMRLワークフローでは、推論/生成フェーズが主要なボトルネックです vLLM。これに対処するために、私たちは
を利用しており、通常の生成と比べて最大11倍の生成高速化を実現しています。GRPOが昨年普及して以来、vLLMはUnslothを含むほとんどのRLフレームワークのコアコンポーネントでした。UnslothのRLをより良くするために重要な役割を果たしているvLLMチームとその全ての貢献者に感謝の意を表します! GRPOノートブック (またはローカルでUnslothを更新):
gpt-oss-20b より長いコンテキストのRLを試すには、既存の任意の
- GSPOを使用できます、
最終更新
役に立ちましたか?

