💎使用 Unsloth 让 MoE 模型微调快 12 倍

使用 Unsloth 指南在本地训练 MoE LLM。

我们正在推出速度快约 12 倍的专家混合(MoE)LLM 训练, 显存减少超过 35% 以及 上下文长度延长约 6 倍 这得益于我们全新的 MoE Triton 内核和新的数学优化,且准确率不受影响。

  • Unsloth 现在支持对以下 MoE 架构进行快速训练,包括 gpt-oss, Qwen3 (30B、235B、VL、Coder)、DeepSeek R1, V3 以及 GLM(4.6, 4.7, Flash).

  • gpt-oss-20b 微调只需 12.8 GB 显存。Qwen3-30B-A3B(16-bit LoRA)使用 63GB。

  • 我们的内核可用于数据中心 GPU(B200、H100), 消费级 以及较旧的 GPU(例如 RTX 3090),并支持 FFT、LoRA 和 QLoRA。

我们与 🤗Hugging Face 合作,使用 PyTorch 全新的 torch._grouped_mm 函数,使所有 MoE 训练运行标准化。Transformers v5 最近通过比 v4 快约 6 倍的 MoE 进行了优化,而 Unsloth 借助自定义 Triton grouped‑GEMM + LoRA 内核进一步推进,带来 额外的 约 2 倍提速、显存减少超过 35% 和上下文长度增加超过 6 倍(相较 v4 总体提速 12-30 倍)。

试试我们的 Unsloth Notebook,用于快速 MoE 训练:

🦥 Unsloth MoE Triton 内核

torch._grouped_mm 之外(见 Faster MoE Training),我们还创建了自定义 Triton MoE 内核,在某些情况下甚至更快。它们也 向后兼容 更旧的硬件(如 A100)以及旧版 PyTorch。

在 A100 上,我们的 Triton 内核快约 2.5 倍torch._grouped_mm。这些内核还带有一次性的自动调优步骤,用于选择最佳内核配置。

自动调优在训练开始时只需约 2 分钟一次,但在 A100 上可使完整运行比 _grouped_mm快 35%,对于更长的运行来说非常值得。

circle-check

🧭 自动后端选择

我们的核心创新是 Split LoRA 方法 ,用于高效 MoE;与 Transformers v5 + torch._grouped_mm相比,它可减少约 35% 内存,并使训练速度提升 2 倍。自定义 torch._grouped_mm + 我们的 Triton 内核比 Transformers v4 快约 12-30 倍。

circle-exclamation

Unsloth 会根据你的硬件自动选择以下后端之一:

后端
优化

grouped_mm

torch._grouped_mm - 从 T4 一直到 B200 都可用,但针对 H100+ 进行了优化。

unsloth_triton

Unsloth Triton 内核——会在 A100 及更旧的 PyTorch 版本上自动启用。

native_torch

原生 PyTorch。它慢 12 倍,但我们的显存减少仍然存在!

你也可以自己切换它们:

circle-check

❓什么是 torch._grouped_mm?

以前,专家混合(MoE)权重被存储为 ModuleList 形式的每个专家线性层。执行前向传播的唯一实际方法是对所有专家进行 for 循环,这既昂贵又非最优。

PyTorch 最近引入了 grouped_mmarrow-up-right 来直接解决这个瓶颈。同时,我们也提供了自己针对 MoE 优化的 Triton 内核。这也与 Transformers 的一个关键变化一致:从 Transformers v5 开始,专家权重被存储为 单个 nn.Parameterarrow-up-right,这使得 grouped_mm 成为更快的 MoE 训练和推理的自然选择。

所以 transformers 4.57.6arrow-up-right 变更为:

transformers 5.0.0arrow-up-right 风格:

torch._grouped_mm 从 NVIDIA T4 开始的 GPU 都可运行,我们已经在 H100、A100、B200 和 RTX 6000 Pro 上验证过,因此支持范围很广。

我们之前还为 gpt-oss 引入了 Unsloth Flex Attention ,这些优化应该会让它更高效。

📊 内核结果 + 基准测试

下面是在不同序列长度下,训练速度和内存使用情况相对于 Transformers v5 的对比(v5 已经使用了 torch._grouped_mm 用于 MoE)。对于 gpt-oss BF16 MoE 训练,我们在 NVIDIA B200 上看到训练速度快 7 倍,显存减少 36% 。对于 Qwen3-30B-A3B,则快 1.8 倍,且 GLM 4.7 Flash 在 RTX PRO 6000 上快 2.1 倍。所有基准测试都使用 LoRA rank = 64,并将所有 LoRA 模块放在 MoE 层(gate、up、down)上。

gpt-oss 基准测试

我们微调了 unsloth/gpt-oss-20b-BF16arrow-up-right 用于基准测试。Unsloth 在 16K 上下文长度下速度快 7 倍,显存少用 36%。Transformers v5 + TRL 会发生显存溢出,而 Unsloth 不会。另外,在这种情况下,得益于我们的 Long Context gpt-oss以及我们的 MoE 内核,随着序列长度增加,提速也会进一步提升。

与 transformers v4 的比较
上下文长度
Unsloth(毫秒)
TF v5(毫秒)
Unsloth 显存(GB)
TF v5 显存(GB)
提速
显存节省

1024

275.35

376.99

40.91

43.88

1.4 倍

6.76%

2048

292.88

696.57

41.83

44.93

2.4 倍

6.89%

4096

370.30

1785.89

43.68

49.86

4.8 倍

12.39%

8192

712.33

5226.86

47.43

73.80

7.3 倍

35.73%

16384

1775.80

OOM

55.13

OOM

不适用

不适用

Qwen3 基准测试

NVIDIA B200上,我们看到 Qwen3-30B-A3B LoRA 的速度提升约 1.7 倍、内存效率提升约 35%,而且在更长的序列长度下显存节省还会进一步改善。

Qwen3-Next 和 Coder 令人惊讶地可以在单张 B200 GPU 上以 bf16 LoRA 运行。

在 H100 GPU 上,我们的表现显著优于基线,训练速度最高达到 1.77 倍 ,同时在 4K 上下文长度微调时还能节省约 5.3GB。虽然我们可以无缝扩展到 8192 上下文长度,但 Transformers v5 + TRL 在 8K 时会 OOM。请注意,我们在 8K 时使用的内存比基线在 4K 时还少,因此我们可以继续把上下文长度推得更高。

上下文长度
Unsloth(毫秒)
TF v5(毫秒)
Unsloth 显存(GB)
TF v5 显存(GB)
提速
显存节省

1024

366.3

628.3

80.88

104.80

1.7x

2.06%

2048

467.0

745.3

80.88

104.81

1.6x

2.57%

4096

711.6

975.5

80.89

104.80

1.4 倍

5.08%

8192

1376.6

1633.5

80.90

104.81

1.2x

9.17%

16384

3182.2

3407.9

85.53

116.61

1.1x

15.26%

GLM 4.7 基准测试

Unsloth 实现了 吞吐量快 2.6 倍,显存减少超过 15% ,适用于 GLM 4.7 Flash 的所有批量大小。GLM 4.7 Flash 是一个 30B MoE(3B 活跃参数)的智能体与代码模型,采用类似 DeepSeek MoE 风格的配置,具有 64 个路由专家和 1 个共享专家。我们将 Unsloth MoE 训练与新的优化版 Transformers v5 进行了基准对比。

请使用下面我们新的 GLM 4.7 Flash Colab Notebook:

GLM 4.7 Flash MoE Notebook A100 80GB
上下文长度
Unsloth(毫秒)
TF v5(毫秒)
Unsloth 显存(GB)
TF v5 显存(GB)
提速
显存节省

512

1145.0

2992.1

57.81

60.89

2.6 倍

6.51%

1024

1298.9

3323.3

58.76

62.55

2.6 倍

6.22%

2048

1831.9

4119.3

60.09

67.32

2.3 倍

9.46%

4096

2883.9

5646.1

63.34

76.78

2 倍

14.83%

⚡更快的 LoRA MoE 训练

在 Transformers/PEFT 中,通常的做法是 将 LoRA 适配器合并到基础权重中 然后再运行 MoE 计算(尤其因为 MoE 常常使用 nn.Parameter 而不是 nn.Linear)。问题在于,这种合并实际上会 把 LoRA delta(针对所有专家)具体化 lora_B @ lora_A.t,这 非常占内存.

Unsloth 避免了这一点。我们之前已经用同样的思路优化了通用 LoRA 训练和推理,现在我们也将其应用到了 MoE + LoRA 。数学上完全一致,因此损失、梯度和输出都保持不变。唯一改变的是 操作顺序,这得益于矩阵乘法的结合律。通过这种重排序,我们获得了显著的提速和显存减少。

circle-exclamation

这些优化 默认启用 ,用于使用 Unsloth 训练 MoE 模型时(尤其是 Qwen-3 MoE、gpt-oss 以及上文提到的模型)。你可以通过 UNSLOTH_MOE_BACKEND 环境变量切换实现:要么是 torch._grouped_mm Triton 内核 要么是 基础的 PyTorch for 循环,具体取决于兼容性和偏好。我们默认使用 grouped_mm 以获得最佳性能和广泛支持。

📚 实现细节

LoRA 是一种参数高效微调方法:它不是更新完整的权重矩阵,而是训练一个参数少得多的低秩“适配器”,从而大幅减少优化器内存。

如果原始权重的形状为 (m, n),LoRA 会添加两个可训练矩阵,形状分别为 (m, r) 以及 (r, n)。它们的乘积是 (m, n),但你只需要跟踪以下部分的优化器状态和梯度:

  • m*r + r*n 个参数(LoRA),而不是

  • m*n 个参数(全量微调)

circle-info

在 MoE 微调中,不建议微调路由层,所以我们默认将其禁用。

对于典型的 MLP 层, m ≈ 4096, n ≈ 12k, 且 r ≈ 64,那大约是 约 100 万个 LoRA 参数 vs 约 4800 万个完整参数 - 大约 ~2%, 通常几乎没有准确率损失。

MoE LoRA 改变了情况

MoE 层不同,因为你有 E 个专家 MLP 并行,因此任何按专家进行的更改(例如添加 LoRA)都会在所有专家上按比例扩展。

Qwen3‑30B‑A3B为例:隐藏维度 m=2048,中间层维度 n=768, ,E=128 个专家,每个 token 激活 k=8 个。每个专家:

  • gate_proj 以及 up_proj: (m, n)=(2048, 768)

  • down_proj: (n, m)=(768, 2048)

使用 LoRA rank r=64时,每个投影会增加 r*(m+n)=64*(2048+768)=180,224 个参数/专家(约 11% 一个 2048×768 矩阵的 r/n = 64/768 相对于典型的 MLP 设置来说很大,例如在 r/n = 64/25600Qwen3-32Barrow-up-right 中,规模相近。

如果你把这部分具体化到 所有 专家上,内存会迅速累积。而且由于 gate_proj 以及 up_proj 通常会融合为 gate_up_proj,你通常会把两者一起具体化,这大致会使开销/峰值内存翻倍。

在内存方面,对于序列长度 s、E 个专家以及 k 个被选中时,两种方法都有以下常见情况

从这里开始,两者开始分化。对于 peft 的方法,我们有

对于 Unsloth 的 split LoRA 方法,我们执行以下操作

现在让我们来看 Qwen3-30B-A3B 的情况。

E = 128, k = 8, m = 2048, n = 768。 代入这些值后,我们得到 s < 32K。

PEFT params:EmnUnsloth Split LoRA params:ks(r+n)In typical LoRA we have:rnSplit LoRA is better when:Emn>ksn  =  Em>ksFor Qwen3-30B-A3B, we haveE=128,k=8,m=2048,n=768So, Split LoRA is mathematically better whens<Emnkn=32K\begin{aligned} \text{PEFT params} &:\quad Emn \\ \text{Unsloth Split LoRA params} &:\quad ks(r+n) \\ \text{In typical LoRA we have} &:\quad r \ll n \\ \text{Split LoRA is better when} &:\quad Emn > ksn \;=\; Em > ks \\ \\ \text{For Qwen3-30B-A3B, we have} \\ E &= 128, \quad k = 8, \quad m = 2048, \quad n = 768 \\ \\ \text{So, Split LoRA is mathematically better when} \\ s &< \frac{Emn}{kn} = 32K \end{aligned}

在计算方面,对于序列长度 s, E 个专家以及 top k 个被选中,我们计算的是:

Δ=AB,ARm×r,  BRr×n2mnr flops per expert loraW=W+Δmn flopsXWXRs×m,  WRm×n2smn flopsMoE peft lora flops=E(2mnr+mn)+2ksmn\begin{aligned} \Delta = AB, A \in \mathbb{R}^{m \times r}, \; B \in \mathbb{R}^{r \times n} &\quad \Rightarrow \quad 2mnr \text{ flops per expert lora} \\ \\ W' = W + \Delta \quad &\Rightarrow \quad mn \text{ flops} \\ \\ XW' \quad | \quad X \in \mathbb{R}^{s \times m}, \; W' \in \mathbb{R}^{m \times n} \quad &\Rightarrow \quad 2smn \text{ flops} \\ \\ \text{MoE peft lora flops} &= E\big(2mnr + mn\big) + 2k\,smn \end{aligned}

对于前面提到的 Unsloth split LoRA,我们有

XW=2smn flopsY=XA,=2smr(applied only to routed token–expert pairs) Z=YB=2srnMoE split lora flops=2k(smn+smr+srn)Crossover condition:2ksr(m+n)>2Emn(r+1/2)s>Emnk(m+n)×(1+12r)For Qwen3-30B-A3B with:E=128,  m=2048,  n=768,  k=8s    16K tokens\begin{aligned} XW &= 2smn \text{ flops} \\ Y = XA, &= 2smr \quad \text{(applied only to routed token--expert pairs)} \\ \ Z = YB &= 2srn \\ \text{MoE split lora flops} &= 2k\big(smn + smr + srn\big) \\ \text{Crossover condition} &:\quad 2ksr(m+n) > 2Emn(r+1/2) \Rightarrow s > \frac{Emn}{k(m+n)} \times (1+ \frac{1}{2r}) \\ \\ \text{For Qwen3-30B-A3B with} &: E = 128,\; m = 2048,\; n = 768,\; k = 8 \\ \\ \Rightarrow \quad s & \;\approx\; 16\text{K tokens} \end{aligned}

从分析角度看,Split LoRA 更优直到 s > Emn/k(m+n) ,这大约相当于 16K 个 token,适用于 Qwen3-30B-A3B 风格的模型。

最后,一些提速来自 减少内存流量:现代 GPU 往往是 带宽受限的,因此传输更少的数据有时比 FLOPs 更重要。一个粗略的提速估计为 Emn / [k·s·(m+n)],因此它强烈依赖于 s、E、k以及矩阵形状。

🔮 模型支持

Unsloth 支持对以下 Qwen、gpt-oss、DeepSeek 和 GLM 模型进行更快的 MoE 训练:

  • Qwen3 (Thinking 和 Instruct):VL • 2507 • Coder

  • gpt-oss:20B • 120B • safeguard

  • GLM:4.5 • 4.6 • 4.6-Air • 4.7 • 4.7-Flash

  • DeepSeek:V3 • R1 • V3.1 • V3.2

我们可能尚未上传某些 MoE 模型,但 Unsloth 仍应支持它们。

📈 更多基准测试

gpt-oss BF16 基准测试

包含与 Transformers v4 的训练速度比较

上下文长度
Unsloth(毫秒)
TF v5(毫秒)
TF v4(毫秒)
提速

1024

275.35

376.99

2111.18

1.37 倍

2048

292.88

696.57

2626.80

2.38 倍

4096

370.30

1785.89

4027.93

4.82 倍

8192

712.33

5226.86

8513.52

7.34 倍

16384

1775.80

OOM

OOM

不适用

内存显存使用

上下文长度
Unsloth 显存(GB)
TF v5 显存(GB)
TF v4 显存(GB)
显存节省

1024

40.91

43.88

89.75

6.76%

2048

41.83

44.93

90.47

6.89%

4096

43.68

49.86

92.72

12.39%

8192

47.43

73.80

100.3

35.73%

16384

55.13

OOM

OOM

不适用

🎉 Unsloth 重要更新

  1. 作为我们 MoE 发布的一部分,我们还让 Gemma-3 现在默认使用 Flex-Attention ,而且这在 float16 设置下也适用(之前存在无穷大问题,我们在不久前已经解决)。 Gemma-3 现在使用 O(N) 内存而不是 O(N^2) 内存,训练速度快 3 倍以上 (随着上下文长度增长,扩展效果更好)。之前的 Unsloth 版本会 OOM。

上下文
旧版峰值显存
新版峰值显存
显存节省

1K

20.1 GB

20.1 GB

0 GB(0%)

2K

21.5 GB

21.1 GB

0.3 GB(2%)

4K

27.7 GB

23.3 GB

4.5 GB(16%)

8K

52.3 GB

27.5 GB

24.8 GB(47%)

16K

OOM

36.0 GB

--

24K

OOM

44.6 GB

--

32K

OOM

53.1 GB

--

48K

OOM

38.4 GB

--

64K

OOM

44.7 GB

--

  1. 视觉微调现在支持仅图片和文本数据的混合数据!

  2. trl==0.27.1 以及 transformers==5.1.0 都得到良好支持——此前我们 120 个 notebook 的覆盖率只有 30%,但现在已超过 80% 覆盖率——我们计划在接下来的几天内将其提升到 100%。

circle-check

致谢

我们感谢 Hugging Face 团队与我们合作,为社区改进 MoE 训练。

我们也真诚感谢 torchao 团队,尤其是 Vasily Kuznetsov(vkuzo),感谢他帮助我们启用 grouped_mm 对 float16 的支持,使其能够在 T4 上运行,并保持与 A100 的向后兼容性。

最后更新于

这有帮助吗?