# 50 万上下文长度微调

我们在 Unsloth 中引入了新的算法，推动长上下文训练的极限，适用于 **任何 LLM 和 VLM**。像 gpt-oss-20b 这样的 LLM 现在可以达到 **500K+ 的上下文长度** 在单个 80GB H100 GPU 上，相比之前的 80K，且没有精度下降。

你可以达到 >**750K 的上下文窗口** 在 B200 192GB GPU 上。

> **在我们的** [**80GB A100 Colab 笔记本**](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt_oss_\(20B\)_500K_Context_Fine_tuning.ipynb)**.**

上尝试 500K 上下文 gpt-oss-20b 微调。我们显著改进了 Unsloth 对内存使用模式、速度和上下文长度的处理：

* **VRAM 使用降低 60%** 与 **3.2 倍更长的上下文** 通过 Unsloth 的新 [融合与分块交叉熵](#unsloth-loss-refactoring-chunk-and-fuse) 损失，在速度或精度上没有退化
* Unsloth 的激活卸载得到增强，体现在 [**梯度检查点**](#unsloth-gradient-checkpointing-enhanced)
* 我们与 Snowflake 的 Stas Bekman 合作研究 [切片 MLP](#tiled-mlp-unlocking-500k)，使得上下文数量增加 2×

Unsloth 的算法使得 gpt-oss-20b QLoRA（4bit）在 H100 上可实现 290K 上下文（无精度损失），并在启用 Tiled MLP 时达到 500K+，整体实现 >**6.4× 更长的上下文长度。**

<figure><img src="https://2657992854-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2F8Ha930qR5XXBOK7M7oiy%2Fline_chart_light_tiled.png?alt=media&#x26;token=51467f68-a77b-4037-b9d9-e668223868c5" alt="" width="563"><figcaption></figcaption></figure>

### 📐 Unsloth 损失重构：分块与融合

我们新的融合损失实现增加了 **动态序列分块**：我们不是一次性对整个序列计算语言模型头的 logits 和交叉熵，而是沿扁平化的序列维度处理可管理的切片。这将峰值内存从 GB 级别降低到较小的分块尺寸。每个分块仍通过 `torch.func.grad_and_value` 运行完整的融合前向+反向传递，并在必要时通过提升到 float32 保持混合精度的准确性。 **这些改变不会降低训练速度或精度。**

<figure><img src="https://2657992854-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FFF43WA1X8Y4vADBrCi8T%2Fline_chart_light.png?alt=media&#x26;token=7afc7f73-bc54-403a-9674-8a16841ec659" alt="" width="563"><figcaption></figcaption></figure>

关键创新在于 **分块大小在运行时自动选择** ，基于可用的 VRAM。

* 如果你有更多可用 VRAM，会使用更大的分块以获得更快的运行速度
* 如果你的 VRAM 更少，它会增加分块数量以避免内存溢出。

这 **消除了手动调优** 并使我们的算法在旧 GPU 和新 GPU、不同工作负载和不同序列长度下保持稳健。

{% hint style="success" %}
由于自动调优， **较小的上下文会使用更多 VRAM** （更少的分块）以 **避免不必要的开销**。对于上面的图表，我们调整了损失分块的数量以反映现实的 VRAM 等级。使用 80GB VRAM 时，这会带来 >3.2× 更长的上下文。
{% endhint %}

### 🏁 Unsloth 梯度检查点增强

我们的 [Unsloth 梯度检查点](https://unsloth.ai/blog/long-context) 算法， **于 2024 年 4 月推出**，迅速流行并成为行业标准，如今已被集成到大多数训练包中。它将激活卸载到 CPU 内存，从而允许 10 倍更长的上下文长度。我们的新增强使用了 CUDA 流和其他技巧，最多只增加 **0.1%** 训练开销且不影响精度。之前它增加了 1 到 3% 的训练开销。

{% code expandable="true" %}

```python
# 原始 Unsloth 版本发布于 2024 年 4 月 - LGPLv3 许可
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
    @staticmethod
    @torch_amp_custom_fwd
    def forward(ctx, forward_function, hidden_states, *args):
        ctx.device = hidden_states.device
        saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
        with torch.no_grad():
            output = forward_function(hidden_states, *args)
        ctx.save_for_backward(saved_hidden_states)
        ctx.forward_function, ctx.args = forward_function, args
        return output

    @staticmethod
    @torch_amp_custom_bwd
    def backward(ctx, dY):
        (hidden_states,) = ctx.saved_tensors
        hidden_states = hidden_states.to(ctx.device, non_blocking = True).detach()
        hidden_states.requires_grad_(True)
        with torch.enable_grad():
            (output,) = ctx.forward_function(hidden_states, *ctx.args)
        torch.autograd.backward(output, dY)
        return (None, hidden_states.grad,) + (None,)*len(ctx.args)
```

{% endcode %}

通过在激活一生成时立即卸载它们，我们将峰值激活占用降到最低，并在需要时准确地释放 GPU 内存。这大幅减轻了长上下文或大批量训练中的内存压力，其中单个解码器层的激活可能超过 2 GB。

> **因此，Unsloth 的新算法与梯度检查点为大多数改进（3.2×）做出了贡献，使得在单个 H100 上能够进行 290k 上下文长度的 QLoRA GPT-OSS 微调。**

### 🔓 切片（Tiled）MLP：解锁 500K+

在 [Stas Bekman](https://x.com/StasBekman) （Snowflake）的帮助下，我们将 Snowflake 的 Arctic 长序列训练中的 Tiled MLP 集成进来 [论文](https://arxiv.org/abs/2506.13996) 和博客文章。TiledMLP 通过在进行大型 MLP 投影前沿序列维度对隐藏状态进行切片，减少了激活内存并支持更长的序列长度。

**我们还引入了一些使用体验方面的改进：**

我们在切片前向重计算中保留随机数生成器（RNG）状态，以便 dropout 和其他随机操作在前向与后向重放之间保持一致。这保持了嵌套检查点计算的稳定性和数值一致性。

{% hint style="success" %}
我们的实现会自动修补任何名称或类型为 `mlp`的模块，所以 **几乎所有带有 MLP 模块的模型开箱即支持 Tiled MLP。**
{% endhint %}

**需要注意的权衡**

TiledMLP 通过额外的前向传递来节省 VRAM。因为它存在于一个被检查点的 transformer 块内并且自身以检查点风格编写，它实际上成为了一个嵌套检查点：一个 **MLP 现在每步大约执行 \~3 次前向和 1 次后向**。作为回报，我们几乎可以从 VRAM 中去除所有中间 MLP 激活，同时仍然支持极长的序列。

<figure><img src="https://2657992854-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FdeOJEEqucGYtbXbb7nqB%2Fbaseline_vs_unsloth_spike.png?alt=media&#x26;token=3b1cdfd3-dd24-4c94-b7ec-5d1366464afb" alt=""><figcaption></figcaption></figure>

图表比较了单个解码器层在长上下文训练步骤中前向和后向的活动内存时间线，左侧为未使用 Tiled MLP，右侧为使用 Tiled MLP。未使用 Tiled MLP 时，峰值 VRAM 出现在 MLP 的后向阶段；使用 Tiled MLP 时，峰值移到融合损失计算处。我们观察到约 40% 更低的 VRAM 使用，并且由于融合损失会根据可用 VRAM 动态分块，在更小的 GPU 上启用 Tiled MLP 时峰值会更小。

<figure><img src="https://2657992854-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2FxhOjnexMCB3dmuQFQ2Zq%2Fuploads%2FUCx0X7S5FvaD3hUsma5j%2Fbaseline_vs_unsloth_nospike.png?alt=media&#x26;token=a81b8639-21d0-43aa-a837-8209949e8742" alt=""><figcaption></figcaption></figure>

为表明交叉熵损失不是新的瓶颈，我们将其分块大小固定而非动态选择，然后将分块数量翻倍。这显著降低了与损失相关的内存峰值。现在最大内存在两种情况下都发生在后向阶段，整体时间相似，尽管 Tiled MLP 增加了少量开销：一个大的 GEMM 变成了许多顺序矩阵乘法，加上上文提到的额外前向传递。

总体而言，这个权衡是值得的：没有 Tiled MLP 时，长上下文训练大约需要 2 倍的内存，而有了 **Tiled MLP 单个 GPU 对于相同的上下文长度只需约 1.3× 的步时增加。**

**在 Unsloth 中启用 Tiled MLP：**

```py
model, tokenizer = FastLanguageModel.from_pretrained(
    ...,
    unsloth_tiled_mlp = True,
)
```

只需设置 `unsloth_tiled_mlp = True` 在 `from_pretrained` 中，Tiled MLP 即可启用。我们遵循 Arctic 论文的相同逻辑并选择 `num_shards = ceil(seq_len/hidden_size)`。每个切片将作用于与模型隐藏维度相同大小的序列长度，以平衡吞吐量和内存节省。

我们还讨论了 Tiled MLP 实际上执行 3 次前向和 1 次后向，相比之下普通的梯度检查点在 Stas Bekman 和 [DeepSpeed](https://github.com/deepspeedai/DeepSpeed/pull/7664) 的协作下说明，普通方法执行 2 次前向和 1 次后向，并且 DeepSpeed 提供了关于 Tiled MLP 的文档更新。

{% hint style="success" %}
下次微调出现内存不足时，尝试开启 `unsloth_tiled_mlp = True`。只要上下文长度超过 LLM 的隐藏维度，这应该能节省一些 VRAM。
{% endhint %}

***

**通过我们最新的更新，现在有可能在单个 GPU 上用更小的模型达到 1M 的上下文长度！**

**在我们的** [**80GB A100 Colab 笔记本**](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt_oss_\(20B\)_500K_Context_Fine_tuning.ipynb)**.**

如果你看到这里，我们本周将发布一篇关于我们在训练速度方面最新改进的新博客，敬请关注并加入我们的 [Reddit r/unsloth](https://www.reddit.com/r/unsloth/) 或我们的文档。
