# FP16 vs BF16 pour le RL

### Float16 vs Bfloat16

Il y avait un article intitulé "**Défaire le décalage entraînement-inférence via FP16**" <https://arxiv.org/pdf/2510.26788> montrant comment l'utilisation de la précision float16 peut être considérablement meilleure que l'utilisation de bfloat16 lors d'un apprentissage par renforcement.

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2Frec4qe1aQS0xyMzGvS9c%2Fimage.png?alt=media&#x26;token=2137e766-0f1f-48ec-b25f-2292d6f149f4" alt=""><figcaption></figcaption></figure>

En fait, plus la génération est longue, plus c'est pire lors de l'utilisation de bfloat16 :

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FWs7ioB2lraTbDbUCOAnn%2Fimage.png?alt=media&#x26;token=ac2b4f8e-210f-4bcc-bcbb-6e68f80781a6" alt=""><figcaption></figcaption></figure>

Nous avons mené une enquête, et **constatons que le float16 est plus stable** que le bfloat16 avec des normes de gradient bien plus petites voir <https://x.com/danielhanchen/status/1985557028295827482> et <https://x.com/danielhanchen/status/1985562902531850472>

{% columns %}
{% column width="50%" %}

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FhvQ1W5wtV6TTfsetp7y2%2FG44d7ZFbIAANBBd.jpg?alt=media&#x26;token=35181a07-de3e-4321-b54e-4436b4a201ff" alt=""><figcaption></figcaption></figure>

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2F62HkxnGcaKvxnSxbZMZu%2FG44c20SbwAAGo8j.jpg?alt=media&#x26;token=e0c7ecb8-6f0c-4ecf-b1a0-50f1b2a9a807" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column width="50%" %}

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2Fsi18IkGqE4IuUvzroyHh%2FG44ix5FbQAM0L5l.jpg?alt=media&#x26;token=bc3b97ce-5df4-4b69-aa50-a8e339f21601" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

### :exploding\_head:Bug d'attention en cascade sur A100

Comme indiqué par <https://x.com/RichardYRLi/status/1984858850143715759> et <https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda>, les anciennes versions de vLLM (avant 0.11.0) avaient des mécanismes d'attention défectueux pour les A100 et GPU similaires. Veuillez mettre à jour vLLM ! Nous désactivons également par défaut l'attention en cascade dans vLLM lors de l'apprentissage par renforcement Unsloth si nous détectons une ancienne version de vLLM.

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FnkCLRVIIGLADXBSCe58e%2Fimage.png?alt=media&#x26;token=6669642f-8690-44bf-b2de-6aa89acf2332" alt=""><figcaption></figcaption></figure>

Différents matériels modifient également les résultats, où les GPU plus récents et plus coûteux présentent une moindre différence KL entre l'inférence et l'entraînement :

<figure><img src="https://550366147-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FaroTTz68zzyofy6nagtH%2Fimage.webp?alt=media&#x26;token=3be09506-b8a0-42eb-8d17-af72496a9cd1" alt=""><figcaption></figcaption></figure>

### :fire:Utiliser le float16 dans Unsloth RL

Pour utiliser la précision float16 dans Unsloth GRPO et RL, il vous suffit de définir `dtype = torch.float16` et nous nous occupons du reste !

{% code overflow="wrap" %}

```python
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Peut être augmenté pour des traces de raisonnement plus longues
lora_rank = 32 # Rang plus grand = plus intelligent, mais plus lent

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Base",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False pour LoRA 16bit
    fast_inference = True, # Activer l'inférence rapide vLLM
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.9, # Réduire si mémoire insuffisante
    
    dtype = torch.float16, # Utiliser torch.float16, torch.bfloat16
)
```

{% endcode %}
