🌀7x 長いコンテキストでの GRPO 強化学習

Unsloth が超長コンテキスト RL ファインチューニングを可能にする方法を学びます。

強化学習(RL)の最大の課題は長い推論トレースをサポートすることです。私たちは新しいバッチ処理アルゴリズムを導入して〜7倍長いコンテキスト (場合によっては12倍以上) FA3、カーネル、およびチャンク化損失を使用する他の最適化されたセットアップと比較して、精度や速度の劣化なしでのRLトレーニング。

  • Unslothは現在gpt-oss QLoRAを 380Kコンテキスト 単一の192GB NVIDIA B200 GPUでトレーニングします

  • Qwen3-8B GRPOは達成します 110Kコンテキスト 80GB VRAMのH100で via vLLM およびQLoRA、そして 65K 用に gpt-oss BF16 LoRAで。

  • 24GB VRAMでは、gpt-ossは20Kコンテキストに達し、32Kは -8B QLoRA

  • Unsloth GRPO RLはLlama、Gemmaおよびすべてのモデルで長いコンテキストを自動サポートします

私たちの新しいデータ移動およびバッチ処理カーネルとアルゴリズムはより多くの コンテキスト を可能にします:

circle-info

Unslothのすべての機能を組み合わせることができます:

  1. Unslothの 重み共有 機能は vLLMarrow-up-right と私たちのスタンバイ機能と組み合わせて メモリ効率の良い RL

  2. Unslothの Flex Attention 長コンテキストgpt-oss用と私たちの 500K Context Training

  3. のFloat8トレーニング、 FP8 RL およびUnslothの 非同期勾配チェックポイントarrow-up-right など多くの機能

🎉始め方

開始するには、既存の任意の GRPOノートブック (またはローカルでUnslothを更新):

UnslothをRLタスクに採用することは、大規模モデルを効率的に管理するための堅牢なフレームワークを提供します。Unslothの改善を効果的に活用するには:

  • ハードウェアの推奨事項:VRAMの最適な利用のためにNVIDIA H100または同等の使用を推奨します。

  • 構成のヒント:次を確認してください batch_size および gradient_accumulation_steps 設定が最良のパフォーマンスのために計算資源と一致していること。

circle-check

私たちのベンチマークは、GPT OSSおよびQwen3-8Bに対して以前のバージョンと比較したメモリ節約を示しています。下の両方のプロットは( スタンバイなしで) batch_size = 4 および gradient_accumulation_steps=2 で実行されました。スタンバイは設計上すべてのVRAMを使用するためです。

私たちのベンチマークでは、BF16 GRPOをHugging Faceのすべての最適化(カーネルライブラリ内のすべてのカーネル、Flash Attention 3、チャンク化損失カーネルなど)を有効にした状態と比較します:

🔢平坦化されたシーケンス長チャンク処理

以前、Unslothはバッチ次元でのチャンク処理を通じてロジットテンソルの完全な具現化を回避することでRLのメモリ使用量を削減しました。順伝播中にロジットを具現化するために必要なVRAMの概算は式(1)に示されています。

Equation 1: Logit Memory (GB)=batch size×context length×vocab dim10243\text{Equation 1: } \text{Logit Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{vocab dim}}{1024^3}

この定式化を使用すると、次の構成は batch_size = 4, context_length = 8192、および vocab_dim = 128,000 の場合、概ね必要になります 3.3 GBのVRAM ロジットテンソルを格納するために。

去年の Long Context gpt-oss を経て、私たちはGRPO向けに融合損失アプローチを導入しました。このアプローチは一度に単一のバッチサンプルのみを処理することを保証し、ピークメモリ使用量を大幅に低減します。同じ構成では、VRAM使用量は概ね 0.83 GBに低下し、式(2)に反映されています。

Equation 2: Logit Memory (GB)=context length×vocab dim10243\text{Equation 2: }\text{Logit Memory (GB)} = \frac{\text{context length} \times \text{vocab dim}}{1024^3}
図1: gpt-oss BF16 GRPO LoRA(Unsloth vs. すべての最適化を有効にしたHF)
図2: Qwen3-8B QLoRA GRPO LoRA(Unsloth vs. すべての最適化を有効にしたHF)

このアップデートでは、同じアイデアをさらに拡張し、 シーケンス次元 にもまたチャンク処理を導入します。バッチ_size × context_length全体のためにロジットを一度に具現化する代わりに、これらの次元を平坦化して構成可能な乗数を使って小さいチャンクで処理します。これにより、Unslothはピークメモリ使用量を増やすことなく大幅に長いコンテキストをサポートできます。 (batch_size × context_length) 空間を一度に処理する代わりに、これらの次元を平坦化して小さなチャンクで処理します。

下の図5では、乗数として max(4, context_length // 4096)を使用していますが、望ましいメモリとパフォーマンスのトレードオフに応じて任意の乗数を指定できます。この設定では、同じ例の構成(batch_size = 4, context_length = 8192, vocab_dim = 128,000)は現在必要とするのはわずか 0.207 GBのVRAM ロジット具現化のために。

Equation 3: Logit Memory (GB)=context lengthmultiplier×vocab dim10243\text{Equation 3: }\text{Logit Memory (GB)} = \frac{\frac{\text{context length}}{\text{multiplier}} \times \text{vocab dim}}{1024^3}
図3: gpt-oss-20b(H100)Unsloth新旧比較
図4: Qwen3-8B(H100)Unsloth新旧比較
図5: gpt-oss-20b(H100)
図6: Qwen3-8B(B200)

このアップデートはコンパイルされた chunked_hidden_states_selective_log_softmax にも反映されており、現在はバッチ次元とシーケンス次元の両方でのチャンク処理をサポートします。ロジットテンソル([batch_size, context_length, vocab_dim])を保持するために、常にバッチ次元でチャンク化されます。追加のシーケンスチャンク処理は unsloth_logit_chunk_multiplier によってGRPO構成で制御されます;未設定の場合はデフォルトで max(4, context_length // 4096)となります。下の例では、 input_ids_chunk[0] は最適化2における隠れ状態ミニバッチのサイズに対応します。

  1. 私たちはカスタムコンパイルオプション付きのtorch.compileを利用してVRAMを削減し、速度を向上させます。

  2. すべてのチャンク化されたロジットは精度を保つためにfloat32にアップキャストされます。

  3. 私たちはロジットソフトキャッピング、温度スケーリングおよびその他すべての機能をサポートします。

👻隠れ状態のチャンク処理

また、より長いコンテキスト長では隠れ状態がメモリ使用量の大きな要因になることを観察しました。デモのために、次を仮定します hidden_states_dim=4096。対応するメモリ使用量はロジットの場合と同様の定式化に従い、以下に示されます。

Hidden States Memory (GB)=batch size×context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{hidden states dim}}{1024^3}

での batch_size = 8 および context_length = 64000の場合、これは概ねVRAM使用量が 2 GBになります。このリリースでは、対数確率計算中の隠れ状態テンソルに対してバッチ次元でのオプションのチャンク処理を導入します。これによりVRAM使用量はバッチサイズで割られ、この場合は 0.244 GBになります。これにより隠れ状態を具現化するために必要なピークVRAMが削減され、以下の更新された式に反映されています:

Hidden States Memory (GB)=context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{context length} \times \text{hidden states dim}}{1024^3}

私たちのリリースでのクロスエントロピー損失と同様に、 500K Context Training 新しい実装は 自動的に隠れ状態のバッチ処理を調整します。ユーザーはこの挙動を以下で制御することもできます unsloth_grpo_mini_batch経由で。ただし、 unsloth_grpo_mini_batch を最適値より大きくすると、以前の損失関数と比較して若干の性能向上または低下(通常は高速化)を引き起こす可能性があります。

しかし、GPT-OSSの実行中(context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2)、 unsloth_grpo_mini_batch = 1 および unsloth_logit_chunk_multiplier = 4 を設定すると、 ほとんどまたはまったく速度低下がなく、VRAM使用量を約5 GB削減します 以前のUnslothバージョンと比較して。

circle-check

🌵をご参照ください。

ログソフトマックスのためのアクティベーションのオフロード このリリースの開発中に、隠れ状態をバッチ次元でタイル処理すると、融合されたロジットとlogprobsの計算後にアクティベーションがオフロードされていないことを発見しました。ロジットはhidden_states[i] @ lm_head

を使用して一度に1バッチずつ計算されるため、モデルの順伝播内で動作するように設計された既存のアクティベーションオフロードおよび勾配チェックポイントのロジックはこのケースには適用されませんでした。

circle-check

)、逆伝播はアクティベーションがオフロードされているかどうかに関わらず同じ量のGPUメモリを必要とします。この場合、アクティベーションのオフロードはメモリ使用量を減らさずにわずかなパフォーマンス低下を招くため、利点はありません。

パラメータの設定: unsloth_grpo_mini_batch および unsloth_logit_chunk_multiplierもしあなたが を設定しない場合、私たちは 利用可能なVRAMおよびコンテキスト長のサイズに基づいてこれら二つのパラメータを自動的に調整します

unsloth_logit_chunk_multiplier = 2 unsloth_grpo_mini_batch および unsloth_logit_chunk_multiplier 最適化と

の可視化は下の図で見ることができます。 unsloth_grpo_mini_batch 3つの行列は全体の大きなバッチまたは unsloth_logit_chunk_multiplier (黒い角括弧の数で表現)を表し、各行列の行はシーケンス長をどのように

📼がチャンク化するかを表します(赤い角括弧の数で表現)。

RL向けvLLMRLワークフローでは、推論/生成フェーズが主要なボトルネックです vLLMarrow-up-right。これに対処するために、私たちは

を利用しており、通常の生成と比べて最大11倍の生成高速化を実現しています。GRPOが昨年普及して以来、vLLMはUnslothを含むほとんどのRLフレームワークのコアコンポーネントでした。UnslothのRLをより良くするために重要な役割を果たしているvLLMチームとその全ての貢献者に感謝の意を表します! GRPOノートブック (またはローカルでUnslothを更新):

- GSPOを使用できます、

最終更新

役に立ちましたか?