openaigpt-oss 长上下文训练

我们很高兴为 OpenAI gpt-oss 训练推出 Unsloth Flex Attention 支持,它能够实现 >8× 更长的上下文长度, >50% 更少的 VRAM 占用>1.5× 更快的训练(且不会降低准确率) 相比所有实现,包括使用 Flash Attention 3(FA3)的实现。Unsloth Flex Attention 使得可以在 60K 上下文长度 的 80GB VRAM H100 GPU 上进行 BF16 LoRA 训练。此外:

  • 你现在可以 导出/保存 你的 QLoRA 微调 gpt-oss 模型到 llama.cpp、vLLM、Ollama 或 HF

  • 我们 修复了 gpt-oss 训练中 loss 变为无穷大 在 float16 GPU 上(例如 T4 Colab)

  • 我们 修复了 gpt-oss 实现中的 与 Unsloth 无关的问题,尤其是确保 swiglu_limit = 7.0 在 transformers 中的 MXFP4 推理过程中被正确应用

🦥介绍 Unsloth Flex Attention 支持

借助 Unsloth 的 Flex Attention 支持,一块 80GB VRAM 的 H100 可以使用 QLoRA 处理最高 81K 上下文长度,使用 BF16 LoRA 处理 60K 上下文!这些提升适用于 两者 gpt-oss-20b 和 gpt-oss-120b!你使用的上下文长度越长,从 Unsloth Flex Attention 获得的收益就越大:

相比之下,所有其他非 Unsloth 实现最多只能在 80GB GPU 上达到 9K 上下文长度,而使用 FA3 也只能达到 15K 上下文。不过, FA3 不适合用于 gpt-oss 训练,因为它不支持 attention sink 的反向传播。所以如果你之前在 gpt-oss 训练中使用 FA3,我们建议你 暂时不要使用它 。因此,在 80GB VRAM 上不使用 Unsloth 时,你能获得的最大上下文长度约为 ~9K。

使用 Unsloth Flex Attention 训练至少可带来 1.3× 的速度提升,并且随着上下文长度增加,收益会继续扩大,最高可快 2×。由于 Flex Attention 会随上下文规模扩展,更长的序列在显存和训练时间上都能带来更大的节省,正如 这里所述.

非常感谢 Rohan Pandey 提供他的 Flex Attention 实现arrow-up-right,这直接启发了 Unsloth 的 Flex Attention 实现开发。

🕶️ Attention Sinks

OpenAI 的 GPT OSS 模型使用了一种 滑动窗口注意力与全注意力交替的模式,即滑动窗口注意力、全注意力、滑动窗口注意力、全注意力,等等(SWA、FA、SWA、FA,依此类推)。每个滑动窗口只关注 128 个 token (包括当前 token),因此计算量大幅减少。然而,这也意味着由于滑动窗口太小,长上下文检索和推理会变得无用。大多数实验室会通过将滑动窗口扩展到 2048 或 4096 个 token 来解决这一问题。

OpenAI 借鉴了 Attention Sinks 《Efficient Streaming Language Models with Attention Sinks》 论文arrow-up-right 该论文表明你可以使用一个较小的滑动窗口,但必须在第一个 token 上添加全局注意力!论文下方提供了一个很好的示意:

论文发现, 注意力机制似乎会给最前面的几个 token(1 到 4)分配大量权重,而在滑动窗口操作中移除它们后,这些“重要”的前几个 token 就会消失,并导致糟糕的长上下文检索。

如果我们绘制 log perplexity(越高越差),并在预训练模型设定的上下文长度之后进行长上下文推理,可以看到 perplexity 会飙升(不好)。然而红线(使用 Attention Sinks)保持较低,这非常好!

论文还表明, Attention Is Off By One 方法arrow-up-right 确实有部分效果,但还必须再添加几个额外的 sink token 才能获得更低的 perplexity。 论文表明,添加一个可学习的单个 sink token 效果非常出色! 而这正是 OpenAI 为 GPT-OSS 所做的!

📐Unsloth 的 Flex Attention 实现

Flex Attention https://pytorch.org/blog/flexattention/arrow-up-right 非常强大,因为它为从业者提供了两种对注意力机制的自定义方式——一个 分数修饰器(f) 以及一个 掩码函数(M).

分数修饰器(f) 允许我们在 softmax 操作之前修改注意力 logits,而 掩码函数(M) 允许我们在不需要某些操作时跳过它们(例如滑动窗口注意力只看最后 128 个 token)。

其巧妙之处在于,Flex Attention 提供了可快速自动生成的 Triton 内核,并支持任意分数修饰器和掩码函数!

σ(s×f(QKT+M))\sigma\bigg(s\times\bold{f}(QK^T+\bold{M})\bigg)

这意味着我们可以使用 Flex Attention 来实现 attention sinks!实现单个 attention sink 的方式既见于 OpenAI 原始的 GPT-OSS 仓库 也见于 HuggingFace 的 transformers 实现。

以上显示我们将 sink 拼接到 Q @ K.T 的最末端,进行 softmax,然后移除最后一列,也就是 sink token。

通过使用 Flex Attention 的 Github 仓库arrow-up-right中的一些可视化工具,我们可以将其可视化。假设序列长度为 16,滑动窗口为 5。左边是最后一个 sink 列(默认实现),右边是将 sink 位置移动到索引 0(我们的实现)。

sink 位置在末尾(默认)

将 sink 位置移动到索引 0

有趣的发现:官方的 Flex Attention 滑动窗口实现将窗口大小视为最后若干 token 的数量 再加一 ,因为它包含当前 token。HuggingFace 和 GPT OSS 的实现则严格只看最后 N 个 token。也就是说,下面来自 https://pytorch.org/blog/flexattention/arrow-up-righthttps://github.com/meta-pytorch/attention-gymarrow-up-right:

默认 Flex Attention(3+1 个 token)

HuggingFace,GPT-OSS(3+0 个 token)

我们还通过 OpenAI 官方的 GPT-OSS 实现确认了这里到底是关注最后 N 个还是 N+1 个 token: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.pyarrow-up-right

我们看到这里只关注最后 3 个 token(不是 3+1)!这意味着我们应该使用 <= SLIDING_WINDOW,使用 < SLIDING_WINDOW (也就是使用小于号,而不是等于号)。

另外,由于我们把 sink token 的索引移到了第一个位置,因此需要给 q_idx 加 1 才能正确索引:

为了确认我们索引 0 的实现,我们验证了训练 loss 与标准 Hugging Face 运行(不使用 Unsloth Flex Attention)保持一致,如我们的图所示:

📜 attention sinks 的数学推导

还有另一种计算 attention sinks 的方法,不需要对 K 和 V 进行 padding。我们首先注意到 softmax 操作会,并且我们现在希望先将带 sinks 的第二种版本视为一个标量:\

A(x)=exp(xi)exp(xi)Asink(x)=exp(xi)exp(s)+exp(xi)A(x) = \frac{\exp(x_i)}{\sum{\exp{(x_i)}}} \\ A_{sink}(x) = \frac{\exp(x_i)}{\exp{(s)}+ \sum{\exp{(x_i)}}}

我们可以通过以下方式从 Flex Attention 获得 logsumexp: return_lse = True ,然后我们这样做:

A(x)=exp(xi)exp(xi)exp(xi)exp(s)+exp(xi)=exp(xi)exp(xi)exp(xi)exp(s)+exp(xi)LSE(x)=logsumexp(x)=logexp(xi)exp(LSE(x))=exp(logexp(xi))=exp(xi)A(x) = \frac{\exp(x_i)}{\sum{\exp{(x_i)}}} \\ \frac{\exp(x_i)}{\exp{(s)}+ \sum{\exp{(x_i)}}} = \frac{\exp(x_i)}{\sum{\exp{(x_i)}}} \frac{\sum{\exp{(x_i)}}}{\exp{(s)}+ \sum{\exp{(x_i)}}} \\ \text{LSE}(x) = \text{logsumexp}(x) = \log{\sum\exp(x_i)} \\ \exp{(\text{LSE}(x))} = \exp{\big(\log{\sum\exp(x_i)}\big)} = \sum\exp(x_i)

现在我们可以很容易地推导出带 sink 的注意力版本。不过我们发现,这个过程的误差略高于零填充方法,因此我们仍然默认使用原始版本。

💾新功能:gpt-oss 训练后保存为 GGUF、vLLM

你现在可以对 gpt-oss 进行 QLoRA 微调,并直接将模型保存、导出或合并到 llama.cpp, vLLM,或 HF ——不只是 Unsloth。我们希望很快会发布一个免费笔记本。

此前,任何经过 QLoRA 微调的 gpt-oss 模型都只能在 Unsloth 中运行。我们通过引入将模型合并为 MXFP4 原生格式 的能力,并使用 save_method="mxfp4"按需反量化 MXFP4 基础模型(如 gpt-oss)使得 可以使用 save_method="merged_16bit" .

MXFP4 原生合并格式相比 bf16 格式具有显著的性能提升:它最多可节省 75% 的磁盘空间,将 VRAM 占用减少 50%,使合并速度提升 5-10 倍,并且能更快地转换为 GGUF 格式。

在微调完你的 gpt-oss 模型后,你可以将其合并为 MXFP4 格式,命令如下:

如果你更喜欢将模型合并后推送到 hugging-face hub,请使用:

要在合并后的模型上运行推理,你可以使用 vLLM 和 Llama.cpp 等。OpenAI 为这两个模型推荐这些 推理设置temperature=1.0, top_p=1.0, top_k=0

保存到 Llama.cpp

  1. 获取最新的 llama.cppGitHub 这里arrow-up-right。你也可以按照下面的构建说明操作。将 -DGGML_CUDA=ON 改为 -DGGML_CUDA=OFF 如果你没有 GPU,或者只想进行 CPU 推理。

  2. 转换 MXFP4 合并后的模型:

  3. 在量化后的模型上运行推理:

chevron-right 保存到 SGLanghashtag
  1. 从源码构建 SGLang:\

  2. 启动 SGLang 服务器:\

  3. 运行推理:\

♦️直接微调 gpt-oss

我们还通过实现补丁加入了对直接微调 gpt-oss 模型的支持,这些补丁允许加载原生 MXFP4 量化格式。这使得可以在少于 24GB 的 VRAM 下加载 'openai/gpt-oss' 模型,并对其进行 QLoRA 微调。只需使用以下方式加载模型:

使用以下方式添加一个 Peft 层 FastLanguageModel.get_peft_model 然后在该 Peft 模型上运行 SFT 微调。

🐛gpt-oss 的错误修复

我们 最近与 Hugging Face 合作arrow-up-right 通过使用 OpenAI 的 kernel 并确保 swiglu_limit = 7.0 在 MXFP4 推理过程中被正确应用,来解决推理问题。

根据用户反馈,我们发现延长的 QLoRA 训练运行(超过 60 步)可能会导致 loss 发散并最终报错。这个问题只会发生在不支持 BF16、而是回退到 F16 的设备上(例如 T4 GPU)。重要的是,它不会影响在 A100 或 H100 GPU 上的 QLoRA 训练,也不会影响在 f16 GPU 上的 LoRA 训练。

经过大量调查,我们现在已将所有 GPU 配置下的训练 loss 行为对齐,包括仅支持 F16 的 GPU。如果你之前因此遇到问题,我们建议使用我们新的更新版 gpt-oss 笔记本!

我们不得不做很多很多实验,才能让 float16 的训练 loss 曲线与 bfloat16 机器(蓝线)保持一致。我们发现如下:

  1. 纯 float16 会在第 50 步发散到无穷大

  2. 我们发现 MoE 中的下投影存在巨大的异常值

  3. 激活值必须以 bfloat16 或 float32 保存

下图显示了 GPT OSS 20B 的绝对幅值激活,其中一些确实会骤增——由于 float16 的最大范围是 65504,这将在 float16 机器上溢出。

我们在 Unsloth 中修复了这一点,因此所有 float16 训练都可以开箱即用!

🔢 Sink Attention 的实现

OpenAI 的 sink token 实现 可在此处获取arrow-up-right。我们在下方提供:

HuggingFace transformers 的实现是 可在此处获取arrow-up-right。我们也在下方提供:

最后更新于

这有帮助吗?