🌀GRPO RL avec contexte 7x plus long

Découvrez comment Unsloth permet l'ajustement RL à contexte ultra-long.

Le plus grand défi de l'apprentissage par renforcement (RL) est de prendre en charge de longues traces de raisonnement. Nous présentons de nouveaux algorithmes de mise en lot (batching) pour permettre ~contexte ~7x plus long (peut dépasser 12x) Entraînement RL sans dégradation de la précision ou de la vitesse par rapport à d'autres configurations optimisées qui utilisent FA3, des kernels et des pertes en morceaux.

  • Unsloth entraîne désormais gpt-oss QLoRA avec contexte 380K sur un seul GPU NVIDIA B200 de 192 Go

  • Qwen3-8B GRPO atteint contexte 110K sur un H100 à 80 Go VRAM via vLLM et QLoRA, et 65K pour gpt-oss avec LoRA en BF16.

  • Sur 24 Go VRAM, gpt-oss atteint 20K de contexte et 32K pour Qwen3-VL-8B QLoRA

  • Les exécutions RL Unsloth GRPO fonctionnent avec Llama, Gemma et tous les modèles prennent automatiquement en charge des contextes plus longs

Nos nouveaux kernels et algorithmes de déplacement de données et de mise en lot débloquent plus de contexte en :

circle-info

Vous pouvez combiner toutes les fonctionnalités d'Unsloth ensemble :

  1. pour gpt-oss à long contexte et notre FP8 RL entraînement Float8 dans et learrow-up-right point de contrôle asynchrone des gradients d'Unsloth

🎉et bien plus

Premiers pas Notebooks GRPO Pour commencer, vous pouvez utiliser n'importe quel

GPU L4

  • Adopter Unsloth pour vos tâches RL fournit un cadre robuste pour gérer efficacement des modèles à grande échelle. Pour utiliser efficacement les améliorations d'Unsloth :Recommandations matérielles

  • : Utilisation d'un NVIDIA H100 ou équivalent pour une utilisation optimale de la VRAM.Conseils de configuration : Assurez-vous que les paramètres et gradient_accumulation_steps batch_size

circle-check

pip install --upgrade --no-cache-dir unsloth unsloth_zoo Nos benchmarks mettent en évidence les économies de mémoire réalisées par rapport aux versions précédentes pour GPT OSS et Qwen3-8B. Les deux graphiques ci‑dessous (sansstandby ) ont été exécutés avec et batch_size = 4 gradient_accumulation_steps=2

, puisque standby, par conception, utilise toute la VRAM.

🔢Pour nos benchmarks, nous comparons BF16 GRPO à Hugging Face avec toutes les optimisations activées (tous les kernels de la bibliothèque kernels, Flash Attention 3, kernels de perte en morceaux, etc.) :

Découpage de longueur de séquence aplatie

Equation 1: Logit Memory (GB)=batch size×context length×vocab dim10243\text{Equation 1: } \text{Logit Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{vocab dim}}{1024^3}

Précédemment, Unsloth réduisait l'utilisation mémoire du RL en évitant la matérialisation complète du tenseur de logits en découpant sur la dimension du batch. Une estimation approximative de la VRAM requise pour matérialiser les logits durant la passe avant est montrée dans l'Équation (1). ) ont été exécutés avec, En utilisant cette formulation, une configuration avec, et context_length = 8192 vocab_dim = 128 000 requiert environ 3,3 Go de VRAM

pour stocker le tenseur de logits. Long Context gpt-oss Via l'année dernière, nous avons ensuite introduit une approche de perte fusionnée pour GRPO. Cette approche garantit qu'un seul échantillon de batch est traité à la fois, réduisant significativement l'utilisation mémoire maximale. Pour la même configuration, l'utilisation de la VRAM tombe à environ0,83 Go

Equation 2: Logit Memory (GB)=context length×vocab dim10243\text{Equation 2: }\text{Logit Memory (GB)} = \frac{\text{context length} \times \text{vocab dim}}{1024^3}
, comme reflété dans l'Équation (2).
Figure 1 : gpt-oss BF16 GRPO LoRA (Unsloth vs. HF avec toutes les optimisations activées)

Figure 2 : Qwen3-8B QLoRA GRPO LoRA (Unsloth vs. HF avec toutes les optimisations activées) Dans cette mise à jour, nous étendons la même idée en introduisant le découpage à travers la dimension de séquence également. Au lieu de matérialiser les logits pour l'intégralité de l'espace (batch_size × context_length)

d'un seul coup, nous aplatissons ces dimensions et les traitons en plus petits morceaux en utilisant un multiplicateur configurable. Cela permet à Unsloth de prendre en charge des contextes sensiblement plus longs sans augmenter l'utilisation mémoire maximale. Dans la Figure 5 ci‑dessous, nous utilisons un multiplicateur demax(4, context_length // 4096)) ont été exécutés avec, En utilisant cette formulation, une configuration avec, context_length = 8192, bien que n'importe quel multiplicateur puisse être spécifié selon le compromis mémoire–performance désiré. Avec ce paramètre, la même configuration d'exemple ( ) nécessite désormais seulement 0,207 Go de VRAM

Equation 3: Logit Memory (GB)=context lengthmultiplier×vocab dim10243\text{Equation 3: }\text{Logit Memory (GB)} = \frac{\frac{\text{context length}}{\text{multiplier}} \times \text{vocab dim}}{1024^3}
pour la matérialisation des logits.
Figure 3 : gpt-oss-20b (H100) Unsloth nouveau vs. ancien
Figure 4 : Qwen3-8B (H100) Unsloth nouveau vs. ancien
Figure 5 : gpt-oss-20b (H100)

Figure 6 : Qwen3-8B (B200) Cette mise à jour est reflétée dans le chunked_hidden_states_selective_log_softmaxcompilé ci‑dessous, qui prend désormais en charge le découpage à la fois sur les dimensions batch et séquence. Pour préserver le tenseur de logits ([batch_size, context_length, vocab_dim] ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé via unsloth_logit_chunk_multiplier Dans la Figure 5 ci‑dessous, nous utilisons un multiplicateur dedans la configuration GRPO ; s'il n'est pas défini, il prend par défaut . Dans l'exemple ci‑dessous, input_ids_chunk[0]

  1. temperature=temperature,

  2. Nous utilisons torch.compile avec des options de compilation personnalisées pour réduire la VRAM et augmenter la vitesse.

  3. Tous les logits découpés sont convertis en float32 pour préserver la précision.

👻Nous prenons en charge le softcapping des logits, le scaling de la température et toutes les autres fonctionnalités.

Découpage des états cachés Nous avons également observé qu'à des longueurs de contexte plus longues, les états cachés peuvent devenir un contributeur significatif à l'utilisation mémoire. Pour la démonstration, nous supposeronshidden_states_dim=4096

Hidden States Memory (GB)=batch size×context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{hidden states dim}}{1024^3}

. L'utilisation mémoire correspondante suit une formulation similaire au cas des logits, montrée ci‑dessous. Avec un et batch_size = 8context_length = 64000 , cela aboutirait à une utilisation de VRAM d'environ2 Go . Dans cette version, nous introduisons un découpage optionnel sur la dimension batch pour le tenseur des états cachés lors du calcul des log‑probabilités. Cela ferait que l'utilisation de la VRAM soit divisée par la taille du batch ou, dans ce cas, soit0,244 Go

Hidden States Memory (GB)=context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{context length} \times \text{hidden states dim}}{1024^3}

. Cela réduit la VRAM maximale requise pour matérialiser les états cachés, comme reflété dans l'équation mise à jour ci‑dessous : 500K Context Training Similaire à notre perte d'entropie croisée dans notre publication, la nouvelle implémentationajuste automatiquement le lotage des états cachés . Les utilisateurs peuvent également contrôler ce comportement viaunsloth_grpo_mini_batch . Les utilisateurs peuvent également contrôler ce comportement via . Cependant, augmenter

au‑delà de la valeur optimale peut introduire une légère augmentation des performances ou un ralentissement (généralement plus rapide) par rapport à l'ancienne fonction de perte.Cependant, lors d'une exécution GPT-OSS (context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2 ), définir et unsloth_grpo_mini_batch = 1 unsloth_logit_chunk_multiplier = 4 entraîne peu ou pas de dégradation de la vitesse tout en réduisant l'utilisation de la VRAM d'environ 5 Go

circle-check

🌵documentation RL avancée

Déchargement des activations pour le log softmax Lors du développement de cette version, nous avons découvert que lorsque l'on carrelait (tiling) sur la dimension batch pour les états cachés, les activations n'étaient pas déchargées après le calcul fusionné des logits et des logprobs. Comme les logits sont calculés un batch à la fois en utilisanthidden_states[i] @ lm_head

, la logique existante de déchargement des activations et de checkpointing des gradients, conçue pour fonctionner dans la passe forward du modèle, ne s'appliquait pas dans ce cas.

circle-check

), la passe backward nécessite la même quantité de mémoire sur le GPU indépendamment du fait que les activations soient déchargées. Étant donné que le déchargement des activations introduit un léger ralentissement des performances sans réduire l'utilisation mémoire dans ce cas, cela n'apporte aucun bénéfice.

Configuration des paramètres : . Les utilisateurs peuvent également contrôler ce comportement via et ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé viaSi vous ne configurez pas , nous ajusterons automatiquement ces deux paramètres

unsloth_logit_chunk_multiplier = 2 . Les utilisateurs peuvent également contrôler ce comportement via et ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé via Une visualisation des optimisations et de

peut être vue dans le schéma ci‑dessous. . Les utilisateurs peuvent également contrôler ce comportement via Les 3 matrices représentent le lot global plus large ou ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé via (représenté par le nombre de crochets noirs) et les lignes de chacune des matrices représentent la longueur de contexte que le

📼découpe de la longueur de séquence (représenté par le nombre de crochets rouges).

vLLM pour le RLPour les flux de travail RL, la phase d'inférence/génération est le principal goulot d'étranglement vLLMarrow-up-right. Pour y remédier, nous utilisons

, qui a accéléré la génération jusqu'à 11x par rapport à la génération normale. Depuis que GRPO a été popularisé l'année dernière, vLLM est devenu un composant central de la plupart des frameworks RL, y compris Unsloth. Nous souhaitons exprimer notre gratitude à l'équipe vLLM et à tous ses contributeurs pour leur travail, car ils jouent un rôle essentiel dans l'amélioration du RL d'Unsloth ! Notebooks GRPO Pour commencer, vous pouvez utiliser n'importe quel

- GSPO

Mis à jour

Ce contenu vous a-t-il été utile ?