🧩上級強化学習ドキュメント

Unsloth を GRPO と併用する際の上級ドキュメント設定。

バッチ処理、生成、トレーニングパラメータに関して Unsloth を使った GRPO の詳細ガイド:

トレーニングパラメータ

  • beta (float、デフォルト 0.0):KL係数。

    • 0.0 ⇒ 参照モデルが読み込まれていません(メモリ使用量が少なく、より高速)。

    • 高い beta は方策を参照方策により近づけるよう制約します。

  • num_iterations (int、デフォルト 1):バッチあたりの PPO エポック数(アルゴリズム中の μ)。 勾配蓄積ステップ内でデータをリプレイします。例: 2 = 各蓄積ステップでの順伝播が2回。

  • epsilon (float、デフォルト 0.2):トークン単位の対数確率比のクリッピング値(標準の ε では比の典型範囲 ≈ [-1.2, 1.2])。

  • delta (float、オプション):を有効にします upper のクリッピング境界を 両側 GRPO が設定されているとき。もし 認証用のSSH公開鍵なら、標準の GRPO クリッピングが使用されます。推奨 > 1 + ε が有効な場合(INTELLECT-2 レポートによる)。

  • epsilon_high (float、オプション):上限イプシロン。未設定の場合はデフォルトで epsilon が使用されます。DAPO は以下を推奨します 0.28.

  • importance_sampling_level (“token” | “sequence”、デフォルト "token"):

    • "token":生のトークン毎の比(トークンごとに1つの重み)。

    • "sequence":トークンごとの比を平均してシーケンスレベルの単一比にします。 GSPO はシーケンスレベルのサンプリングがシーケンスレベルの報酬に対してより安定した学習を示すことが多いと示しています。

  • reward_weights (list[float]、オプション):報酬ごとの重み。もし 認証用のSSH公開鍵なら、すべての重み = 1.0 になります。

  • scale_rewards (str|bool、デフォルト "group"):

    • True または "group":各グループ内の 標準偏差でスケーリングします (グループ内の単位分散)。

    • "batch":各グループ内の :バッチ全体にわたる標準偏差でスケーリングします (PPO-Lite に準拠)。

    • False または "none": :スケーリングしない。Dr. GRPO は標準偏差スケーリングによる困難性バイアスを避けるためにスケーリングしないことを推奨しています。

  • loss_type (str、デフォルト "dapo"):

    • "grpo":シーケンス長で正規化します(長さバイアス;非推奨)。

    • "dr_grpo":を用いて正規化します グローバル定数 (Dr. GRPO で導入;長さバイアスを除去)。定数 ≈ max_completion_length.

    • "dapo" (デフォルト):グローバルに蓄積されたバッチ内の アクティブトークンで正規化します (DAPO で導入;長さバイアスを除去)。

    • "bnpo":グローバルに蓄積されたバッチ内の :ローカルバッチ内のアクティブトークンのみ(結果はローカルバッチサイズによって変動する可能性があります; のときは GRPO と等しくなります per_device_train_batch_size == 1).

  • mask_truncated_completions (bool、デフォルト False): もし Trueなら、切り詰められた完了は損失から除外されます(安定性のために DAPO によって推奨)。 注意:このフラグにはいくつか KL の問題があるため、無効にすることを推奨します。

    # mask_truncated_completions が有効な場合、completion_mask 内の切り詰められた完了をゼロにします
    if self.mask_truncated_completions:
        truncated_completions = ~is_eos.any(dim=1)
        completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()

    これにより多くの完了が切り詰められている場合、すべての completion_mask エントリがゼロになることがあり、 n_mask_per_reward = 0 となり、KL が NaN になる可能性があります。 詳細についてはarrow-up-right

  • vllm_importance_sampling_correction (bool、デフォルト True): を適用します 切り詰め重要度サンプリング(TIS) は、生成(例:vLLM / fast_inference)がトレーニングバックエンドと異なるときのオフポリシー効果を補正します。 Unsloth では、これは vLLM/fast_inference を使用している場合に自動的に True に設定されます 。それ以外の場合は False.

  • vllm_importance_sampling_cap (float、デフォルト 2.0): TIS の切り詰めパラメータ C ;安定化のために重要度サンプリング比に上限を設定します。

  • dtype float16 または bfloat16 を選ぶ場合は、以下を参照してください RL における FP16 と BF16 の比較

生成パラメータ

  • temperature (float、デフォルト 1.0): サンプリングの温度。温度が高いほど生成はよりランダムになります。学習に役立つ多様性を得るために、比較的高めの(1.0)温度を使用することを確認してください。

  • top_p (float、オプション、デフォルト 1.0): 考慮する上位トークンの累積確率を制御する浮動小数点。値は (0, 1] の範囲でなければなりません。すべてのトークンを考慮するには 1.0 に設定します。

  • top_k (int、オプション): top-k フィルタリングのために保持する最も高確率の語彙トークン数。None の場合、top-k フィルタリングは無効になり、すべてのトークンが考慮されます。

  • min_p (float、オプション): 最小トークン確率。最も可能性の高いトークンの確率でスケーリングされます。0.0〜1.0 の値でなければなりません。典型的な値は 0.01〜0.2 の範囲です。

  • repetition_penalty (float、オプション、デフォルト 1.0): プロンプトおよびこれまでに生成されたテキストに出現したかどうかに基づいて新しいトークンを罰則する浮動小数点。値 > 1.0 はモデルに新しいトークンの使用を促し、値 < 1.0 はトークンの繰り返しを促します。

  • steps_per_generation: (int、オプション): 生成あたりのステップ数。None の場合、デフォルトは gradient_accumulation_stepsです。これは generation_batch_size.

circle-info

との相互排他的なパラメータです。 per_device_train_batch_size このパラメータを弄るのは少し混乱するため、バッチサイズに関しては

と勾配蓄積を編集することを推奨します

バッチ & スループットパラメータ

  • バッチを制御するパラメータtrain_batch_size :サンプル数 プロセスごと ステップごと。 もしこの整数が num_generationsより小さい場合、デフォルトで num_generations.

  • steps_per_generationマイクロバッチ が寄与する 一つの生成の 損失計算(順伝播のみ)。 新しいデータバッチは毎 steps_per_generation ステップごとに生成されます;逆伝播のタイミングは gradient_accumulation_steps.

  • num_processes:分散トレーニングプロセスの数(例:GPU / ワーカー)。

  • gradient_accumulation_steps (別名 gradient_accumulation):適用するために蓄積するマイクロバッチの数、 逆伝播とオプティマイザ更新を行います。

  • 有効バッチサイズ:

    更新前に勾配へ寄与するサンプルの総数(すべてのプロセスとステップを通じて)。

  • 世代あたりのオプティマイザステップ:

    例: 4 / 2 = 2.

  • num_generations:生成される世代の数 プロンプトあたり (適用中 その後に 計算する effective_batch_size)。 一つの生成サイクルにおける ユニークなプロンプト数 は以下の通りです:

    GRPO が機能するためには > 2 である必要があります。 必要です。

GRPO バッチ例

以下の表は、バッチがステップを通じてどのように流れるか、オプティマイザ更新がいつ発生するか、および新しいバッチがどのように生成されるかを示します。

例 1

生成サイクル A

ステップ
バッチ
注記

0

[0,0,0]

1

[1,1,1]

→ オプティマイザ更新(蓄積 = 2 到達)

2

[2,2,2]

3

[3,3,3]

オプティマイザ更新

生成サイクル B

ステップ
バッチ
注記

0

[4,4,4]

1

[5,5,5]

→ オプティマイザ更新(蓄積 = 2 到達)

2

[6,6,6]

3

[7,7,7]

オプティマイザ更新

例 2

生成サイクル A

ステップ
バッチ
注記

0

[0,0,0]

1

[1,1,1]

2

[2,2,2]

3

[3,3,3]

オプティマイザ更新(蓄積 = 4 到達)

生成サイクル B

ステップ
バッチ
注記

0

[4,4,4]

1

[5,5,5]

2

[6,6,6]

3

[7,7,7]

オプティマイザ更新(蓄積 = 4 到達)

例 3

生成サイクル A

ステップ
バッチ
注記

0

[0,0,0]

1

[0,1,1]

2

[1,1,3]

3

[3,3,3]

オプティマイザ更新(蓄積 = 4 到達)

生成サイクル B

ステップ
バッチ
注記

0

[4,4,4]

1

[4,5,5]

2

[5,5,6]

3

[6,6,6]

オプティマイザ更新(蓄積 = 4 到達)

例 4

生成サイクル A

ステップ
バッチ
注記

0

[0,0,0, 1,1,1]

1

[2,2,2, 3,3,3]

オプティマイザ更新(蓄積 = 2 到達)

生成サイクル B

ステップ
バッチ
注記

0

[4,4,4, 5,5,5]

1

[6,6,6, 7,7,7]

オプティマイザ更新(蓄積 = 2 到達)

簡易式リファレンス

最終更新

役に立ちましたか?