from unsloth import FastLanguageModel
import re
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset, Dataset
max_seq_length = 1024 # より長い推論トレースのために増やすことができます
lora_rank = 32 # ランクが大きいほど賢くなりますが遅くなります
max_prompt_length = 256
# データセットの読み込みと準備
SYSTEM_PROMPT = """
以下の形式で応答してください:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# 1ショットプロンプティングのために中間のメッセージをアンコメントしてください
def get_gsm8k_questions(split: str = "train") -> Dataset:
data = load_dataset("openai/gsm8k", "main")[split] # type: ignore
data = data.map(
lambda x: { # type: ignore
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question"]},
],
"answer": extract_hash_answer(x["answer"]),
}
) # type: ignore
return data # type: ignore
# 報酬関数
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
q = prompts[0][-1]["content"]
extracted_responses = [extract_xml_answer(r) for r in responses]
print(
"-" * 20,
f"Question:\n{q}",
f"\nAnswer:\n{answer[0]}",
f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}",
)
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""補完が特定の形式を持っているかどうかをチェックする報酬関数。"""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""補完が特定の形式を持っているかどうかをチェックする報酬関数。"""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text: str) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
if __name__ == "__main__":
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Qwen3-0.6B",
max_seq_length=max_seq_length,
load_in_4bit=False, # LoRA用はFalse(16ビット)
fast_inference=False, # vLLMの高速推論を有効にする
max_lora_rank=lora_rank,
gpu_memory_utilization=0.7, # メモリ不足時は減らす
device_map="xpu:0",
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # 任意の正の数を選択!推奨:8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # メモリ不足の場合はQKVOを削除
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # 長いコンテキストのファインチューニングを有効にする
random_state=3407,
)
dataset = get_gsm8k_questions()
training_args = GRPOConfig(
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_torch",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1, # 滑らかなトレーニングのために4に増やす
num_generations=4, # メモリ不足の場合は減らす
max_prompt_length=max_prompt_length,
max_completion_length=max_seq_length - max_prompt_length,
# num_train_epochs=1, # フルトレーニング実行の場合は1に設定
max_steps=20,
save_steps=250,
max_grad_norm=0.1,
report_to="none", # Weights & Biasesを使用できます
output_dir="outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
dataset_num_proc=1, # Windowsで推奨
)
trainer.train()