🌀GRPO RL avec contexte 7x plus long
Découvrez comment Unsloth permet l'ajustement RL à contexte ultra-long.
Le plus grand défi de l'apprentissage par renforcement (RL) est de prendre en charge de longues traces de raisonnement. Nous présentons de nouveaux algorithmes de mise en lot (batching) pour permettre ~contexte ~7x plus long (peut dépasser 12x) Entraînement RL sans dégradation de la précision ou de la vitesse par rapport à d'autres configurations optimisées qui utilisent FA3, des kernels et des pertes en morceaux.
Unsloth entraîne désormais gpt-oss QLoRA avec contexte 380K sur un seul GPU NVIDIA B200 de 192 Go
Sur 24 Go VRAM, gpt-oss atteint 20K de contexte et 32K pour Qwen3-VL-8B QLoRA
Les exécutions RL Unsloth GRPO fonctionnent avec Llama, Gemma et tous les modèles prennent automatiquement en charge des contextes plus longs
Nos nouveaux kernels et algorithmes de déplacement de données et de mise en lot débloquent plus de contexte en :
Découpage dynamique de séquences aplaties en morceaux pour éviter la matérialisation de tenseurs de logits massifs et
Déchargement des activations de log softmax ce qui empêche la croissance silencieuse de la mémoire au fil du temps.
Vous pouvez combiner toutes les fonctionnalités d'Unsloth ensemble :
La partage de poids d'Unsloth vLLM avec RL efficace en mémoire
La et notre fonctionnalité Standby dans Flex Attention 500K Context Training
🎉et bien plus
Premiers pas Notebooks GRPO Pour commencer, vous pouvez utiliser n'importe quel
GPU L4
Adopter Unsloth pour vos tâches RL fournit un cadre robuste pour gérer efficacement des modèles à grande échelle. Pour utiliser efficacement les améliorations d'Unsloth :Recommandations matérielles
: Utilisation d'un NVIDIA H100 ou équivalent pour une utilisation optimale de la VRAM.Conseils de configuration
: Assurez-vous que les paramètresetgradient_accumulation_stepsbatch_size
sont alignés sur vos ressources informatiques pour de meilleures performances.
pip install --upgrade --no-cache-dir unsloth unsloth_zoo Nos benchmarks mettent en évidence les économies de mémoire réalisées par rapport aux versions précédentes pour GPT OSS et Qwen3-8B. Les deux graphiques ci‑dessous (sansstandby ) ont été exécutés avec et batch_size = 4 gradient_accumulation_steps=2
, puisque standby, par conception, utilise toute la VRAM.
🔢Pour nos benchmarks, nous comparons BF16 GRPO à Hugging Face avec toutes les optimisations activées (tous les kernels de la bibliothèque kernels, Flash Attention 3, kernels de perte en morceaux, etc.) :
Découpage de longueur de séquence aplatie
Précédemment, Unsloth réduisait l'utilisation mémoire du RL en évitant la matérialisation complète du tenseur de logits en découpant sur la dimension du batch. Une estimation approximative de la VRAM requise pour matérialiser les logits durant la passe avant est montrée dans l'Équation (1). ) ont été exécutés avec, En utilisant cette formulation, une configuration avec, et context_length = 8192 vocab_dim = 128 000 requiert environ 3,3 Go de VRAM
pour stocker le tenseur de logits. Long Context gpt-oss Via l'année dernière, nous avons ensuite introduit une approche de perte fusionnée pour GRPO. Cette approche garantit qu'un seul échantillon de batch est traité à la fois, réduisant significativement l'utilisation mémoire maximale. Pour la même configuration, l'utilisation de la VRAM tombe à environ0,83 Go


Figure 2 : Qwen3-8B QLoRA GRPO LoRA (Unsloth vs. HF avec toutes les optimisations activées) Dans cette mise à jour, nous étendons la même idée en introduisant le découpage à travers la dimension de séquence également. Au lieu de matérialiser les logits pour l'intégralité de l'espace (batch_size × context_length)
d'un seul coup, nous aplatissons ces dimensions et les traitons en plus petits morceaux en utilisant un multiplicateur configurable. Cela permet à Unsloth de prendre en charge des contextes sensiblement plus longs sans augmenter l'utilisation mémoire maximale. Dans la Figure 5 ci‑dessous, nous utilisons un multiplicateur demax(4, context_length // 4096)) ont été exécutés avec, En utilisant cette formulation, une configuration avec, context_length = 8192, bien que n'importe quel multiplicateur puisse être spécifié selon le compromis mémoire–performance désiré. Avec ce paramètre, la même configuration d'exemple ( ) nécessite désormais seulement 0,207 Go de VRAM




Figure 6 : Qwen3-8B (B200) Cette mise à jour est reflétée dans le chunked_hidden_states_selective_log_softmaxcompilé ci‑dessous, qui prend désormais en charge le découpage à la fois sur les dimensions batch et séquence. Pour préserver le tenseur de logits ([batch_size, context_length, vocab_dim] ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé via unsloth_logit_chunk_multiplier Dans la Figure 5 ci‑dessous, nous utilisons un multiplicateur dedans la configuration GRPO ; s'il n'est pas défini, il prend par défaut . Dans l'exemple ci‑dessous, input_ids_chunk[0]
temperature=temperature,
Nous utilisons torch.compile avec des options de compilation personnalisées pour réduire la VRAM et augmenter la vitesse.
Tous les logits découpés sont convertis en float32 pour préserver la précision.
👻Nous prenons en charge le softcapping des logits, le scaling de la température et toutes les autres fonctionnalités.
Découpage des états cachés Nous avons également observé qu'à des longueurs de contexte plus longues, les états cachés peuvent devenir un contributeur significatif à l'utilisation mémoire. Pour la démonstration, nous supposeronshidden_states_dim=4096
. L'utilisation mémoire correspondante suit une formulation similaire au cas des logits, montrée ci‑dessous. Avec un et batch_size = 8context_length = 64000 , cela aboutirait à une utilisation de VRAM d'environ2 Go . Dans cette version, nous introduisons un découpage optionnel sur la dimension batch pour le tenseur des états cachés lors du calcul des log‑probabilités. Cela ferait que l'utilisation de la VRAM soit divisée par la taille du batch ou, dans ce cas, soit0,244 Go
. Cela réduit la VRAM maximale requise pour matérialiser les états cachés, comme reflété dans l'équation mise à jour ci‑dessous : 500K Context Training Similaire à notre perte d'entropie croisée dans notre publication, la nouvelle implémentationajuste automatiquement le lotage des états cachés . Les utilisateurs peuvent également contrôler ce comportement viaunsloth_grpo_mini_batch . Les utilisateurs peuvent également contrôler ce comportement via . Cependant, augmenter
au‑delà de la valeur optimale peut introduire une légère augmentation des performances ou un ralentissement (généralement plus rapide) par rapport à l'ancienne fonction de perte.Cependant, lors d'une exécution GPT-OSS (context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2 ), définir et unsloth_grpo_mini_batch = 1 unsloth_logit_chunk_multiplier = 4 entraîne peu ou pas de dégradation de la vitesse tout en réduisant l'utilisation de la VRAM d'environ 5 Go

Remarque : par rapport aux anciennes versions d'Unsloth. Dans les Figures 3 et 4, nous utilisons la taille de batch effective maximale, qui est 8 dans cette configuration. La taille de batch effective est calculée commebatch_size × gradient_accumulation_steps , donnant4 × 2 = 8 . Pour une explication plus approfondie de la façon dont fonctionnent les tailles de batch effectives en RL, voir notre.
🌵documentation RL avancée
Déchargement des activations pour le log softmax Lors du développement de cette version, nous avons découvert que lorsque l'on carrelait (tiling) sur la dimension batch pour les états cachés, les activations n'étaient pas déchargées après le calcul fusionné des logits et des logprobs. Comme les logits sont calculés un batch à la fois en utilisanthidden_states[i] @ lm_head
, la logique existante de déchargement des activations et de checkpointing des gradients, conçue pour fonctionner dans la passe forward du modèle, ne s'appliquait pas dans ce cas.
Remarque : torch.autograd.backward(output, grad_output) Cette fonctionnalité n'est efficace que lors du découpage sur la dimension batch ou lorsqueunsloth_grpo_mini_batch > 1 ), définir. Si tous les états cachés sont matérialisés d'un seul coup pendant la passe forward (c'est‑à‑dire,
✨), la passe backward nécessite la même quantité de mémoire sur le GPU indépendamment du fait que les activations soient déchargées. Étant donné que le déchargement des activations introduit un léger ralentissement des performances sans réduire l'utilisation mémoire dans ce cas, cela n'apporte aucun bénéfice.
Configuration des paramètres : . Les utilisateurs peuvent également contrôler ce comportement via et ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé viaSi vous ne configurez pas , nous ajusterons automatiquement ces deux paramètres
unsloth_logit_chunk_multiplier = 2 . Les utilisateurs peuvent également contrôler ce comportement via et ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé via Une visualisation des optimisations et de

peut être vue dans le schéma ci‑dessous. . Les utilisateurs peuvent également contrôler ce comportement via Les 3 matrices représentent le lot global plus large ou ), il est toujours découpé sur la dimension batch. Le découpage supplémentaire sur la séquence est contrôlé via (représenté par le nombre de crochets noirs) et les lignes de chacune des matrices représentent la longueur de contexte que le
📼découpe de la longueur de séquence (représenté par le nombre de crochets rouges).
vLLM pour le RLPour les flux de travail RL, la phase d'inférence/génération est le principal goulot d'étranglement vLLM. Pour y remédier, nous utilisons
, qui a accéléré la génération jusqu'à 11x par rapport à la génération normale. Depuis que GRPO a été popularisé l'année dernière, vLLM est devenu un composant central de la plupart des frameworks RL, y compris Unsloth. Nous souhaitons exprimer notre gratitude à l'équipe vLLM et à tous ses contributeurs pour leur travail, car ils jouent un rôle essentiel dans l'amélioration du RL d'Unsloth ! Notebooks GRPO Pour commencer, vous pouvez utiliser n'importe quel
gpt-oss-20b Pour essayer le RL à plus long contexte, vous pouvez utiliser n'importe quel
- GSPO
Mis à jour
Ce contenu vous a-t-il été utile ?

