from unsloth import FastLanguageModel
import re
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset, Dataset
max_seq_length = 1024 # 可为更长的推理轨迹增加
lora_rank = 32 # 更大的秩=更聪明,但更慢
max_prompt_length = 256
# 加载并准备数据集
SYSTEM_PROMPT = """
以以下格式响应:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# 取消注释中间消息以进行 1-shot 提示
def get_gsm8k_questions(split: str = "train") -> Dataset:
data = load_dataset("openai/gsm8k", "main")[split] # type: ignore
data = data.map(
lambda x: { # type: ignore
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question"]},
],
"answer": extract_hash_answer(x["answer"]),
}
) # type: ignore
return data # type: ignore
# 奖励函数
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
q = prompts[0][-1]["content"]
extracted_responses = [extract_xml_answer(r) for r in responses]
print(
"-" * 20,
f"Question:\n{q}",
f"\nAnswer:\n{answer[0]}",
f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}",
)
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""检查完成内容是否具有特定格式的奖励函数。"""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""检查完成内容是否具有特定格式的奖励函数。"""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text: str) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
if __name__ == "__main__":
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Qwen3-0.6B",
max_seq_length=max_seq_length,
load_in_4bit=False, # LoRA 使用 16 位 时为 False
fast_inference=False, # 启用 vLLM 快速推理
max_lora_rank=lora_rank,
gpu_memory_utilization=0.7, # 若内存不足请降低
device_map="xpu:0",
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # 选择任意大于 0 的数!建议 8、16、32、64、128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # 如果内存不足可移除 QKVO
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # 启用长上下文微调
random_state=3407,
)
dataset = get_gsm8k_questions()
training_args = GRPOConfig(
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_torch",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1, # 若要更平滑的训练可增加到 4
num_generations=4, # 若内存不足请减少
max_prompt_length=max_prompt_length,
max_completion_length=max_seq_length - max_prompt_length,
# num_train_epochs=1, # 对于完整训练运行设为 1
max_steps=20,
save_steps=250,
max_grad_norm=0.1,
report_to="none", # 可使用 Weights & Biases
output_dir="outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
dataset_num_proc=1, # 在 Windows 上推荐
)
trainer.train()