🎱Apprentissage par renforcement en FP8

Entraînez l'apprentissage par renforcement (RL) et GRPO en précision FP8 avec Unsloth.

Nous introduisons l'entraînement en précision FP8 pour le RL, rendant FP8 GRPO désormais possible sur GPU grand public (RTX 40, 50, etc.). DeepSeek-R1 a démontré la puissance du FP8 et avec Unsloth, Qwen3-1.7B FP8 GRPO fonctionne maintenant simplement sur 5 Go de VRAM.

Une inférence RL plus rapide est critique car c'est la charge de calcul la plus intensive en RL. Nous avons collaboré avec TorchAOarrow-up-right de PyTorch pour permettre des gains de performance sans perte de précision.

  • ≈1,4× plus rapide inférence RL via vLLMarrow-up-right • contexte 2x plus long vs. BF16 et FP16

  • 60% de VRAM en moins et 10× plus long contexte que d'autres implémentations FP8 RL

  • Unsloth est le seul framework à faire fonctionner FP8 RL LoRA sur les GPU grand public (p. ex. NVIDIA GeForce RTX 40 et 50 Series). Fonctionne aussi sur H100, H200, B200, etc.

  • Utilisez load_in_fp8 = True dans FastLanguageModel pour activer FP8 RL.

  • Bien que Qwen3-8B tienne dans 16 Go de VRAM, les GPU NVIDIA Tesla T4 gratuits de Colab ne prennent pas en charge FP8. Donc nos notebooks utilisent des GPU L4 de 24 Go qui contiennent Qwen3-14B.

Notebooks : Qwen3-8B FP8 GRPOarrow-up-right et Llama-3.2-1B FP8 GRPOarrow-up-right

circle-check

Notre support FP8 utilise la fonctionnalité de partage de poids, réduisant l'utilisation de la VRAM de 50%, permettant 10× plus de contexte sans perte de précision. Nous utilisons vLLMarrow-up-right pour une inférence rapide et, nos techniques comme Unsloth Standby et et notre fonctionnalité Standby dans pour réduire encore l'utilisation de la VRAM. TorchAO permet le FP8 universel à la volée, donc Llama, Gemma, Mistral et d'autres fonctionnent. Nous avons aussi téléversé la plupart des modèles FP8 (y compris Qwen3).

Le graphique de récompense montre que FP8 suit la même tendance que BF16

🌻FP8 vs BF16 Entraînement

La recherche montre que l'entraînement en FP8 peut en grande partie égaler la précision du BF16 et si vous servez des modèles en FP8, entraîner et servir dans la même précision aide à préserver la précision. De plus, FP8 vs BF16 offre un débit 1,6× supérieur sur les H100 et utilise 2× moins de mémoire.

Échelles de poids et types FP8

L'entraînement quantifié stocke un poids en basse précision (par ex., FP8) plus une échelle en haute précision (FP16/BF16/FP32). Vous récupérez approximativement le poids original via : original_weight ≈ quantized_weight * weight_scale

L'échelle mappe la plage du poids à la plage représentable par le FP8. Plus d'échelles améliorent généralement la précision, mais les échelles coûtent de la mémoire en haute précision, donc c'est un compromis. DeepSeek R1arrow-up-right, par exemple, favorise principalement la quantification par bloc.

Il existe 3 types FP8 courants tels que définis par llm-compressorarrow-up-right. Nous avons benchmarké Qwen3-8B sur les 3 types, et avons aussi vérifié le débit, MMLU Pro et GQPA Diamond. Nous trouvons que FP8 bloc-wise ou par canal (-FP8-Dynamic) est le meilleur en termes de précision et de débit.

Type
Débit
MMLU Pro
GQPA Diamond

Baseline Bfloat16

11,367

62.04%

28.79%

Bloc-wise

Échelles par bloc (128X128)

12,041

62.37%

29.29%

Par canal

1 échelle par ligne ou colonne

12,963

61.89%

31.82%

Par tenseur

1 échelle pour tout le tenseur

13,681

61.83%

27.78%

Benchmarks de performance FP8

L'inférence RL Unsloth FP8 via vLLM est généralement 1,4x plus rapide que BF16. Vous pouvez observer encore plus d'améliorations de vitesse si le modèle est plus grand !

Précision Benchmarks de perte d'entraînement

Nous avons testé plusieurs modèles incluant Qwen3-4B, 8B, 14B, Llama 3.2 1B, 3B, Qwen3-VL-2B, Qwen3-VL 4B et bien d'autres. Tous ont été entraînés à la fois en BF16 et FP8. Comme vu dans les graphiques, les courbes de perte pendant le SFT pour BF16 et FP8 se suivent de près. Il n'y a pas grand-chose à choisir entre les deux types de données en termes de perte d'entraînement :

Pour GRPO spécifiquement, en raison des différences de génération, le but est de voir si les graphiques de récompense au moins correspondent et ne divergent pas (parfois, par ex. les runs Qwen3-14B peuvent ne pas être exactement similaires)

⛩️L'inférence = 96% de l'entraînement RL

En RL, nous devons appeler le LLM / VLM pour générer quelques solutions candidates pour une exécution, puis nous notons chaque solution possible et récompenser les bonnes solutions, et pénaliser les mauvaises réponses. Pour atteindre une efficacité maximale, nous devons rendre l'inférence presque 100% du run d'entraînement. Dans Unsloth, nous avons réussi à faire en sorte que l'entraînement ne prenne que <4% de l'ensemble du run RL, 96% étant purement de l'inférence vLLM.

Par exemple pour Qwen-3-8B, qui est 1,15x plus rapide sur des longueurs de séquence plus courtes, vLLM FP8 lui-même pour l'inférence (sans entraînement) a aussi un débit 1,15x plus rapide. Nous voyons notre run RL dans Unsloth obtenir aussi 1,15x plus rapide sur les tokens traités, montrant à quel point le surcoût d'entraînement est négligeable dans Unsloth.

🔢60% moins d'utilisation mémoire

En théorie, vous vous attendriez à ce que les économies de mémoire soient à peu près égales à la mémoire des poids du modèle, parce que : les états de l'optimiseur sont toujours stockés en haute précision et les activations sont aussi stockées en haute précision (pour l'instant). Nos constats correspondent à la théorie. Pour le fine-tuning LoRA, nous avons observé : ≈30 Go économisés pour Qwen3-32B, ≈14 Go économisés pour Qwen2.5-14B et ≈8 Go économisés pour Qwen3-8B

Pour Fine-tuning LoRA BF16 sur Qwen3-32B, nous faisions des OOM à des tailles de batch plus élevées et avons dû réduire la taille du batch. Le variant FP8 n'a pas eu de tels problèmes, et nous pouvions utiliser des tailles de batch plus grandes sans OOM.

Aussi rappel : dans Unsloth nous partageons l'espace mémoire de vLLM pour les poids comme introduit dans RL efficace en mémoire — nous avons apporté cette astuce au domaine FP8 !

GPU 80 Go
Moteur d'inférence
Moteur d'entraînement

Poids du modèle

8Go PARTAGÉS FP8

<<< PARTAGÉ

Polyvalent

Espace de 72 Go

Cache KV

Activations, gradients, états de l'optimiseur

Pour permettre la Unsloth Standby pour le RL en FP8 (ou BF16), ajoutez simplement ce qui suit à tous les runs d'entraînement RL / GRPO avant tout import Unsloth :

Comment utiliser FP8 RL / installation

Mettez simplement à jour Unsloth ou installez Unsloth dans un nouvel environnement virtuel pour H100, L4, RTX 50x, RTX 40x, H200s, B200s, et tout GPU NVIDIA (grand public ou datacenter) sorti après le RTX 4090.

Pour mettre à jour Unsloth : pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zooOu créez un nouvel environnement :

Puis utilisez load_in_fp8 = True et c'est parti ! Nous mapperons automatiquement le nom du modèle à la variante Float8, ou nous convertirons le modèle en Float8 à la volée !

Par exemple sur une RTX 5090 (rappel de définir os.environ["UNSLOTH_VLLM_STANDBY"] = "1" )

Ensuite, utilisez nos 2 notebooks FP8 pour le RL :

💿Implémentation de l'entraînement FP8

Notre premier point de référence était transformers, qui prend déjà en charge le FP8 de plusieurs manières. L'une d'elles est une implémentation matmul quantifiée par bloc : lorsqu'une couche reçoit des activations 16 bits, elle les quantifie et les passe à un noyau matmul FP8 personnalisé. Après avoir connecté cela et benchmarké sur une NVIDIA H100, nous avons vu le contraire de ce que nous voulions : le fine-tuning est devenu environ 4× plus lent que le fine-tuning standard en BF16.

🔥Collab TorchAO

Nous avons donc travaillé avec l'équipe TorchAOarrow-up-right (énormes remerciements à Andrewarrow-up-right) pour incorporer le support FP8 de TorchAO dans nos charges RL et avons observé environ 1,4× de débit plus élevé et jusqu'à 60% d'utilisation mémoire du modèle en moins. À un niveau élevé :

  • Nous stockons les poids LoRA gelés en FP8.

  • Pendant la passe avant, nous appliquons une quantification FP8 dynamique aux activations d'entrée, tout en gardant les adaptateurs LoRA entraînables en BF16.

  • Ces poids FP8 partagent les mêmes buffers que les poids du modèle vLLM, donc il n'y a qu'une seule copie FP8 du modèle en mémoire à tout moment (pas de surcharge mémoire de « double modèle »).

  • Dans la passe arrière, nous déquantifions les poids LoRA afin que tout le calcul des gradients soit effectué en BF16 pour une meilleure précision.

Cette configuration générale fonctionne pour tous les algorithmes RL supportés, y compris GSPO, Dr. GRPO, PPO et DPO.

TorchAO fournit un support FP8 natif PyTorch pour l'entraînement et l'inférence, offrant une variété de granularités d'échelle incluant tensorwise, row-wise et blockwise 128x128 (prototype). Le support FP8 de TorchAO peut améliorer le débit d'inférence jusqu'à 1,64x à l'échelle 27Barrow-up-right avec granularité d'échelle par ligne. Pour plus de détails, visitez le README FP8arrow-up-right.

Matmul FP8 quantifié par bloc de TorchAO

Nous avons utilisé l'implémentation matmul FP8 quantifiée par bloc de TorchAO qui a fourni :

  • 80% du débit BF16

  • Sans dégrader la perte ni la stabilité d'entraînement

Donc pendant un certain temps, cela est devenu notre backend matmul FP8 par défaut, jusqu'à ce que FBGEMM rattrape son retard — nous utilisons désormais par défaut l'implémentation de FBGEMM, si votre GPU la supporte ! La version actuelle d'Unsloth peut choisir automatiquement le meilleur backend en fonction de ce qui est installé. Si vous avez les bons paquets, vous n'avez pas à laisser de performance sur la table 🙂

PS : Nous avons également expérimenté DeepGEMM de DeepSeek, mais n'avons pas pu l'intégrer complètement de bout en bout pour obtenir des comparaisons propres et équitables.

🐦Quantification FP8 à la volée TorchAO

Un immense merci à Andrewarrow-up-right de TorchAO, Unsloth FP8 RL vous permet aussi de quantifier le modèle à la volée en effectuant la quantification lors du chargement du modèle et en la transmettant ensuite à vLLM. De cette façon, vous n'avez pas besoin de quantifier explicitement le modèle vous-même (nous nous en chargeons). Vous pouvez le faire en définissant load_in_fp8 = True dans les arguments de chargement du modèle, et nous ferons du FP8 hors ligne si nous ne trouvons pas de checkpoint pré-quantifié approprié.

🎉Téléversements FP8 Unsloth

Pour plus de commodité, nous avons téléversé des modèles FP8 Dynamic et FP8 Block sur Hugging Face. Vous pouvez les utiliser pour l'entraînement FP8 ou aussi pour le service/déploiement efficace et rapide via vLLM/SGLang etc.

FP8 Dynamic offre un entraînement légèrement plus rapide et une utilisation de VRAM inférieure à FP8 Block, mais avec un petit compromis sur la précision. Voir ici pour notre liste complète de quantifications FP8, mais voici les plus populaires :

💁Remerciements

Un énorme merci à l'ensemble des équipes PyTorch et TorchAO pour leur aide et collaboration ! Un grand merci en particulier à : Andrew Or, Jerry Zhang, Supriya Rao, Scott Roy et Mergen Nachin pour leur aide lors de nombreuses discussions sur le FP8 RL, et pour l'intégration dans Unsloth ! Merci également à l'équipe Executorch !

Mis à jour

Ce contenu vous a-t-il été utile ?