gpt-oss強化学習

今ならOpenAIを訓練できます gpt-oss RL と GRPO で、 Unsloth。Unsloth は今や 最速の推論 (3倍高速)、 最小のVRAM使用量 (50%削減)と 最長のコンテキスト (8倍長い)を、gpt-ossのRLに対してどの実装よりも提供します — 精度低下なしで。 gpt-ossでの強化学習(RL)はまだvLLM互換ではないため、私たちは推論コードをTransformersコードから書き直し、約21トークン/秒でgpt-ossに3倍高速な推論を実現しました。BF16でも、Unslothは特にVRAM使用量に関して最速の推論(約30トークン/秒)を達成し、他のどのRL実装よりも50%少ないVRAMで動作します。私たちは 50%の重み共有機能 を、vLLMがRLに対応次第サポートする予定です。

Unslothを使えば、15GB VRAMのColab上でgpt-oss-20bをGRPOで訓練でき、 無料 です。私たちは埋め込みのオフロードを導入し、 offload_embeddingsを通じて使用量をさらに1GB削減しました。Unlothの新しい推論は、A100、H100、旧型T4を含む あらゆる GPU上でより高速に動作します。gpt-oss-120bは120GB VRAMのGPUに問題なく収まります。

Unslothはgpt-oss向けの4ビットRLをサポートする唯一のフレームワークです。すべての性能向上は、Unsloth独自の 重み共有, Flex Attention, スタンバイ とカスタムカーネルによるものです。

⚡推論を大幅に高速化

推論はRL学習において極めて重要です。というのも、何らかの報酬関数を最大化する前に候補解を生成するために必要だからです(こちらを参照 より詳しい説明はこちら)。vLLMなしでgpt-ossに対して最速の推論速度を実現するため、私たちはTransformersの推論コードを書き直し、Unslothのようなカスタムアルゴリズムを含む多くの革新を統合しました Flex Attention。さらに torch.compile 内で特別なフラグ(combo kernelsのようなもの)を使用しています。gpt-oss向けの新しい推論コードは、すでに最適化済みのベースライン(ネイティブTransformersより2倍高速)と比較評価されました。

vLLMはgpt-oss向けのBF16学習とLoRAサポートがないため、gpt-ossのRLをサポートしていません。Unslothがなければ、完全精度のBF16による訓練のみが機能し、 メモリ使用量が 800%以上増加します。多くのフレームワークはFA3(Flash Attention 3)をデフォルトで有効にします(VRAM使用量を減らし、速度を向上させます) が、これにより誤った訓練損失が発生します。参照: Issue 1797 がFA3リポジトリにあります。とはいえ、FA3は長文コンテキスト学習を妨げるため無効化する必要があります。FA3はO(N)のメモリ使用量なのに対し、素朴なattentionはO(N^2)に膨れ上がるからです。そこで、attention sinksを微分可能にするために、私たちは Unsloth Flex Attention.

gpt-oss RLの推論を、BitsandBytes 4ビットでベンチマークし、BF16でも別途テストしました。Unslothの4ビット推論は約4倍高速で、BF16も特にVRAM使用量においてより効率的です。

Unslothのgpt-oss RLの最も優れた点は、BF16をサポートしないGPUでも、どのGPUでも動作できることです。無料のgpt-oss-20b Colabノートブックでは旧型の15GB T4 GPUを使用しているため、推論例も問題なく動作します!

🛠️ gpt-oss Flex Attention の問題点と癖

attention sinksの実装を こちらで説明されているように 変更する必要がありました。左パディングで生成を動作させるためです。logsumexpを取得し、以下のようにsigmoid活性化を適用してattention重みを変更する必要がありました:

A(X)=σ(1dQKT)VA(X)=exp1dQKTexp1dQKTVLSE=logexp1dQKTAsinks(X)=A(X)σ(LSEsinks)A(X) = \sigma \bigg( \frac{1}{\sqrt{d}}QK^T \bigg)V \\ A(X) = \frac{\exp{\frac{1}{\sqrt{d}}QK^T}}{\sum{\exp{\frac{1}{\sqrt{d}}QK^T}}}V \\ \text{LSE} = \log{\sum{\exp{\frac{1}{\sqrt{d}}QK^T}}} \\ A_{sinks}(X) = A(X) \odot \sigma (\text{LSE} - \text{sinks})

推論時の左パディングされたマスキングも、gpt-ossでは対処が難しい問題でした。トークン生成時のKV Cache prefillを考慮するだけでなく、バッチ生成における各プロンプトのパディングトークン数の違いも考慮する必要があり、ブロックマスクの保存方法が変わります。その例は以下のとおりです:

通常の因果マスク:

一般的な推論(デコード)では

同じマスキング戦略を単純に使うと、これは失敗します:

生成(デコード段階)では、通常、attention行列の最後の行だけが重要です。というのも、1つのクエリトークンが過去のすべてのキー トークンに注意を向けるだけだからです。因果マスク(q_idx ≥ k_idx)を単純に適用すると、単一のクエリのインデックスは0なのに対し、n_k個のキー トークンがあるため失敗します。これを修正するには、どのトークンに注意を向けるかを決めるためのマスク生成時のオフセットが必要です。しかし、素朴な方法ではオフセットが各ステップで変わるため、マスクとカーネルの再生成が必要になり遅いです。私たちはキャッシュとコンパイルの最適化でこれを解決しました。

より難しいのはバッチ生成です。シーケンスの長さが異なるため、パディングがマスク生成を複雑にします。Flex Attentionには多くの 課題があり 、動的マスクは扱いが難しいです。さらに、コンパイルされていない場合はeager attentionにフォールバックし、遅くてメモリ消費も大きくなります(シーケンス長に対して二次対線形)。

引用元 https://github.com/meta-pytorch/attention-gym/issues/15#issuecomment-2284148665

これは _compile=True で呼び出す必要があります。私たちは基本的に、ブロックマスクを生成するために、あなたのブロックマスクを完全な Q_LEN x KV_LEN 行列にマッピングしています。コンパイルなしでは、この全体を実体化する必要があり、長いシーケンスではOOMを引き起こす可能性があります。

また、 flex_attention = torch.compile(flex_attention)を実行する必要があります。コンパイルなしでは、flexは非融合のeager実装にフォールバックします。デバッグには便利ですが、はるかに遅く、完全なスコア行列を実体化します。

最終的に、このマスクはKV Cacheを用いたprefillとdecode、シーケンスごとのバッチおよびパディングトークンを動的に扱い、 torch.compile 互換性があり、スライディングウィンドウをサポートしていなければなりません。

🔍 Flash Attention の検証

私たちが探ったもう1つの興味深い方向は、Flash Attentionの統合を試みることでした。その利点は広く認識されていますが、1つの制約として、gpt-ossのバックワードパスでattention sinksをサポートしていないことがあります。これを回避するため、attention機構を再構成し、FlashAttentionが容易に提供するattention出力とlogsumexp値のみに作用するようにしました。こうした利点を考えると、試すのは当然の選択に思えました。

しかし、すぐに問題が見え始めました。最初の数層は期待どおりに動作した一方で、後半の層、特に18層から24層では、transformersのeagerモード実装から大きく逸脱した出力が生成されました。重要なのは、この不一致は誤差の蓄積では説明できないということです。というのも、各方法への入力は各層で同一だからです。さらに検証するため、Unslothとも比較しました FlexAttention.

なぜ最後の数層だけが、flash attention実装と他の実装との間でこれほど劇的に異なるのか、さらなる調査が必要です。

⚠️ 報酬ハッキングに対抗できるか?

RLの究極の目的は、何らかの報酬(たとえば速度、収益、指標)を最大化することです。しかしRLは 不正を働けます。 RLアルゴリズムがコツを学んだり、報酬を増やすために何かを悪用したりして、最終的に本来のタスクを実行していない場合、これは「報酬ハッキング".

」と呼ばれます。これは、モデルがコーディング課題を通すためにユニットテストを改変することを学んでしまう理由であり、実世界への展開における重大な障害です。ほかにも良い例としては Wikipedia.

私たちの 無料gpt-oss RLノートブック で、コード生成設定における報酬ハッキングへの対抗方法を探り、一般的な失敗モードに対する具体的な解決策を示します。モデルが時間計測関数を編集したり、他のライブラリに処理を外注したり、結果をキャッシュしたり、露骨に不正を働く様子を確認しました。対策後は、モデルは巧妙なごまかしではなく、真に最適化された行列積カーネルを生成します。

🏆報酬ハッキング

RL中の報酬ハッキングの一般的な例には以下があります:

怠慢

RLはNumpyやTorchなど、最適化されたCUDAカーネルを呼び出す他のライブラリを使うことを学習します。生成されたコードが標準外のPythonライブラリをimportしているかを確認することで、RLアルゴリズムが最適化済みコードを呼び出すのを止められます。

キャッシュと不正

RLは出力結果をキャッシュすることを学習し、Pythonのグローバル変数を調べることで実際の出力を見つけることを学習します。

大きな偽の行列でキャッシュを消去することで、RLアルゴリズムがキャッシュ済みデータを使うのを止められます。また、複数のループとターンで慎重にベンチマークする必要があります。

不正

RLは時間計測関数を編集して、0秒と出力させることを学習します。RLアルゴリズムがグローバル変数やキャッシュ変数を使うのを止めるために、 ローカル および グローバルへのアクセスを制限します。また、関数の作成には exec を使うので、出力を空のdictに保存する必要があります。さらに、次を通じたグローバル変数アクセスも禁止します types.FunctionType(f.__code__, {})\

チュートリアル: RLでgpt-ossを訓練する方法

LLMは複雑な環境を伴うタスクにしばしば苦戦します。しかし、適用することで 強化学習 (RL)とカスタムの 報酬関数を設計することで、これらの課題は克服できます。

RLは、自動カーネル生成や戦略作成などのタスク向けに適応できます。このチュートリアルでは、2048を自律的に攻略できるように、 gpt-ossGRPO とUnslothで学習させる方法を示します。

私たちのノートブックには、プロセス全体を進めるためのステップバイステップガイドがすでに含まれています。

作成するもの:

  • モデルが自動的に2048で勝てるようにgpt-oss-20bを学習させる

  • モデルがやり取りできる最小限の2048環境を作成する

  • 定義する 報酬関数 内容:

    1. 生成された戦略がコンパイルおよび実行できることを確認し、

    2. 報酬ハッキングを防止し(外部インポートを禁止)、さらに

    3. 実際のゲーム成功に報酬を与える

  • 推論を実行し、モデルをエクスポートする(MXFP4 4ビットまたはマージ済みFP16)

ハードウェア: 2048の例は無料のColab T4で動作しますが、訓練は遅くなります。A100/H100の方がはるかに高速です。4ビット読み込み + LoRAにより、控えめなVRAMで20Bモデルを収められます

最終更新

役に立ちましたか?