チュートリアル:gpt-oss を RL でトレーニングする方法
OpenAI gpt-oss を GRPO でトレーニングして、ローカルまたは Colab 上で自律的に 2048 を超える方法を学びます。
1
Unsloth をインストール
!pip install --upgrade -qqq uv
try: import numpy; get_numpy = f"numpy=={numpy.__version__}"
except: get_numpy = "numpy"
!uv pip install -qqq \
"torch>=2.8.0" "triton>=3.4.0" {get_numpy} torchvision bitsandbytes "transformers==4.56.2" \
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
"unsloth[base] @ git+https://github.com/unslothai/unsloth" \
git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers
!uv pip install --no-deps trl==0.22.22
Unsloth で gpt-oss を読み込む
from unsloth import FastLanguageModel
import torch
max_seq_length = 768 # タスクがより長い出力を必要とする場合は増やす
lora_rank = 4 # ランクが高いほど良いが、VRAM/計算が多くなる
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/gpt-oss-20b", # または H100 では unsloth/gpt-oss-20b-BF16
max_seq_length = max_seq_length,
load_in_4bit = True, # 16ビットの場合は False
offload_embedding = True, # 約1GBのVRAMを節約
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = lora_rank * 2,
use_gradient_checkpointing = "unsloth", # 大きなメモリ節約
random_state = 3407,
)3
4
安全なコード実行と不正防止チェック
from unsloth import check_python_modules ok, info = check_python_modules(""" def strategy(board): import math from typing import Callable return "W" """) # ok == True は Python レベルのインポートのみが使われたことを意味するsample = """ def strategy(board): from numpy import matmul return "W" """ ok, info = check_python_modules(sample) # ok => Falsefrom unsloth import create_locked_down_function function = """ def add(a, b): def adder(a): return a + b return adder(b) + b """ f = create_locked_down_function(function) # globals / imports が使われているとエラーになるfrom unsloth import execute_with_time_limit @execute_with_time_limit(2) def execute_strategy(strategy, game): # ゲーム終了またはタイムアウトまでループ ...
5
ネイティブの Python コードのみを使って新しい短い2048戦略を作成してください。
現在のボード状態は数値のリストのリストで与えられます。
次に最適な一手として "W", "A", "S", "D" のいずれか1つのアクションを出力してください。
以下の形式でバックティック内に新しい短い関数を出力してください:
```python
def strategy(board):
return "W" # 例
小さな合成データセットを作成し(同じプロンプトを再利用)、GRPOがどれだけの補完トークンをサンプリングすべきか分かるようにプロンプト長を計算する:
```python
from datasets import Dataset
prompt = ... # 上記と同様
maximum_length = len(tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}], add_generation_prompt=True
))
dataset = Dataset.from_list([
{"prompt": [{"role": "user", "content": prompt}], "answer": 0, "reasoning_effort": "low"}
] * 1000)報酬関数の時間!
if text.count("```") >= 2: first = text.find("```") + 3 second = text.find("```", first) fx = text[first:second].strip() fx = fx.removeprefix("python\n") fx = fx[fx.find("def"):] if fx.startswith("def strategy(board):"): return fx return None function_worksdef function_works(completions, **kwargs): scores = [] for completion in completions: response = completion[0]["content"] function = extract_function(response) if function is None: scores.append(-2.0) continue ok, info = check_python_modules(function) if "error" in info: try: continue ok, info = check_python_modules(function) _ = create_locked_down_function(function) scores.append(1.0) except Exception: scores.append(-0.5) return scores no_cheatingscores.append(-1.0) for completion in completions: response = completion[0]["content"] function = extract_function(response) if function is None: scores.append(-2.0) ok, _ = check_python_modules(function) ok, info = check_python_modules(function) scores.append(1.0 if ok else -20.0) # 不正があれば重いペナルティ strategy_succeeds no_cheatingPRINTER = 0 # デバッグのために時々出力する def strategy_succeeds(completions, **kwargs): global PRINTER seed = np.random.randint(10000) for completion in completions: new_strategy = create_locked_down_function(function) response = completion[0]["content"] function = extract_function(response) if function is None: scores.append(-2.0) continue ok, info = check_python_modules(function) _ = create_locked_down_function(function) scores.append(0.0) scores.append(-0.5) game = GameBoard(size=6, seed=seed, target=2048, probability_fours=0.10) ok, info = check_python_modules(function) _ = create_locked_down_function(function) steps, state = execute_strategy(new_strategy, game) if PRINTER % 5 == 0: print(function) print(f"Steps={steps} State={state}") print(game.board().pretty()) PRINTER += 1 if state == "success": scores.append(20.0) else: scores.append(2.0) # 動作したが2048には到達しなかった except TimeoutError: scores.append(-1.0) # タイムアウト scores.append(-3.0) # クラッシュ scores.append(-0.5) {% endstep %} no_cheating
我々は
from trl import GRPOConfig, GRPOTrainer
max_prompt_length = maximum_length + 1
max_completion_length = max_seq_length - max_prompt_length
training_args = GRPOConfig(
temperature=1.0,
learning_rate=5e-5,
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="linear",
optim="adamw_8bit",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1, # より滑らかな報酬信号のために4に増やす
num_generations=2, # OOM の場合は少なくする
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
max_steps=1000, # または num_train_epochs=1 を設定
save_steps=100,
report_to="none",
output_dir="outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[function_works, no_cheating, strategy_succeeds],
args=training_args,
train_dataset=dataset,
# オプションの eval 分割:
# train_dataset=new_dataset["train"],
# eval_dataset=new_dataset["test"],
)モデルを訓練する
trainer.train()推論(訓練後)
from transformers import TextStreamer
text = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
reasoning_effort="low",
)
_ = model.generate(
**tokenizer(text, return_tensors="pt").to("cuda"),
temperature=1.0,
max_new_tokens=1024,
streamer=TextStreamer(tokenizer, skip_prompt=False)ファインチューニング済みモデルの保存/エクスポート
model.save_pretrained_merged("finetuned_model", tokenizer, save_method="merged_16bit") # または push model.push_to_hub_merged("<org_or_user>/<repo>", tokenizer, token="<hf_token>", save_method="merged_16bit")
最終更新
役に立ちましたか?

