🌀Reinforcement Learning GRPO mit 7x längerem Kontext

Lerne, wie Unsloth ultra-langes Kontext-RL-Finetuning ermöglicht.

Die größte Herausforderung des Reinforcement Learning (RL) besteht darin, lange Schlussfolgerungsketten zu unterstützen. Wir führen neue Batch-Algorithmen ein, um etwa7x längeren Kontext (kann mehr als 12x sein) RL-Training ohne Genauigkeits- oder Geschwindigkeitsverschlechterung gegenüber anderen optimierten Setups, die FA3, Kernel und chunked losses verwenden.

  • Unsloth trainiert jetzt gpt-oss QLoRA mit 380K Kontext auf einer einzelnen 192GB NVIDIA B200 GPU

  • Qwen3-8B GRPO erreicht 110K Kontext auf einer 80GB VRAM H100 via vLLM und QLoRA, und 65K für gpt-oss mit BF16 LoRA.

  • Auf 24GB VRAM erreicht gpt-oss 20K Kontext und 32K für Qwen3-VL-8B QLoRA

  • Unsloth GRPO RL läuft mit Llama, Gemma & alle Modelle unterstützen automatisch längere Kontexte

Unsere neuen Datenbewegungs- und Batch-Kernel und -Algorithmen erschließen mehr Kontext durch:

circle-info

Sie können alle Funktionen in Unsloth zusammen kombinieren:

  1. Unsloths Gewichts-Teilungs- Funktion mit vLLMarrow-up-right und unserer Standby-Funktion in Speichereffizientes RL

  2. Unsloths Flex Attention für langen Kontext gpt-oss und unser 500K Context Training

  3. Float8-Training in FP8 RL und Unsloths asynchrones Gradient-Checkpointingarrow-up-right und vieles mehr

🎉Erste Schritte

Um loszulegen, können Sie jedes vorhandene GRPO-Notebooks (oder aktualisieren Sie Unsloth, wenn lokal):

Die Einführung von Unsloth für Ihre RL-Aufgaben bietet ein robustes Framework zur effizienten Verwaltung großskaliger Modelle. Um Unsloths Verbesserungen effektiv zu nutzen:

  • Hardware-Empfehlungen: Verwendung einer NVIDIA H100 oder eines gleichwertigen Modells für optimale VRAM-Nutzung.

  • Konfigurations-Tipps: Stellen Sie sicher, dass batch_size und gradient_accumulation_steps Einstellungen mit Ihren Rechenressourcen für beste Leistung übereinstimmen.

circle-check

Unsere Benchmarks heben die erzielten Speicherersparnisse im Vergleich zu früheren Versionen für GPT OSS und Qwen3-8B hervor. Beide untenstehenden Diagramme (ohne standby) wurden mit batch_size = 4 und gradient_accumulation_steps=2 , da Standby per Design den gesamten VRAM verwendet.

Für unsere Benchmarks vergleichen wir BF16 GRPO mit Hugging Face mit allen aktivierten Optimierungen (alle Kernel in der Kernbibliothek, Flash Attention 3, chunked loss-Kernel, etc.):

🔢Abgeflachte Sequenzlängen-Chunking

Früher reduzierte Unsloth die Speichernutzung von RL, indem die vollständige Materialisierung des Logits-Tensors durch Chunking über die Batch-Dimension vermieden wurde. Eine grobe Schätzung des VRAMs, der benötigt wird, um Logits während des Forward-Passes zu materialisieren, wird in Gleichung (1) gezeigt.

Equation 1: Logit Memory (GB)=batch size×context length×vocab dim10243\text{Equation 1: } \text{Logit Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{vocab dim}}{1024^3}

Unter Verwendung dieser Formulierung würde eine Konfiguration mit batch_size = 4, context_length = 8192, und vocab_dim = 128.000 ungefähr benötigen 3,3 GB VRAM um den Logits-Tensor zu speichern.

Via Long Context gpt-oss letztes Jahr haben wir dann einen fused loss-Ansatz für GRPO eingeführt. Dieser Ansatz stellt sicher, dass jeweils nur eine einzelne Batch-Probe verarbeitet wird, und reduziert dadurch die Spitzen-Speicherauslastung erheblich. Unter derselben Konfiguration sinkt die VRAM-Nutzung auf ungefähr 0,83 GB, wie in Gleichung (2) dargestellt.

Equation 2: Logit Memory (GB)=context length×vocab dim10243\text{Equation 2: }\text{Logit Memory (GB)} = \frac{\text{context length} \times \text{vocab dim}}{1024^3}
Abbildung 1: gpt-oss BF16 GRPO LoRA (Unsloth vs. HF mit allen Optimierungen an)
Abbildung 2: Qwen3-8B QLoRA GRPO LoRA (Unsloth vs. HF mit allen Optimierungen an)

In diesem Update erweitern wir dieselbe Idee, indem wir Chunking über die Sequenz-Dimension ebenfalls einführen. Anstatt Logits für die gesamte (batch_size × context_length) Fläche auf einmal zu materialisieren, flachen wir diese Dimensionen ab und verarbeiten sie in kleineren Stücken mithilfe eines konfigurierbaren Multiplikators. Dadurch kann Unsloth wesentlich längere Kontexte unterstützen, ohne die Spitzen-Speicherauslastung zu erhöhen.

In Abbildung 5 unten verwenden wir einen Multiplikator von max(4, context_length // 4096), obwohl jeder Multiplikator je nach gewünschtem Speicher–Leistungs-Kompromiss angegeben werden kann. Mit dieser Einstellung benötigt die gleiche Beispielkonfiguration (batch_size = 4, context_length = 8192, vocab_dim = 128.000) nun nur noch 0,207 GB VRAM für die Materialisierung der Logits.

Equation 3: Logit Memory (GB)=context lengthmultiplier×vocab dim10243\text{Equation 3: }\text{Logit Memory (GB)} = \frac{\frac{\text{context length}}{\text{multiplier}} \times \text{vocab dim}}{1024^3}
Abbildung 3: gpt-oss-20b (H100) Unsloth neu vs. alt
Abbildung 4: Qwen3-8B (H100) Unsloth neu vs. alt
Abbildung 5: gpt-oss-20b (H100)
Abbildung 6: Qwen3-8B (B200)

Dieses Update spiegelt sich im kompilierten chunked_hidden_states_selective_log_softmax unten wider, das jetzt Chunking sowohl über die Batch- als auch die Sequenzdimension unterstützt. Um den Logits-Tensor ([batch_size, context_length, vocab_dim]) zu erhalten, wird er immer über die Batch-Dimension gechunked. Zusätzliches Sequenz-Chunking wird über unsloth_logit_chunk_multiplier in der GRPO-Konfiguration gesteuert; wenn es nicht gesetzt ist, ist der Standardwert max(4, context_length // 4096). Im folgenden Beispiel entspricht input_ids_chunk[0] der Größe der Hidden-States-Mini-Batches in Optimierung 2.

  1. Wir nutzen torch.compile mit benutzerdefinierten Compile-Optionen, um VRAM zu reduzieren und die Geschwindigkeit zu erhöhen.

  2. Alle gechunkten Logits werden in float32 hochskaliert, um die Genauigkeit zu bewahren.

  3. Wir unterstützen Logit-Softcapping, Temperature-Scaling und alle anderen Funktionen.

👻Hidden-States-Chunking

Wir haben auch beobachtet, dass bei längeren Kontextlängen die Hidden States einen erheblichen Beitrag zur Speichernutzung leisten können. Zur Demonstration nehmen wir an hidden_states_dim=4096. Die entsprechende Speichernutzung folgt einer ähnlichen Formulierung wie im Logits-Fall, die unten gezeigt wird.

Hidden States Memory (GB)=batch size×context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{batch size} \times\text{context length} \times \text{hidden states dim}}{1024^3}

Mit einem batch_size = 8 und context_length = 64000, würde dies zu einer VRAM-Nutzung von ungefähr führen 2 GB. In dieser Version führen wir optionales Chunking über die Batch-Dimension für den Hidden-States-Tensor während der Log-Probabilitätsberechnung ein. Dadurch würde die VRAM-Nutzung durch die Batch-Größe geteilt oder in diesem Fall 0,244 GB. Dies reduziert den Spitzen-VRAM, der zur Materialisierung der Hidden States erforderlich ist, wie in der aktualisierten Gleichung unten reflektiert:

Hidden States Memory (GB)=context length×hidden states dim10243\text{Hidden States Memory (GB)} = \frac{\text{context length} \times \text{hidden states dim}}{1024^3}

Ähnlich wie bei unserem Cross-Entropy-Loss in unserer 500K Context Training Version, passt die neue Implementierung automatisch das Hidden-State-Batching an. Benutzer können dieses Verhalten auch über unsloth_grpo_mini_batchsteuern. Allerdings kann die Erhöhung von unsloth_grpo_mini_batch über den optimalen Wert hinaus eine leichte Leistungsverbesserung oder -verlangsamung (normalerweise schneller) im Vergleich zur vorherigen Verlustfunktion verursachen.

Während eines GPT-OSS-Laufs (context_length = 8192, batch_size = 4, gradient_accumulation_steps = 2) führt das Setzen von unsloth_grpo_mini_batch = 1 und unsloth_logit_chunk_multiplier = 4 dazu, dass kaum bis keine Geschwindigkeitsverschlechterung auftritt, während der VRAM-Verbrauch um etwa 5 GB reduziert wird im Vergleich zu älteren Versionen von Unsloth.

circle-check

🌵Auslagern von Aktivierungen für Log-Softmax

Während der Entwicklung dieser Version entdeckten wir, dass beim Tiling über die Batch-Dimension für Hidden States die Aktivierungen nach der Berechnung der fused logits und logprobs nicht ausgelagert wurden. Da Logits jeweils für eine Batch mit hidden_states[i] @ lm_headberechnet werden, galt die vorhandene Aktivierungs-Auslagerungs- und Gradient-Checkpointing-Logik, die innerhalb des Forward-Passes des Modells arbeitet, in diesem Fall nicht.

Um dies zu beheben, haben wir explizite Logik hinzugefügt, um diese Aktivierungen außerhalb des Forward-Passes des Modells auszulagern, wie im folgenden Python-Pseudocode gezeigt:

circle-check

Parameter konfigurieren:

Wenn Sie unsloth_grpo_mini_batch und unsloth_logit_chunk_multipliernicht konfigurieren, werden wir diese beiden Parameter automatisch für Sie abstimmen basierend auf Ihrem verfügbaren VRAM und abhängig von der Größe Ihrer Kontextlänge. Unten jedoch erfahren Sie, wie Sie diese Variablen in Ihrem GRPO-Lauf ändern können:

Eine Visualisierung der Optimierungen und unsloth_grpo_mini_batch und unsloth_logit_chunk_multiplier ist im folgenden Diagramm zu sehen.

Die 3 Matrizen repräsentieren die insgesamt größere Batch oder unsloth_grpo_mini_batch (dargestellt durch die Anzahl der schwarzen Klammern) und die Zeilen jeder der Matrizen repräsentieren die Kontextlänge, durch die die unsloth_logit_chunk_multiplier die Sequenzlänge chunkt (dargestellt durch die Anzahl der roten Klammern).

📼vLLM für RL

Für RL-Workflows ist die Inferenz-/Generierungsphase der Hauptengpass. Um dem zu begegnen, nutzen wir vLLMarrow-up-right, das die Generierung im Vergleich zur normalen Generierung um bis zu 11x beschleunigt hat. Seit GRPO im letzten Jahr populär wurde, ist vLLM ein Kernbestandteil der meisten RL-Frameworks, einschließlich Unsloth. Wir möchten dem vLLM-Team und allen Mitwirkenden unseren Dank aussprechen, da sie eine entscheidende Rolle dabei spielen, Unsloths RL zu verbessern!

Um RL mit längerem Kontext zu testen, können Sie jedes vorhandene GRPO-Notebooks (oder aktualisieren Sie Unsloth, wenn lokal):

Danksagungen: Ein großer Dank an das Hugging Face-Team und die Bibliotheken für die Unterstützung von Unsloth und das Ermöglichen dieses Fortschritts.

Zuletzt aktualisiert

War das hilfreich?