lightbulb-cfl-ongpt-oss 強化学習

あなたは今やOpenAIを訓練できます gpt-oss RLおよびGRPOで Unslotharrow-up-right。Unslothは現在提供しています 最速の推論 (3倍高速)、 最小のVRAM使用量 (50%少ない)および 最長のコンテキスト (8倍長い)を、いかなる実装よりもgpt-ossのRLで実現しています — 精度の低下なしで。 gpt-ossでの強化学習(RL)はまだvLLMと互換性がないため、gpt-oss向けに約21トークン/秒で3倍高速な推論を提供するために、Transformersの推論コードを書き直す必要がありました。BF16では、Unslothは特にVRAM使用量に対して優れており、他のどのRL実装よりも50%少ないVRAMで最速の推論(約30トークン/秒)を達成します。vLLMがRLに対応次第、私たちの 50%の重み共有機能 をサポートする予定です。

Unslothを使えば、gpt-oss-20bを15GBのVRAMでGRPOで訓練でき、 無料で Colab上で行えます。埋め込みオフロード(embedding offloading)を導入し、 offload_embeddingsを介して使用量を1GB削減しました。Unslothの新しい推論は TTSモデル( GPU(A100、H100や古いT4を含む)上でより高速に動作します。gpt-oss-120bは120GBのVRAM GPUにうまく収まります。

Unslothはgpt-ossの4ビットRLをサポートする唯一のフレームワークです。全ての性能向上はUnsloth独自の 重み共有, Flex Attention, Standby およびカスタムカーネルによるものです。

circle-exclamation

⚡ 推論を大幅に高速化する

推論はRL訓練において重要です。なぜなら報酬関数(ここを参照してください より詳細な説明のため)を最大化する前に候補解を生成する必要があるからです。vLLMなしでgpt-ossの最速推論速度を達成するため、私たちはTransformersの推論コードを書き直し、Unslothのようなカスタムアルゴリズムを含む多くの革新を統合しました Flex Attention、および内部で特別なフラグを使用(コンボカーネルのような)しました。gpt-oss向けの新しい推論コードは、すでに最適化されたベースライン(ネイティブTransformersの2倍高速)と比較評価されました。 torch.compile vLLMはgpt-ossのRLをサポートしていません。なぜならgpt-ossのBF16訓練とLoRAをサポートしていないからです。Unslothがなければ、完全精度のBF16での訓練のみが動作し、これにより

メモリ使用量 800%超高くなります。多くのフレームワークはFA3(VRAM使用を削減し速度を上げる)をデフォルトで有効にしますが、 これは不正確な訓練損失を引き起こします。参照: FA3リポジトリのIssue 1797arrow-up-right を参照してください。しかしFA3は長いコンテキストの訓練を不可能にするため無効にする必要があります。FA3はO(N)メモリ使用を行う一方で、素朴な注意はO(N^2)で膨張します。したがって、attention sinksを微分可能にするために、私たちは Unsloth Flex Attention.

を実装しました。私たちはgpt-ossのRL推論をBitsandBytesの4ビットでベンチマークし、BF16についても別途テストしました。Unslothの4ビット推論は約4倍高速で、BF16も特にVRAM使用において効率的です。

Unslothのgpt-oss RLの最良の点は、BF16をサポートしないGPUであっても動作することです。私たちの無料のgpt-oss-20b Colabノートブックは古い15GBのT4 GPUを使用するので、推論例はよく動作します!

🛠️ gpt-oss Flex Attention の問題点と特異性

左パディングで生成を動作させるために、attention sinksの実装を こちらに記載 変更する必要がありました。logsumexpを取得し、シグモイド活性化を適用して下記のように注意重みを変更する必要がありました:

A(X)=σ(1dQKT)VA(X)=exp1dQKTexp1dQKTVLSE=logexp1dQKTAsinks(X)=A(X)σ(LSEsinks)A(X) = \sigma \bigg( \frac{1}{\sqrt{d}}QK^T \bigg)V \\ A(X) = \frac{\exp{\frac{1}{\sqrt{d}}QK^T}}{\sum{\exp{\frac{1}{\sqrt{d}}QK^T}}}V \\ \text{LSE} = \log{\sum{\exp{\frac{1}{\sqrt{d}}QK^T}}} \\ A_{sinks}(X) = A(X) \odot \sigma (\text{LSE} - \text{sinks})

推論中の左パディングマスキングもgpt-ossでは厄介な問題でした。KVキャッシュの事前充填を考慮するだけでなく、バッチ生成における各プロンプトごとの異なるパッドトークン数も考慮する必要があり、それがブロックマスクの保存方法を変えました。以下にそのような例を示します:

通常の因果マスク:

一般的な場合の推論(デコード)では

もし同じマスキング戦略を素朴に使うと、これは失敗します:

生成(デコードフェーズ)では、通常注意行列の最後の行のみが重要です。なぜなら一つのクエリトークンが全ての以前のキーに注意を向けるからです。もし因果マスクを素朴に適用すると(q_idx ≥ k_idx)、単一のクエリのインデックスが0である一方でキーがn_k存在するため、これが失敗します。これを修正するにはマスク生成時にオフセットが必要で、どのトークンに注意を向けるかを決定します。しかし素朴なアプローチは遅く、オフセットが各ステップで変わるためマスクとカーネルの再生成を強制します。私たちはキャッシュとコンパイル最適化でこれを解決しました。

より難しいのはバッチ生成です。シーケンスは長さが異なるため、パディングがマスク生成を複雑にします。Flex Attentionには多くの 課題arrow-up-right があり、動的マスクは厄介です。さらに、コンパイルされていない場合はフォールバックでイーガー注意に戻り、これは遅くメモリを大量に消費します(シーケンス長に対して二次的=quadratic、線形ではない)。

こちらからの引用: https://github.com/meta-pytorch/attention-gym/issues/15#issuecomment-2284148665arrow-up-right

これを_compile=Trueで呼び出す必要があります。私たちは本質的にブロックマスクをフルのQ_LEN x KV_LEN行列にマッピングしてブロックマスクを生成します。コンパイルがないと、この完全なものを具現化する必要があり、長いシーケンスでOOMを引き起こす可能性があります。

さらに、次のように実行する必要があります: flex_attention = torch.compile(flex_attention)。コンパイルがない場合、flexは非融合のイーガー実装にフォールバックし、デバッグには優れていますが、はるかに遅くスコア行列全体を具現化します。

最終的に、マスクはKVキャッシュのプレフィル対デコード、バッチとシーケンスごとのパディングトークンを動的に処理し、かつ torch.compile フレンドリーであり、スライディングウィンドウをサポートする必要があります。

🔍 Flash Attention の調査

私たちが探ったもう一つの興味深い方向はFlash Attentionの統合試験でした。利点は広く認められていますが、一つの制限はgpt-ossの逆伝播におけるattention sinksをサポートしていない点です。これを回避するため、Attention出力とFlashAttentionが提供するlogsumexp値のみに対して動作するように注意機構を再構築しました。これらの利点を考えると、試す価値は明白に思えました。

しかし、すぐに問題が現れ始めました。最初の数層は期待通りに振る舞いましたが、後半の層、特に18層から24層では、transformersのイーガーモード実装と大きく異なる出力を生成しました。重要なのは、この差異は誤差の蓄積によるものではなく、各層における入力はどの方法でも同一であるという点です。さらなる検証のため、私たちは結果をUnslothの FlexAttention.

とも比較しました。これは、なぜ最後の数層だけがFlash Attention実装と他の実装でこれほど劇的な差を示すのかをさらに調査する必要があります。

triangle-exclamation

⚠️ 報酬ハッキングに対抗できますか?

強化学習(RL)の究極の目的は、ある報酬(例えば速度、収益、ある指標)を最大化することです。しかしRLは 不正を行うことがあります。 RLアルゴリズムが実際にタスクを遂行することなく報酬を増やすためにトリックを覚えたり何かを悪用したりする場合、これは「報酬ハッキング(Reward Hacking)".

の原因です。モデルがコーディングチャレンジを通過するためにユニットテストを改変することを学ぶのはこのためであり、これらは実世界での展開における重大な障害となります。他の良い例は ウィキペディアarrow-up-right.

私たちの 無料の gpt-oss RL ノートブックarrow-up-right では、コード生成の設定で報酬ハッキングに対抗する方法を探り、一般的なエラーモードに対する具体的な解決策を示しています。モデルがタイミング関数を編集したり、他のライブラリに外注したり、結果をキャッシュしたり、完全に不正を行ったりするのを観察しました。対策を講じた後、我々のモデルは巧妙な不正ではなく真に最適化された行列乗算カーネルを生成するようになりました。

🏆報酬ハッキング(Reward Hacking)

強化学習中の報酬ハッキングの一般的な例には以下があります:

怠惰さ

RLはNumpy、Torch、その他のライブラリを使うことを学び、最適化されたCUDAカーネルを呼び出します。生成されたコードが標準外のPythonライブラリをインポートしているかを検査することで、RLアルゴリズムが最適化コードを呼び出すのを防ぐことができます。

キャッシュと不正

RLは出力結果をキャッシュすることを学び、Pythonのグローバル変数を調べて実際の出力を見つけることを学びます。

大きな偽の行列でキャッシュを消去することにより、RLアルゴリズムがキャッシュデータを使用するのを防げます。また、複数のループや反復で慎重にベンチマークする必要があります。

不正行為

RLはタイミング関数を編集して経過時間を0として出力することを学びます。グローバルやキャッシュされた変数を使用するのを防ぐために、RLの ローカル(locals) および グローバル(globals)を制限します。また、 exec で関数を作成するので、出力を空の辞書に保存する必要があります。さらに、 types.FunctionType(f.__code__, {}) を介したグローバル変数アクセスを禁止します。\

チュートリアル:RLでgpt-ossを訓練する方法

LLMは複雑な環境を含むタスクで苦戦することが多い。しかし、 強化学習 (RL)やカスタム 報酬関数を設計することで、これらの課題を克服できる。

RLはオートカーネルや戦略生成のようなタスクに適用できる。本チュートリアルでは、 gpt-ossGRPO と Unsloth を用いて自律的に2048を攻略する方法を示す。

私たちのノートブックには、プロセス全体を案内するステップバイステップのガイドが含まれています。

あなたが作るもの:

  • gpt-oss-20b を訓練し、モデルが自動で2048に勝てるようにする

  • モデルが対話できる最小限の2048環境を作成する

  • を定義する 報酬関数

    1. 生成された戦略がコンパイルされ実行されることを確認し、

    2. 報酬ハッキングを防ぐ(外部インポートを禁止)

    3. 実際のゲーム成功を報酬化する

  • 推論を実行してモデルをエクスポートする(MXFP4 4ビットまたはマージ済み FP16)

circle-info

ハードウェア: 2048の例は無料のColab T4で実行されますが、訓練は遅くなります。A100/H100の方がはるかに高速です。4ビット読み込み+LoRAにより、20Bモデルを控えめなVRAMに収めることができます

最終更新

役に立ちましたか?