# 长上下文 gpt-oss 训练

我们很高兴介绍针对 OpenAI gpt-oss 训练的 Unsloth 弹性注意力（Flex Attention）支持，它可以实现 **>8× 更长的上下文长度**, **>50% 更少的显存使用** 以及 **>1.5× 更快的训练（且无精度下降）** 相比包括使用 Flash Attention 3 (FA3) 在内的所有实现。Unsloth 弹性注意力使得可以在 **60K 上下文长度** 在 80GB VRAM 的 H100 GPU 上进行 BF16 LoRA。此外：

* 你可以 [现在可导出/保存](#new-saving-to-gguf-vllm-after-gpt-oss-training) 你的 QLoRA 微调 gpt-oss 模型到 llama.cpp、vLLM、Ollama 或 HF
* 我们 [**修复了 gpt-oss 训练中的**](#bug-fixes-for-gpt-oss) **损失值趋于无穷大** 在 float16 GPU（例如 T4 Colab）上
* 我们 [修复了 gpt-oss 实现中的](#bug-fixes-for-gpt-oss) 与 Unsloth 无关的问题，最显著的是确保 `swiglu_limit = 7.0` 在 transformers 的 MXFP4 推理期间被正确应用

## 🦥介绍 Unsloth 弹性注意力支持

有了 Unsloth 的弹性注意力支持，单个 80GB VRAM 的 H100 在 QLoRA 下最多可处理 81K 上下文长度，在 BF16 LoRA 下可处理 60K 上下文长度！这些提升适用于 **二者** gpt-oss-20b 和 **gpt-oss-120b**！你使用的上下文长度越大，从 Unsloth 弹性注意力中获得的提升就越多：

<figure><img src="/files/7943cc723ba8b27e27ff9397c954412404a17e99" alt="" width="563"><figcaption></figcaption></figure>

相比之下，所有其他非 Unsloth 的实现在线 80GB GPU 上的最大上下文长度为 9K，并且即便使用 FA3 也只能达到 15K。但 **FA3 不适合用于 gpt-oss 训练，因为它缺乏对 attention sinks 的反向传播支持**。所以如果你之前在使用 FA3 进行 gpt-oss 训练，我们建议你 **暂时不要使用它** 因此，在没有 Unsloth 的情况下，你在 80GB VRAM 上能获得的最大上下文长度约为 \~9K。

使用 Unsloth 弹性注意力进行训练至少带来 1.3× 的加速，随上下文长度增加收益更大，最多可达 2× 加速。因为弹性注意力随上下文扩展而扩展，较长的序列在 VRAM 使用和训练时间上带来更大的节省，正如 [在此描述](#unsloths-flex-attention-implementation).

非常感谢 Rohan Pandey 提供的 [弹性注意力实现](https://x.com/khoomeik/status/1955693558914310608)，它直接启发了 Unsloth 弹性注意力实现的发展。

## :dark\_sunglasses: 注意力 Sink（汇点）

OpenAI 的 GPT OSS 模型使用了一种 **交替模式：滑动窗口注意力、全注意力**、滑动窗口注意力，依此类推（SWA、FA、SWA、FA 等）。每个滑动窗口只关注 **128 个标记** （包括当前标记），因此计算量大幅减少。然而，这也意味着由于滑动窗口较小，长上下文的检索和推理变得无效。大多数实验室通过将滑动窗口扩展到 2048 或 4096 个标记来修复这个问题。

OpenAI 借鉴了 **注意力 Sink（汇点）** 来自《具有注意力汇点的高效流式语言模型（Efficient Streaming Language Models with Attention Sinks）》 [论文中的方法](https://arxiv.org/abs/2309.17453) 该论文表明你可以使用一个小的滑动窗口，但你必须在第一个标记上添加全局注意力！论文下面提供了一个很好的示意图：

<figure><img src="/files/3496f27e5a9efcfefbd904521944091004a63bd4" alt=""><figcaption></figcaption></figure>

论文发现 **注意力机制似乎对前几个标记（1 到 4）分配了大量权重**，在滑动窗口操作过程中如果移除了这些“重要”的前几个标记，它们就会消失，从而导致糟糕的长上下文检索。

如果我们绘制对数困惑度（值越高越差），并在预训练模型设定的上下文长度之后进行长上下文推理，我们会看到困惑度急剧上升（不理想）。然而红线（使用 Attention Sinks）保持较低，这非常好！

<figure><img src="/files/c2d1b7dca761893dffa8ba735fd3e9f751e2403e" alt=""><figcaption></figcaption></figure>

论文还表明 [“注意力偏差一个（Attention Is Off By One）”方法](https://www.evanmiller.org/attention-is-off-by-one.html) 确实部分有效，但也必须添加一些额外的 sink 标记以获得更低的困惑度。 **论文显示添加一个可学习的单个 sink 标记效果非常好！ 这也是 OpenAI 在 GPT-OSS 中所做的！**

<figure><img src="/files/97083cda46a2b623793651efb4fe8bcf87f59017" alt=""><figcaption></figcaption></figure>

## :triangular\_ruler:Unsloth 的弹性注意力实现

灵活注意力（Flex Attention） <https://pytorch.org/blog/flexattention/> 非常强大，因为它为实践者提供了两条自定义注意力机制的路线 - 一个 **分数修正器（f）** 和一个 **掩码函数（M）**.

这些 **分数修正器（f）** 允许我们在 softmax 操作之前编辑注意力 logits，并且 **掩码函数（M）** 允许我们在不需要某些操作时跳过它们（例如滑动窗口注意力只看到最后 128 个标记）。

<mark style="background-color:green;">**关键在于弹性注意力提供了带任意分数修正器和掩码函数的快速自动生成 Triton 内核！**</mark>

<p align="center"><span class="math">\sigma\bigg(s\times\bold{f}(QK^T+\bold{M})\bigg)</span><br></p>

这意味着我们可以使用弹性注意力来实现注意力汇点（attention sinks）！在 [OpenAI 的原始 GPT-OSS 仓库](#implementations-for-sink-attention) 和 HuggingFace 的 transformers 实现中都提供了实现单个注意力 sink 的方法。

```python
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
probs = F.softmax(combined_logits, dim=-1)
scores = probs[..., :-1]
```

上述显示我们将 sink 串联在 `Q @ K.T` 的最末端，进行 softmax，然后移除最后一列（即 sink 标记）。

通过使用一些来自 [Flex Attention 的 Github 仓库](https://github.com/meta-pytorch/attention-gym)的可视化工具，我们可以将其可视化。假设序列长度为 16，滑动窗口为 5。左侧是最后的 sink 列（默认实现），右侧是如果我们将 sink 位置移动到索引 0（我们的实现）。

{% columns %}
{% column %}
***在末端的 sink 位置（默认）***

<figure><img src="/files/388639f182c06846b841713f2f068fd1c6864310" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column %}
***将 sink 位置移动到索引 0***

<figure><img src="/files/2871b9a4521af4aa240df1daaeac4789b9286538" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

**有趣的发现**：官方的 Flex Attention 滑动窗口实现将窗口大小视为最后几个标记的数量 **再加一** 因为它包括当前标记。HuggingFace 和 GPT OSS 的实现严格只看到最后 N 个标记。即下面来自 <https://pytorch.org/blog/flexattention/> 以及 <https://github.com/meta-pytorch/attention-gym>:

{% code overflow="wrap" %}

```python
def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask
```

{% endcode %}

{% columns %}
{% column %}
默认 Flex Attention（3+1 个标记）

<figure><img src="/files/d9a0e337a817f92a214dfac892ca9fc9db85ad58" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column %}
HuggingFace、GPT-OSS（3+0 个标记）

<figure><img src="/files/cda8865f91de29b0f2fb82e822ef021c34a78732" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

我们还通过 OpenAI 官方的 GPT-OSS 实现确认了我们是在关注最后 N 个还是 N+1 个标记： <https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py>

```python
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
    mask += torch.tril(
        mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
    )
```

<figure><img src="/files/3f483810942ef68e3cf30250965c54315d189d6c" alt=""><figcaption></figcaption></figure>

我们看到只关注最后 3 个标记（而不是 3+1）！这意味着应当使用 `<= SLIDING_WINDOW`，而改为使用 `< SLIDING_WINDOW` （即使用小于，而不是等于）。

```python
def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW # 默认 Flex Attention
    window_mask = q_idx - kv_idx <  SLIDING_WINDOW # GPT-OSS 版本
    return causal_mask & window_mask
```

另外由于我们将 sink 标记索引移动到第一列，我们必须对 q\_idx 加 1 以正确索引：

```python
def causal_mask_with_sink(batch, head, q_idx, kv_idx):
    """
      0 1 2 3     0 1 2 3
    0 X X       1   X
    1 X X X     2   X X
    2 X X X X   3   X X X
    """
    # 我们加上 (q_idx + 1) 因为第一列是 sink 标记
    causal_mask = (q_idx + 1) >= kv_idx
    sink_first_column = kv_idx == 0
    return causal_mask | sink_first_column
```

为了确认我们索引为 0 的实现，我们验证了训练损失与标准 Hugging Face 运行（不使用 Unsloth 弹性注意力）保持一致，如我们的图表所示：

<figure><img src="/files/00c1f2f93e429ca2264409a516a91d63edaa23d7" alt="" width="375"><figcaption></figcaption></figure>

## :scroll: 注意力 sink 的数学推导

还有另一种在不对 K 和 V 进行填充的情况下计算注意力 sink 的方法。我们首先注意到 softmax 操作会做什么，现在我们想要带 sinks 的第二个版本作为一个标量：\\

$$
A(x) = \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \\
A\_{sink}(x) = \frac{\exp(x\_i)}{\exp{(s)}+ \sum{\exp{(x\_i)}}}
$$

我们可以通过 Flex Attention 获取 logsumexp，方法为 `return_lse = True` ，因此我们这样做：

$$
A(x) = \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \\
\frac{\exp(x\_i)}{\exp{(s)}+ \sum{\exp{(x\_i)}}} =  \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \frac{\sum{\exp{(x\_i)}}}{\exp{(s)}+ \sum{\exp{(x\_i)}}} \\
\text{LSE}(x) = \text{logsumexp}(x) = \log{\sum\exp(x\_i)} \\
\exp{(\text{LSE}(x))} = \exp{\big(\log{\sum\exp(x\_i)}\big)} = \sum\exp(x\_i)
$$

现在我们可以很容易推导出带 sink 的注意力版本。不过我们确实发现这个过程的误差比零填充方法略高，所以我们仍然默认使用原始版本。

## 💾**新：在 gpt-oss 训练后保存为 GGUF、vLLM**

你现在可以 QLoRA 微调 gpt-oss 并直接保存、导出或合并模型为 **llama.cpp**, **你现在可以在微调流程中直接使用**，或 **HF** —— 不仅限于 Unsloth。我们希望很快发布一个免费的笔记本。

之前，任何 QLoRA 微调的 gpt-oss 模型都被限制只能在 Unsloth 中运行。我们通过引入合并 **MXFP4** **本地格式** 使用 `save_method="mxfp4"` 以及 **在 LoRA 合并过程中按需去量化 MXFP4** 基模型（如 gpt-oss）使得可以 **使用以下方式以 bf16 格式导出你的微调模型** `save_method="merged_16bit"` .

这些 **MXFP4** 本地合并格式相比 **bf16 格式**具有显著性能提升：它最多使用 75% 更少的磁盘空间，减少 50% 的 VRAM 消耗，加速合并 5-10 倍，并使得转换为 **GGUF** 格式变得更快。

微调你的 gpt-oss 模型后，你可以将其合并为 **MXFP4** 格式，命令如下：

```python
model.save_pretrained_merged(save_directory, tokenizer, save_method="mxfp4")
```

如果你想合并模型并推送到 hugging-face hub，使用：

```python
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token, save_method="mxfp4")
```

要在合并后的模型上运行推理，你可以使用 vLLM 和 Llama.cpp 等工具。OpenAI 推荐以下 [推理设置](/docs/zh/mo-xing/gpt-oss-how-to-run-and-fine-tune.md#recommended-settings) 适用于两种模型： `temperature=1.0`, `top_p=1.0`, `top_k=0`

#### :sparkles: 保存到 Llama.cpp

1. 获取最新的 `llama.cpp` 在 [GitHub（在此）](https://github.com/ggml-org/llama.cpp)。你也可以按照下面的构建说明。若没有 GPU 或仅想使用 CPU 推理，请将 `-DGGML_CUDA=ON` 改为 `-DGGML_CUDA=OFF` 。

   ```bash
   apt-get update
   apt-get install pciutils build-essential cmake curl libcurl4-openssl-dev -y
   git clone https://github.com/ggml-org/llama.cpp
   cmake llama.cpp -B llama.cpp/build \
       -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON
   cmake --build llama.cpp/build --config Release -j --clean-first --target llama-cli llama-gguf-split
   cp llama.cpp/build/bin/llama-* llama.cp
   ```
2. 转换该 **MXFP4** 合并模型：

   ```bash
   python3 llama.cpp/convert_hf_to_gguf.py gpt-oss-finetuned-merged/ --outfile gpt-oss-finetuned-mxfp4.gguf
   ```
3. 在量化模型上运行推理：

   ```bash
   llama.cpp/llama-cli --model gpt-oss-finetuned-mxfp4.gguf \
       --jinja -ngl 99 --threads -1 --ctx-size 16384 \
       --temp 1.0 --top-p 1.0 --top-k 0 \
        -p "The meaning to life and the universe is"
   ```

<details>

<summary><span data-gb-custom-inline data-tag="emoji" data-code="2728">✨</span> 保存到 SGLang</summary>

1. 从源码构建 SGLang:\\

   ```bash
   # 从源码构建
   git clone https://github.com/sgl-project/sglang
   cd sglang
   pip3 install pip --upgrade
   pip3 install -e "python[all]"

   # ROCm 6.3
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/rocm6.3
   git clone https://github.com/triton-lang/triton
   cd python/triton_kernels
   pip3 install .

   # hopper
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu126
   pip3 install sgl-kernel==0.3.2

   # blackwell cu128
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
   pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2+cu128-cp39-abi3-manylinux2014_x86_64.whl

   # blackwell cu129
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu129
   pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2-cp39-abi3-manylinux2014_x86_64.whl
   ```
2. 启动 SGLang 服务器：\\

   ```bash
   python3 -m sglang.launch_server --model-path ./gpt-oss-finetuned-merged/
   ```
3. 运行推理：\\

   ```python
   import requests
   from sglang.utils import print_highlight

   url = f"http://localhost:8000/v1/chat/completions"

   data = {
       "model": "gpt-oss-finetuned-merged",
       "messages": [{"role": "user", "content": "What is the capital of France?"}],
   }

   response = requests.post(url, json=data)
   print_highlight(response.json())
   ```

</details>

### :diamonds:直接微调 gpt-oss

我们还通过实现补丁来增加直接微调 gpt-oss 模型的支持，这些补丁允许加载本地 MXFP4 量化格式。这使得可以在少于 24GB VRAM 的情况下加载“openai/gpt-oss”模型并进行 QLoRA 微调。只需使用以下方式加载模型：

```python
from unsloth import FastLanguageModel
    # model_name = "unsloth/gpt-oss-20b-BF16", 
    model_name = "unsloth/gpt-oss-20b",
    dtype = dtype, # None 表示自动检测
    max_seq_length = max_seq_length, # 可为长上下文选择任意长度！
    load_in_4bit = True,  # 使用 4 位量化以减少内存
    full_finetuning = False, # [新!] 我们现在支持全量微调！
    # token = "hf_...", # 如果使用受限模型可提供令牌
)
```

使用以下方式添加一个 Peft 层 `FastLanguageModel.get_peft_model` 并在该 Peft 模型上运行 SFT 微调。

## 🐛 gpt-oss 的错误修复

我们 [最近与 Hugging Face 合作](https://github.com/huggingface/transformers/pull/40197) 通过使用 OpenAI 的内核并确保 `swiglu_limit = 7.0` 在 MXFP4 推理期间被正确应用，从而解决了推理问题。

根据用户反馈，我们发现延长的 QLoRA 训练（超过 60 步）可能导致 **损失发散并最终报错**。该问题仅发生在不支持 BF16 而回退到 F16 的设备上（例如 T4 GPU）。重要的是，它并未影响在 A100 或 H100 GPU 上的 QLoRA 训练，也未影响在 f16 GPU 上的 LoRA 训练。

**经过大量调查，我们现已在所有 GPU 配置（包括仅支持 F16 的 GPU）上使训练损失行为一致。**&#x5982;果你之前因该问题而遇到问题，建议使用我们更新的 gpt-oss 笔记本！

<figure><img src="/files/e2285c3b271c174ed61c0fbaeaee9a071de11e83" alt=""><figcaption></figcaption></figure>

我们不得不进行大量实验，使 float16 的训练损失曲线与 bfloat16 机器（蓝线）等效。我们发现以下情况：

1. **纯 float16 会在第 50 步时发散到无穷**
2. **我们发现 MoE 中的下投影存在巨大的异常值**
3. **激活值必须以 bfloat16 或 float32 保存**

**下面显示了 GPT OSS 20B 的绝对幅值激活，其中一些会出现极点——这将在 float16 机器上溢出，因为 float16 的最大范围为 65504。**

**我们在 Unsloth 中修复了这个问题，所以所有 float16 训练开箱即可正常工作！**

<figure><img src="/files/99e8886c8c7f150df0d580822ca3ae8e256667a4" alt=""><figcaption></figcaption></figure>

## :1234: Sink 注意力的实现

OpenAI 的 sink 标记实现位于 [此处提供](https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py)。我们在下面提供其实现：

{% code fullWidth="false" %}

```python
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
    # sliding_window == 0 表示没有滑动窗口
    n_tokens, n_heads, q_mult, d_head = Q.shape
    assert K.shape == (n_tokens, n_heads, d_head)
    assert V.shape == (n_tokens, n_heads, d_head)
    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
    if sliding_window > 0:
        mask += torch.tril(
            mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
        )
    QK = torch.einsum("qhmd,khmd->hmqk", Q, K) * sm_scale
    QK += mask[None, None, :, :]
    QK = torch.cat([QK, S], dim=-1)
    W = torch.softmax(QK, dim=-1)
    W = W[..., :-1]
    attn = torch.einsum("hmqk,khmd->qhmd", W, V)
    return attn.reshape(n_tokens, -1)
```

{% endcode %}

HuggingFace transformers 的实现位于 [此处提供](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/modeling_gpt_oss.py)。我们也在下面提供它：

{% code fullWidth="false" %}

```python
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
    combined_logits = torch.cat([attn_weights, sinks], dim=-1)

    # 这在原始实现中没有，且会略微影响结果；它防止在 BF16/FP16 下溢出
    # 在 batch size > 1 的训练时我们将最大值钳位。

    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
    scores = probs[..., :-1]  # 我们在此处丢弃 sink
    attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights
```

{% endcode %}


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://unsloth.ai/docs/zh/mo-xing/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
