# gpt-oss 长上下文训练

我们很高兴为 OpenAI gpt-oss 训练推出 Unsloth Flex Attention 支持，它能够实现 **>8× 更长的上下文长度**, **>50% 更少的 VRAM 占用** 和 **>1.5× 更快的训练（且不会降低准确率）** 相比所有实现，包括使用 Flash Attention 3（FA3）的实现。Unsloth Flex Attention 使得可以在 **60K 上下文长度** 的 80GB VRAM H100 GPU 上进行 BF16 LoRA 训练。此外：

* 你现在可以 [导出/保存](#new-saving-to-gguf-vllm-after-gpt-oss-training) 你的 QLoRA 微调 gpt-oss 模型到 llama.cpp、vLLM、Ollama 或 HF
* 我们 [**修复了 gpt-oss 训练中**](#bug-fixes-for-gpt-oss) **loss 变为无穷大** 在 float16 GPU 上（例如 T4 Colab）
* 我们 [修复了 gpt-oss 实现中的](#bug-fixes-for-gpt-oss) 与 Unsloth 无关的问题，尤其是确保 `swiglu_limit = 7.0` 在 transformers 中的 MXFP4 推理过程中被正确应用

## 🦥介绍 Unsloth Flex Attention 支持

借助 Unsloth 的 Flex Attention 支持，一块 80GB VRAM 的 H100 可以使用 QLoRA 处理最高 81K 上下文长度，使用 BF16 LoRA 处理 60K 上下文！这些提升适用于 **两者** gpt-oss-20b 和 **gpt-oss-120b**！你使用的上下文长度越长，从 Unsloth Flex Attention 获得的收益就越大：

<figure><img src="/files/7943cc723ba8b27e27ff9397c954412404a17e99" alt="" width="563"><figcaption></figcaption></figure>

相比之下，所有其他非 Unsloth 实现最多只能在 80GB GPU 上达到 9K 上下文长度，而使用 FA3 也只能达到 15K 上下文。不过， **FA3 不适合用于 gpt-oss 训练，因为它不支持 attention sink 的反向传播**。所以如果你之前在 gpt-oss 训练中使用 FA3，我们建议你 **暂时不要使用它** 。因此，在 80GB VRAM 上不使用 Unsloth 时，你能获得的最大上下文长度约为 \~9K。

使用 Unsloth Flex Attention 训练至少可带来 1.3× 的速度提升，并且随着上下文长度增加，收益会继续扩大，最高可快 2×。由于 Flex Attention 会随上下文规模扩展，更长的序列在显存和训练时间上都能带来更大的节省，正如 [这里所述](#unsloths-flex-attention-implementation).

非常感谢 Rohan Pandey 提供他的 [Flex Attention 实现](https://x.com/khoomeik/status/1955693558914310608)，这直接启发了 Unsloth 的 Flex Attention 实现开发。

## :dark\_sunglasses: Attention Sinks

OpenAI 的 GPT OSS 模型使用了一种 **滑动窗口注意力与全注意力交替的模式**，即滑动窗口注意力、全注意力、滑动窗口注意力、全注意力，等等（SWA、FA、SWA、FA，依此类推）。每个滑动窗口只关注 **128 个 token** （包括当前 token），因此计算量大幅减少。然而，这也意味着由于滑动窗口太小，长上下文检索和推理会变得无用。大多数实验室会通过将滑动窗口扩展到 2048 或 4096 个 token 来解决这一问题。

OpenAI 借鉴了 **Attention Sinks** 《Efficient Streaming Language Models with Attention Sinks》 [论文](https://arxiv.org/abs/2309.17453) 该论文表明你可以使用一个较小的滑动窗口，但必须在第一个 token 上添加全局注意力！论文下方提供了一个很好的示意：

<figure><img src="/files/3496f27e5a9efcfefbd904521944091004a63bd4" alt=""><figcaption></figcaption></figure>

论文发现， **注意力机制似乎会给最前面的几个 token（1 到 4）分配大量权重**，而在滑动窗口操作中移除它们后，这些“重要”的前几个 token 就会消失，并导致糟糕的长上下文检索。

如果我们绘制 log perplexity（越高越差），并在预训练模型设定的上下文长度之后进行长上下文推理，可以看到 perplexity 会飙升（不好）。然而红线（使用 Attention Sinks）保持较低，这非常好！

<figure><img src="/files/c2d1b7dca761893dffa8ba735fd3e9f751e2403e" alt=""><figcaption></figcaption></figure>

论文还表明， [Attention Is Off By One 方法](https://www.evanmiller.org/attention-is-off-by-one.html) 确实有部分效果，但还必须再添加几个额外的 sink token 才能获得更低的 perplexity。 **论文表明，添加一个可学习的单个 sink token 效果非常出色！ 而这正是 OpenAI 为 GPT-OSS 所做的！**

<figure><img src="/files/97083cda46a2b623793651efb4fe8bcf87f59017" alt=""><figcaption></figcaption></figure>

## :triangular\_ruler:Unsloth 的 Flex Attention 实现

Flex Attention <https://pytorch.org/blog/flexattention/> 非常强大，因为它为从业者提供了两种对注意力机制的自定义方式——一个 **分数修饰器（f）** 以及一个 **掩码函数（M）**.

该 **分数修饰器（f）** 允许我们在 softmax 操作之前修改注意力 logits，而 **掩码函数（M）** 允许我们在不需要某些操作时跳过它们（例如滑动窗口注意力只看最后 128 个 token）。

<mark style="background-color:green;">**其巧妙之处在于，Flex Attention 提供了可快速自动生成的 Triton 内核，并支持任意分数修饰器和掩码函数！**</mark>

<p align="center"><span class="math">\sigma\bigg(s\times\bold{f}(QK^T+\bold{M})\bigg)</span><br></p>

这意味着我们可以使用 Flex Attention 来实现 attention sinks！实现单个 attention sink 的方式既见于 [OpenAI 原始的 GPT-OSS 仓库](#implementations-for-sink-attention) 也见于 HuggingFace 的 transformers 实现。

```python
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
probs = F.softmax(combined_logits, dim=-1)
scores = probs[..., :-1]
```

以上显示我们将 sink 拼接到 `Q @ K.T` 的最末端，进行 softmax，然后移除最后一列，也就是 sink token。

通过使用 [Flex Attention 的 Github 仓库](https://github.com/meta-pytorch/attention-gym)中的一些可视化工具，我们可以将其可视化。假设序列长度为 16，滑动窗口为 5。左边是最后一个 sink 列（默认实现），右边是将 sink 位置移动到索引 0（我们的实现）。

{% columns %}
{% column %}
***sink 位置在末尾（默认）***

<figure><img src="/files/388639f182c06846b841713f2f068fd1c6864310" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column %}
***将 sink 位置移动到索引 0***

<figure><img src="/files/2871b9a4521af4aa240df1daaeac4789b9286538" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

**有趣的发现**：官方的 Flex Attention 滑动窗口实现将窗口大小视为最后若干 token 的数量 **再加一** ，因为它包含当前 token。HuggingFace 和 GPT OSS 的实现则严格只看最后 N 个 token。也就是说，下面来自 <https://pytorch.org/blog/flexattention/> 和 <https://github.com/meta-pytorch/attention-gym>:

{% code overflow="wrap" %}

```python
def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask
```

{% endcode %}

{% columns %}
{% column %}
默认 Flex Attention（3+1 个 token）

<figure><img src="/files/d9a0e337a817f92a214dfac892ca9fc9db85ad58" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column %}
HuggingFace，GPT-OSS（3+0 个 token）

<figure><img src="/files/cda8865f91de29b0f2fb82e822ef021c34a78732" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

我们还通过 OpenAI 官方的 GPT-OSS 实现确认了这里到底是关注最后 N 个还是 N+1 个 token： <https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py>

```python
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
    mask += torch.tril(
        mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
    )
```

<figure><img src="/files/3f483810942ef68e3cf30250965c54315d189d6c" alt=""><figcaption></figcaption></figure>

我们看到这里只关注最后 3 个 token（不是 3+1）！这意味着我们应该使用 `<= SLIDING_WINDOW`，使用 `< SLIDING_WINDOW` （也就是使用小于号，而不是等于号）。

```python
def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW # 默认 Flex Attention
    window_mask = q_idx - kv_idx <  SLIDING_WINDOW # GPT-OSS 版本
    return causal_mask & window_mask
```

另外，由于我们把 sink token 的索引移到了第一个位置，因此需要给 q\_idx 加 1 才能正确索引：

```python
def causal_mask_with_sink(batch, head, q_idx, kv_idx):
    """
      0 1 2 3     0 1 2 3
    0 X X       1   X
    1 X X X     2   X X
    2 X X X X   3   X X X
    """
    # 我们加上 (q_idx + 1)，因为第一列是 sink token
    causal_mask = (q_idx + 1) >= kv_idx
    sink_first_column = kv_idx == 0
    return causal_mask | sink_first_column
```

为了确认我们索引 0 的实现，我们验证了训练 loss 与标准 Hugging Face 运行（不使用 Unsloth Flex Attention）保持一致，如我们的图所示：

<figure><img src="/files/00c1f2f93e429ca2264409a516a91d63edaa23d7" alt="" width="375"><figcaption></figcaption></figure>

## :scroll: attention sinks 的数学推导

还有另一种计算 attention sinks 的方法，不需要对 K 和 V 进行 padding。我们首先注意到 softmax 操作会，并且我们现在希望先将带 sinks 的第二种版本视为一个标量：\\

$$
A(x) = \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \\
A\_{sink}(x) = \frac{\exp(x\_i)}{\exp{(s)}+ \sum{\exp{(x\_i)}}}
$$

我们可以通过以下方式从 Flex Attention 获得 logsumexp： `return_lse = True` ，然后我们这样做：

$$
A(x) = \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \\
\frac{\exp(x\_i)}{\exp{(s)}+ \sum{\exp{(x\_i)}}} =  \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \frac{\sum{\exp{(x\_i)}}}{\exp{(s)}+ \sum{\exp{(x\_i)}}} \\
\text{LSE}(x) = \text{logsumexp}(x) = \log{\sum\exp(x\_i)} \\
\exp{(\text{LSE}(x))} = \exp{\big(\log{\sum\exp(x\_i)}\big)} = \sum\exp(x\_i)
$$

现在我们可以很容易地推导出带 sink 的注意力版本。不过我们发现，这个过程的误差略高于零填充方法，因此我们仍然默认使用原始版本。

## 💾**新功能：gpt-oss 训练后保存为 GGUF、vLLM**

你现在可以对 gpt-oss 进行 QLoRA 微调，并直接将模型保存、导出或合并到 **llama.cpp**, **vLLM**，或 **HF** ——不只是 Unsloth。我们希望很快会发布一个免费笔记本。

此前，任何经过 QLoRA 微调的 gpt-oss 模型都只能在 Unsloth 中运行。我们通过引入将模型合并为 **MXFP4** **原生格式** 的能力，并使用 `save_method="mxfp4"` 和 **按需反量化 MXFP4** 基础模型（如 gpt-oss）使得 **可以使用** `save_method="merged_16bit"` .

该 **MXFP4** 原生合并格式相比 **bf16 格式**具有显著的性能提升：它最多可节省 75% 的磁盘空间，将 VRAM 占用减少 50%，使合并速度提升 5-10 倍，并且能更快地转换为 **GGUF** 格式。

在微调完你的 gpt-oss 模型后，你可以将其合并为 **MXFP4** 格式，命令如下：

```python
model.save_pretrained_merged(save_directory, tokenizer, save_method="mxfp4")
```

如果你更喜欢将模型合并后推送到 hugging-face hub，请使用：

```python
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token, save_method="mxfp4")
```

要在合并后的模型上运行推理，你可以使用 vLLM 和 Llama.cpp 等。OpenAI 为这两个模型推荐这些 [推理设置](/docs/zh/mo-xing/gpt-oss-how-to-run-and-fine-tune.md#recommended-settings) ： `temperature=1.0`, `top_p=1.0`, `top_k=0`

#### :sparkles: 保存到 Llama.cpp

1. 获取最新的 `llama.cpp` 在 [GitHub 这里](https://github.com/ggml-org/llama.cpp)。你也可以按照下面的构建说明操作。将 `-DGGML_CUDA=ON` 改为 `-DGGML_CUDA=OFF` 如果你没有 GPU，或者只想进行 CPU 推理。

   ```bash
   apt-get update
   apt-get install pciutils build-essential cmake curl libcurl4-openssl-dev -y
   git clone https://github.com/ggml-org/llama.cpp
   cmake llama.cpp -B llama.cpp/build \
       -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON
   cmake --build llama.cpp/build --config Release -j --clean-first --target llama-cli llama-gguf-split
   cp llama.cpp/build/bin/llama-* llama.cpp
   ```
2. 转换 **MXFP4** 合并后的模型：

   ```bash
   python3 llama.cpp/convert_hf_to_gguf.py gpt-oss-finetuned-merged/ --outfile gpt-oss-finetuned-mxfp4.gguf
   ```
3. 在量化后的模型上运行推理：

   ```bash
   llama.cpp/llama-cli --model gpt-oss-finetuned-mxfp4.gguf \
       --jinja -ngl 99 --threads -1 --ctx-size 16384 \
       --temp 1.0 --top-p 1.0 --top-k 0 \
        -p "生命、宇宙以及一切的意义是"
   ```

<details>

<summary><span data-gb-custom-inline data-tag="emoji" data-code="2728">✨</span> 保存到 SGLang</summary>

1. 从源码构建 SGLang：\\

   ```bash
   # 从源码构建
   git clone https://github.com/sgl-project/sglang
   cd sglang
   pip3 install pip --upgrade
   pip3 install -e "python[all]"

   # ROCm 6.3
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/rocm6.3
   git clone https://github.com/triton-lang/triton
   cd python/triton_kernels
   pip3 install .

   # hopper
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu126
   pip3 install sgl-kernel==0.3.2

   # blackwell cu128
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
   pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2+cu128-cp39-abi3-manylinux2014_x86_64.whl

   # blackwell cu129
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu129
   pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2-cp39-abi3-manylinux2014_x86_64.whl
   ```
2. 启动 SGLang 服务器：\\

   ```bash
   python3 -m sglang.launch_server --model-path ./gpt-oss-finetuned-merged/
   ```
3. 运行推理：\\

   ```python
   import requests
   from sglang.utils import print_highlight

   url = f"http://localhost:8000/v1/chat/completions"

   data = {
       "model": "gpt-oss-finetuned-merged",
       "messages": [{"role": "user", "content": "法国的首都是什么？"}],
   }

   response = requests.post(url, json=data)
   print_highlight(response.json())
   ```

</details>

### :diamonds:直接微调 gpt-oss

我们还通过实现补丁加入了对直接微调 gpt-oss 模型的支持，这些补丁允许加载原生 MXFP4 量化格式。这使得可以在少于 24GB 的 VRAM 下加载 'openai/gpt-oss' 模型，并对其进行 QLoRA 微调。只需使用以下方式加载模型：

```python
model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "unsloth/gpt-oss-20b-BF16", 
    model_name = "unsloth/gpt-oss-20b",
    dtype = dtype, # 自动检测时设为 None
    max_seq_length = max_seq_length, # 为长上下文任意选择！
    load_in_4bit = True,  # 使用 4bit 量化以减少内存
    full_finetuning = False, # [NEW!] 我们现在支持全参数微调！
    # token = "hf_...", # 如果使用受限模型，请使用这个
)
```

使用以下方式添加一个 Peft 层 `FastLanguageModel.get_peft_model` 然后在该 Peft 模型上运行 SFT 微调。

## 🐛gpt-oss 的错误修复

我们 [最近与 Hugging Face 合作](https://github.com/huggingface/transformers/pull/40197) 通过使用 OpenAI 的 kernel 并确保 `swiglu_limit = 7.0` 在 MXFP4 推理过程中被正确应用，来解决推理问题。

根据用户反馈，我们发现延长的 QLoRA 训练运行（超过 60 步）可能会导致 **loss 发散并最终报错**。这个问题只会发生在不支持 BF16、而是回退到 F16 的设备上（例如 T4 GPU）。重要的是，它不会影响在 A100 或 H100 GPU 上的 QLoRA 训练，也不会影响在 f16 GPU 上的 LoRA 训练。

**经过大量调查，我们现在已将所有 GPU 配置下的训练 loss 行为对齐，包括仅支持 F16 的 GPU**。如果你之前因此遇到问题，我们建议使用我们新的更新版 gpt-oss 笔记本！

<figure><img src="/files/e2285c3b271c174ed61c0fbaeaee9a071de11e83" alt=""><figcaption></figcaption></figure>

我们不得不做很多很多实验，才能让 float16 的训练 loss 曲线与 bfloat16 机器（蓝线）保持一致。我们发现如下：

1. **纯 float16 会在第 50 步发散到无穷大**
2. **我们发现 MoE 中的下投影存在巨大的异常值**
3. **激活值必须以 bfloat16 或 float32 保存**

**下图显示了 GPT OSS 20B 的绝对幅值激活，其中一些确实会骤增——由于 float16 的最大范围是 65504，这将在 float16 机器上溢出。**

**我们在 Unsloth 中修复了这一点，因此所有 float16 训练都可以开箱即用！**

<figure><img src="/files/99e8886c8c7f150df0d580822ca3ae8e256667a4" alt=""><figcaption></figcaption></figure>

## :1234: Sink Attention 的实现

OpenAI 的 sink token 实现 [可在此处获取](https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py)。我们在下方提供：

{% code fullWidth="false" %}

```python
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
    # sliding_window == 0 表示没有滑动窗口
    n_tokens, n_heads, q_mult, d_head = Q.shape
    assert K.shape == (n_tokens, n_heads, d_head)
    assert V.shape == (n_tokens, n_heads, d_head)
    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
    if sliding_window > 0:
        mask += torch.tril(
            mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
        )
    QK = torch.einsum("qhmd,khmd->hmqk", Q, K) * sm_scale
    QK += mask[None, None, :, :]
    QK = torch.cat([QK, S], dim=-1)
    W = torch.softmax(QK, dim=-1)
    W = W[..., :-1]
    attn = torch.einsum("hmqk,khmd->qhmd", W, V)
    return attn.reshape(n_tokens, -1)
```

{% endcode %}

HuggingFace transformers 的实现是 [可在此处获取](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/modeling_gpt_oss.py)。我们也在下方提供：

{% code fullWidth="false" %}

```python
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
    combined_logits = torch.cat([attn_weights, sinks], dim=-1)

    # 这并不在原始实现中，并且会轻微影响结果；它可防止 BF16/FP16 中的溢出
    # 在 bsz>1 训练时我们会对最大值进行截断。

    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
    scores = probs[..., :-1]  # 我们在这里去掉 sink
    attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights
```

{% endcode %}


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://unsloth.ai/docs/zh/mo-xing/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
