🎱FP8 Reinforcement Learning

Trainiere Reinforcement Learning (RL) und GRPO in FP8-Präzision mit Unsloth.

Wir führen FP8-Genauigkeitstraining für RL ein und machen FP8 GRPO jetzt möglich auf Consumer-GPUs (RTX 40, 50 usw.). DeepSeek-R1 zeigte, wie leistungsfähig FP8 sein kann, und mit Unsloth funktioniert Qwen3-1.7B FP8 GRPO jetzt nur auf 5 GB VRAM.

Schnellere RL-Inferenz ist entscheidend, da sie die rechenintensivste Arbeitslast im RL ist. Wir haben mit TorchAOarrow-up-right von PyTorch zusammengearbeitet, um Leistungssteigerungen ohne Genauigkeitsverlust zu ermöglichen.

  • ~1,4× schneller RL-Inferenz über vLLMarrow-up-right • 2x längerer Kontext gegenüber BF16 und FP16

  • 60% weniger VRAM und 10× länger Kontext als andere FP8-RL-Implementierungen

  • Unsloth ist das einzige Framework das FP8 RL LoRA auf Consumer-GPUs (z. B. NVIDIA GeForce RTX 40 und 50 Serie) zum Laufen bringt. Funktioniert auch auf H100, H200, B200 usw.

  • Verwenden Sie load_in_fp8 = True innerhalb FastLanguageModel um FP8 RL zu aktivieren.

  • Obwohl Qwen3-8B in 16 GB VRAM passt, unterstützen freie Colab NVIDIA Tesla T4 GPUs kein FP8. Daher verwenden unsere Notebooks 24GB L4 GPUs, die Qwen3-14B aufnehmen.

Notebooks: Qwen3-8B FP8 GRPOarrow-up-right und Llama-3.2-1B FP8 GRPOarrow-up-right

circle-check

Unsere FP8-Unterstützung nutzt Unsloths Gewichts-Sharing-Funktion, wodurch der VRAM-Verbrauch um weitere 50%reduziert wird und 10× mehr Kontext ohne Genauigkeitsverlust ermöglicht. Wir verwenden vLLMarrow-up-right für schnelle Inferenz und Techniken wie Unsloth Standby und Flex Attention um den VRAM-Verbrauch weiter zu reduzieren. TorchAO ermöglicht universelles, on-the-fly FP8, sodass Llama, Gemma, Mistral & mehr funktionieren. Wir haben außerdem hochgeladen die meisten FP8-Modelle (einschließlich Qwen3).

Der Reward-Plot zeigt, dass FP8 dem gleichen Trend wie BF16 folgt

🌻FP8 vs BF16 Training

Forschung zeigt, dass FP8-Training die BF16-Genauigkeit weitgehend erreichen kann, und wenn du Modelle in FP8 bereitstellst, Training und Bereitstellung in derselben Präzision trägt das zur Erhaltung der Genauigkeit bei. Außerdem liefert FP8 gegenüber BF16 auf H100s eine 1,6x höhere Durchsatzrate und verwendet 2x weniger Speicher.

Gewichtsskalen & FP8-Typen

Quantisiertes Training speichert ein niedrigpräzises Gewicht (z. B. FP8) plus eine höherpräzise Skala (FP16/BF16/FP32). Du stellst das ursprüngliche Gewicht ungefähr so wieder her: original_weight ≈ quantized_weight * weight_scale

Die Skala bildet den Wertebereich des Gewichts in den darstellbaren Bereich von FP8 ab. Mehr Skalen verbessern normalerweise die Genauigkeit, aber Skalen kosten zusätzlichen hochpräzisen Speicher, daher ist es ein Kompromiss. DeepSeek R1arrow-up-right, zum Beispiel, bevorzugt überwiegend Block-Quantisierung.

Es gibt 3 gängige FP8-Typen, wie sie von vLLM's definiert werden llm-compressorarrow-up-right. Wir haben Qwen3-8B für alle 3 Typen benchmarked und auch Durchsatz, MMLU Pro und GQPA Diamond geprüft. Wir stellen fest, dass FP8 Block-Wise oder Per-Channel (-FP8-Dynamic) das beste in Bezug auf Genauigkeit und Durchsatz ist.

Typ
Durchsatz
MMLU Pro
GQPA Diamond

Bfloat16-Baseline

11,367

62.04%

28.79%

Block-wise

Skalen pro Block (128X128)

12,041

62.37%

29.29%

Per-Channel

1 Skala pro Zeile oder Spalte

12,963

61.89%

31.82%

Per-Tensor

1 Skala für das gesamte Tensor

13,681

61.83%

27.78%

FP8 Leistungs-Benchmarks

Unsloth FP8 RL-Inferenz über vLLM ist generell 1,4x schneller als BF16. Bei größeren Modellen sind noch höhere Geschwindigkeitsverbesserungen möglich!

Genauigkeit Trainingsverlust-Benchmarks

Wir testeten mehrere Modelle, darunter Qwen3-4B, 8B, 14B, Llama 3.2 1B, 3B, Qwen3-VL-2B, Qwen3-VL 4B und viele mehr. Alle wurden sowohl in BF16 als auch in FP8 trainiert. Wie in den Plots zu sehen ist, verfolgen die Verlustkurven während SFT für BF16 und FP8 einander eng. Es gibt kaum einen Unterschied zwischen den beiden Datentypen hinsichtlich des Trainingsverlusts:

Speziell für GRPO ist das Ziel aufgrund von Unterschiedlichkeiten bei der Generierung zu sehen, ob sich die Reward-Plots zumindest angleichen und nicht auseinanderlaufen (manchmal sind z. B. Qwen3-14B-Läufe nicht exakt ähnlich)

⛩️Inference = 96% des RL-Trainings

Im RL müssen wir das LLM/VLM aufrufen, um einige mögliche Kandidatenlösungen für einen Lauf zu generieren, dann bewerten wir jede mögliche Lösung und belohnen gute Lösungen und bestrafen schlechte Antworten. Um maximale Effizienz zu erreichen, müssen wir die Inferenz nahezu 100% des Trainingslaufs ausmachen. In Unsloth haben wir es geschafft, das Training auf weniger als 4% des gesamten RL-Laufs zu reduzieren, wobei 96% reine vLLM-Inferenz sind.

Zum Beispiel für Qwen-3-8B, das bei kürzeren Sequenzlängen 1,15x schneller ist, ist vLLM FP8 selbst für Inferenz (ohne Training) ebenfalls 1,15x schneller im Durchsatz. Wir sehen, dass unser RL-Lauf in Unsloth ebenfalls 1,15x schneller bei verarbeiteten Tokens ist, was zeigt, wie der Trainings-Overhead in Unsloth vernachlässigbar ist.

🔢60% weniger Speicherverbrauch

Theoretisch würde man erwarten, dass die Speicherersparnis ungefähr dem Gewichtsspeicher des Modells entspricht, weil: Optimizer-Zustände werden weiterhin in hoher Präzision gespeichert und Aktivierungen werden ebenfalls in hoher Präzision gespeichert (vorerst). Unsere Erkenntnisse stimmen mit der Theorie überein. Für LoRA-Finetuning beobachteten wir: ~30 GB eingespart für Qwen3-32B, ~14 GB eingespart für Qwen2.5-14B und ~8 GB eingespart für Qwen3-8B

Für BF16 LoRA-Finetuning auf Qwen3-32B, wir hatten OOMs bei größeren Batch-Größen und mussten die Batch verkleinern. Die FP8-Variante hatte solche Probleme nicht, und wir konnten größere Batch-Größen ohne OOM verwenden.

Außerdem Erinnerung: In Unsloth teilen wir den vLLM-Speicherraum für die Gewichte wie eingeführt in Speichereffizientes RL - wir haben diesen Trick in den FP8-Bereich übertragen!

80GB GPU
Inference Engine
Training Engine

Modellgewichte

8GB GETEILTER FP8

<<< GETEILT

Mehrzweck

72GB Platz

KV-Cache

Aktivierungen, Gradienten, Optimizer-Zustände

nn.Linear Unsloth Standby Für FP8 (oder BF16) RL füge einfach das Folgende zu allen RL/GRPO-Trainingsläufen hinzu, bevor irgendein Unsloth-Import erfolgt:

Wie man FP8 RL verwendet / Installation

Aktualisiere einfach Unsloth oder installiere Unsloth in einer neuen virtuellen Umgebung für H100, L4, RTX 50x, RTX 40x, H200s, B200s und jede NVIDIA-GPU (Consumer- oder Rechenzentrumsklasse), die nach der RTX 4090 veröffentlicht wurde.

Um Unsloth zu aktualisieren: pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zooOder erstelle eine neue Umgebung:

Verwenden Sie dann load_in_fp8 = True und du bist startklar! Wir werden den Modellnamen automatisch der Float8-Variante zuordnen, oder wir konvertieren das Modell on-the-fly in Float8!

Zum Beispiel auf einer RTX 5090 (Erinnerung: setze os.environ["UNSLOTH_VLLM_STANDBY"] = "1" )

Verwende dann unsere 2 FP8-Notebooks für RL:

💿Implementierung von FP8-Training

Unser erster Referenzpunkt war transformers, der FP8 bereits auf mehrere Arten unterstützt. Eine davon ist eine block-quantisierte Matmul-Implementierung: wenn eine Schicht 16-Bit-Aktivierungen erhält, quantisiert sie diese und übergibt sie an einen benutzerdefinierten FP8-Matmul-Kernel. Nachdem wir das verdrahtet und auf einer NVIDIA H100 gebenchmarkt hatten, sahen wir das Gegenteil von dem, was wir wollten: Finetuning wurde etwa 4× langsamer als standardmäßiges BF16-Finetuning.

🔥TorchAO Zusammenarbeit

Also arbeiteten wir mit dem TorchAOarrow-up-right Team (großer Dank an Andrewarrow-up-right) zusammen, um TorchAOs FP8-Unterstützung in unsere RL-Workloads zu integrieren und sahen etwa 1,4× höheren Durchsatz und bis zu 60% geringeren Modellspeicherverbrauch. Auf hoher Ebene:

  • Wir speichern die eingefrorenen LoRA-Gewichte in FP8.

  • Während des Vorwärtspasses wenden wir dynamische FP8-Quantisierung auf die Eingabeaktivierungen an, während die trainierbaren LoRA-Adapter in BF16 bleiben.

  • Diese FP8-Gewichte teilen die gleichen Puffer wie die vLLM-Modellgewichte, sodass zu jedem Zeitpunkt nur eine einzige FP8-Kopie des Modells im Speicher ist (keine doppelte Modell-Speicherüberhead).

  • Im Rückwärtspass dequantisieren wir die LoRA-Gewichte, sodass alle Gradientenberechnungen in BF16 für bessere Genauigkeit erfolgen.

Dieses allgemeine Setup funktioniert über alle unterstützten RL-Algorithmen hinweg, einschließlich GSPO, Dr. GRPO, PPO und DPO.

TorchAO bietet native PyTorch FP8-Unterstützung für Training und Inferenz und bietet eine Vielzahl von Skalierungsgranularitäten, einschließlich tensorweise, zeilenweise und 128x128 blockweise (Prototyp). TorchAOs FP8-Unterstützung kann die Inferenzdurchsatzrate um bis zu 1,64x bei 27B-Skalierungarrow-up-right mit zeilenweiser Skalierungsgranularität verbessern. Für weitere Details besuche das TorchAO FP8 READMEarrow-up-right.

TorchAOs block-quantisiertes FP8-Matmul

Wir verwendeten TorchAOs block‑quantisierte FP8-Matmul-Implementierung, die Folgendes lieferte:

  • 80% des BF16-Durchsatzes

  • Ohne Verlust der Trainingsstabilität oder Verschlechterung des Loss

Eine Zeit lang wurde dies unser Standard-FP8-Matmul-Backend, bis FBGEMM aufholte – wir verwenden jetzt standardmäßig FBGEMMs Implementierung, wenn deine GPU sie unterstützt! Die aktuelle Version von Unsloth kann automatisch das beste Backend basierend auf der Installation wählen. Wenn du die richtigen Pakete hast, musst du keine Performance-Potenziale verschenken 🙂

PS: Wir haben auch mit DeepSeeks DeepGEMM experimentiert, konnten es aber nicht vollständig end-to-end integrieren, um saubere, vergleichbare Tests durchzuführen.

🐦On-the-fly TorchAO FP8-Quantisierung

Großer Dank an Andrewarrow-up-right von TorchAO, Unsloth FP8 RL ermöglicht es dir auch, das Modell on-the-fly zu quantisieren, indem die Quantisierung zur Modellladezeit durchgeführt und an vLLM übergeben wird. Auf diese Weise musst du das Modell nicht explizit selbst quantisieren (wir übernehmen das für dich). Du kannst dies tun, indem du load_in_fp8 = True in den Modell-Ladeargumenten setzt, und wir führen eine Offline-FP8 durch, falls wir keinen geeigneten vor-quantisierten Checkpoint finden.

🎉Unsloth FP8 Uploads

Zur Bequemlichkeit haben wir FP8 Dynamic- und FP8 Block-Modelle auf Hugging Face hochgeladen. Du kannst sie für FP8-Training oder auch für effizientes & schnelles Serving/Deployment über vLLM/SGLang usw.

FP8 Dynamic bietet etwas schnelleres Training und geringeren VRAM-Verbrauch als FP8 Block, jedoch mit einem kleinen Genauigkeitskompromiss. Sieh hier für unsere vollständige Liste der FP8-Quantis, aber hier die populärsten:

💁Danksagungen

Großer Dank an das gesamte PyTorch- und TorchAO-Team für ihre Hilfe und Zusammenarbeit! Ein besonderer Dank geht an: Andrew Or, Jerry Zhang, Supriya Rao, Scott Roy und Mergen Nachin für die Unterstützung in vielen Diskussionen zu FP8 RL und bei der Integration in Unsloth! Auch Dank an das Executorch-Team!

Zuletzt aktualisiert

War das hilfreich?