# 長文コンテキストでのgpt-oss学習

OpenAIのgpt-oss学習向けに、Unsloth Flex Attention対応を導入できることをうれしくお知らせします。これにより **>8倍長いコンテキスト長**, **>50%少ない VRAM 使用量** および **1.5倍超の高速学習（精度低下なし）** Flash Attention 3（FA3）を使う実装を含むすべての実装と比べて実現します。Unsloth Flex Attentionにより、 **6万トークンのコンテキスト長** を80GB VRAMのH100 GPUでBF16 LoRAとして学習できます。さらに：

* 今すぐ [エクスポート／保存して](#new-saving-to-gguf-vllm-after-gpt-oss-training) QLoRAで微調整したgpt-ossモデルをllama.cpp、vLLM、Ollama、またはHFに出力できます
* 私たちは [**gpt-oss学習の**](#bug-fixes-for-gpt-oss) **損失が無限大になる問題を** float16 GPU（T4 Colabなど）で修正し、
* 私たちは [gpt-oss実装の](#bug-fixes-for-gpt-oss) Unslothには無関係な問題、特に次を正しく適用することを確認しました `swiglu_limit = 7.0` transformersでのMXFP4推論中に適切に適用されるようにしました

## 🦥 Unsloth Flex Attention対応の導入

UnslothのFlex Attention対応により、単一の80GB VRAM H100でQLoRAなら最大8.1万トークン、BF16 LoRAなら6万トークンのコンテキスト長を扱えます。これらの改善は **両方の** gpt-oss-20b と **gpt-oss-120b**! 使用するコンテキスト長が長いほど、Unsloth Flex Attentionによる恩恵も大きくなります。

<figure><img src="/files/112104c100881e328fac4707f14d1081d315bda3" alt="" width="563"><figcaption></figcaption></figure>

比較すると、Unsloth以外の実装はすべて80GB GPUでは最大9Kのコンテキスト長にとどまり、FA3でも15Kまでしか到達できません。しかし、 **FA3はattention sinkのバックワードパスをサポートしていないため、gpt-oss学習には不向きです**。そのため、これまでgpt-oss学習でFA3を使っていた場合は、 **今は使わないこと** をおすすめします。したがって、80GB VRAMでUnslothなしに得られる最大コンテキスト長は約9Kです。

Unsloth Flex Attentionでの学習は少なくとも1.3倍高速になり、コンテキスト長が長くなるほど改善幅も大きくなり、最大で2倍高速になります。Flex Attentionはコンテキストに応じてスケールするため、長い系列ほどVRAMと学習時間の両方でより大きな節約が得られます。 [こちらで説明されているように](#unsloths-flex-attention-implementation).

Rohan Pandeyによる [Flex Attentionの実装](https://x.com/khoomeik/status/1955693558914310608)に大きな感謝を。これがUnslothのFlex Attention実装の開発に直接インスピレーションを与えました。

## :dark\_sunglasses: Attention Sinks

OpenAIのGPT OSSモデルは、 **スライディングウィンドウ注意とフル注意を交互に繰り返すパターン**（SWA, FA, SWA, FA, ...）を使用しています。各スライディングウィンドウが注目するのは **128トークン** （現在のトークンを含む）だけなので、計算量は大幅に削減されます。しかし、その一方でウィンドウが小さいため、長文コンテキストの検索や推論はほぼ使い物になりません。多くの研究室ではこれを、スライディングウィンドウを2048または4096トークンに拡張することで解決しています。

OpenAIは **Attention Sinks** Efficient Streaming Language Models with Attention Sinks [論文の手法を活用しました](https://arxiv.org/abs/2309.17453) 。この論文では、小さなスライディングウィンドウを使いつつ、最初のトークンにグローバル注意を加える必要があることを示しています。下図はその良い例です：

<figure><img src="/files/e51ca02790b5dc44267d7094e1f31ce052d88b48" alt=""><figcaption></figcaption></figure>

論文では、 **注意機構が最初の数トークン（1〜4）に大きな重みを割り当てるように見える**ことが示されており、スライディングウィンドウ処理中にそれらを取り除くと、これらの「重要な」最初の数トークンが消えてしまい、長文コンテキストの検索性能が悪化します。

log perplexity（高いほど悪い）を描き、事前学習モデルの設定コンテキスト長を超えて長文コンテキスト推論を行うと、perplexityが急上昇するのが分かります（良くありません）。しかし、赤線（Attention Sinksを使用）は低いままで、これは非常に良いことです！

<figure><img src="/files/3481ec42c3439bc156f46c8a8abb0ac02ad2e1a8" alt=""><figcaption></figcaption></figure>

論文ではさらに、 [Attention Is Off By One手法](https://www.evanmiller.org/attention-is-off-by-one.html) も部分的には有効だが、より低いperplexityを得るには追加のsinkトークンも必要だと示しています。 **論文では、学習可能な単一のsinkトークンを追加するだけで驚くほどうまくいくことが示されています！ そしてOpenAIはGPT-OSSでこれを実施しました！**

<figure><img src="/files/6d744a38b6fad2e2170736362aae8f4f5a2eac75" alt=""><figcaption></figcaption></figure>

## :triangular\_ruler:UnslothのFlex Attention実装

Flex Attention <https://pytorch.org/blog/flexattention/> は非常に強力です。というのも、注意機構に対して実務者に2つのカスタマイズ手段を提供するからです。すなわち **スコア修飾子（f）** と **マスキング関数（M）**.

この **スコア修飾子（f）** で、softmax演算の前にattention logitsを編集できます。また **マスキング関数（M）** では、必要ない演算をスキップできます（例えばスライディングウィンドウ注意は最後の128トークンだけを見る）。

<mark style="background-color:green;">**ポイントは、Flex Attentionが任意のスコア修飾子とマスキング関数を備えた高速な自動生成Tritonカーネルを提供することです！**</mark>

<p align="center"><span class="math">\sigma\bigg(s\times\bold{f}(QK^T+\bold{M})\bigg)</span><br></p>

つまり、Flex Attentionを使ってattention sinkを実装できるということです。単一のattention sinkの実装は、 [OpenAIの元のGPT-OSSリポジトリ](#implementations-for-sink-attention) とHuggingFaceのtransformers実装の両方にあります。

```python
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
probs = F.softmax(combined_logits, dim=-1)
scores = probs[..., :-1]
```

上の式は、sinkを `Q @ K.T` の最後に連結し、softmaxを行い、sinkトークンである最後の列を取り除いていることを示しています。

次のような可視化ユーティリティを使うと [Flex AttentionのGitHubリポジトリの](https://github.com/meta-pytorch/attention-gym)これを可視化できます。系列長が16、スライディングウィンドウが5だったとします。左は最後のsink列（デフォルト実装）、右はsinkの位置をインデックス0に移動した場合（私たちの実装）です。

{% columns %}
{% column %}
***末尾にsinkを置く（デフォルト）***

<figure><img src="/files/b904016c65a3e32d6acb51f2d082645f4d7800fd" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column %}
***sinkの位置をインデックス0に移動***

<figure><img src="/files/bdd2f6be076b401b6061053a5fa6305882167edb" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

**興味深い発見**：公式のFlex Attentionのスライディングウィンドウ実装では、ウィンドウサイズを最後のトークン数 **＋1** として扱います。これは現在のトークンを含むためです。HuggingFaceおよびGPT OSSの実装は、厳密には最後のNトークンのみを見ます。つまり、以下は <https://pytorch.org/blog/flexattention/> および <https://github.com/meta-pytorch/attention-gym>:

{% code overflow="wrap" %}

```python
def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask
```

{% endcode %}

{% columns %}
{% column %}
デフォルトのFlex Attention（3+1トークン）

<figure><img src="/files/9dcc358493f2b32365827734106aeff7e0794009" alt=""><figcaption></figcaption></figure>
{% endcolumn %}

{% column %}
HuggingFace、GPT-OSS（3+0トークン）

<figure><img src="/files/ef660edee28cb0f5e56cec6fe1649ad49d870746" alt=""><figcaption></figcaption></figure>
{% endcolumn %}
{% endcolumns %}

また、OpenAIの公式GPT-OSS実装を通じて、ここで最後のNトークンを見るのか、それともN+1トークンを見るのかも確認しました： <https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py>

```python
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
    mask += torch.tril(
        mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
    )
```

<figure><img src="/files/2f4c4a54589c1531e755467f0f5342aa2b5299f2" alt=""><figcaption></figcaption></figure>

そして、注目されるのは最後の3トークンのみ（3+1ではない）であることが分かります！ つまり、次の代わりに `<= SLIDING_WINDOW`、次を使用: `< SLIDING_WINDOW` （つまり、等号を含めず「より小さい」を使う）。

```python
def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW # デフォルトのFlex Attention
    window_mask = q_idx - kv_idx <  SLIDING_WINDOW # GPT-OSS版
    return causal_mask & window_mask
```

また、sinkトークンのインデックスを最初に移動したので、正しくインデックスするにはq\_idxに1を加える必要があります：

```python
def causal_mask_with_sink(batch, head, q_idx, kv_idx):
    """
      0 1 2 3     0 1 2 3
    0 X X       1   X
    1 X X X     2   X X
    2 X X X X   3   X X X
    """
    # 最初の列はsinkトークンなので(q_idx + 1)を加える
    causal_mask = (q_idx + 1) >= kv_idx
    sink_first_column = kv_idx == 0
    return causal_mask | sink_first_column
```

インデックス0への実装を確認するため、学習損失が標準のHugging Face実行（Unsloth Flex Attentionなし）と一致することを、次のグラフで検証しました：

<figure><img src="/files/bde7659ec4e86a3dd7407fe584af917332eaeff3" alt="" width="375"><figcaption></figcaption></figure>

## :scroll: attention sinkの数学的導出

KとVをパディングせずにattention sinkを計算する別の方法があります。まずsoftmax演算が何をするかに注目し、今のところsink付きの第2版をスカラーとして扱いたいと思います：\\

$$
A(x) = \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \\
A\_{sink}(x) = \frac{\exp(x\_i)}{\exp{(s)}+ \sum{\exp{(x\_i)}}}
$$

Flex Attentionからlogsumexpを取得するには `return_lse = True` を使うので、次のようにします：

$$
A(x) = \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \\
\frac{\exp(x\_i)}{\exp{(s)}+ \sum{\exp{(x\_i)}}} =  \frac{\exp(x\_i)}{\sum{\exp{(x\_i)}}} \frac{\sum{\exp{(x\_i)}}}{\exp{(s)}+ \sum{\exp{(x\_i)}}} \\
\text{LSE}(x) = \text{logsumexp}(x) = \log{\sum\exp(x\_i)} \\
\exp{(\text{LSE}(x))} = \exp{\big(\log{\sum\exp(x\_i)}\big)} = \sum\exp(x\_i)
$$

これでattentionのsink版を簡単に導出できます。ただし、この手順はゼロパディング方式よりやや誤差が大きいことが分かったため、引き続き元の方式をデフォルトにしています。

## 💾**新機能: gpt-oss 学習後の GGUF、vLLM への保存**

これで gpt-oss を QLoRA でファインチューニングし、モデルを直接 **llama.cpp**, **vLLM**、または **HF** へ保存、エクスポート、またはマージできます。Unsloth だけではありません。無料ノートブックをできるだけ早く公開する予定です。

これまでは、QLoRAで微調整したgpt-ossモデルはUnsloth内でしか実行できませんでした。そこで、 **MXFP4** **ネイティブ形式** でマージする機能を追加し、 `save_method="mxfp4"` および **MXFP4 のオンデマンド逆量子化** gpt-ossのようなベースモデルを使って **微調整済みモデルをbf16形式で書き出せるようにしました。以下を使います：** `save_method="merged_16bit"` .

この **MXFP4** ネイティブマージ形式は、 **bf16 形式**と比較して大きな性能向上を提供します。ディスク容量を最大 75% 節約し、VRAM 消費を 50% 削減し、マージを 5〜10 倍高速化し、さらに **GGUF** 形式への変換を大幅に高速化します。

gpt-oss モデルのファインチューニング後、それを **MXFP4** 形式に次のようにマージできます:

```python
model.save_pretrained_merged(save_directory, tokenizer, save_method="mxfp4")
```

モデルをマージしてhugging-face hubにpushしたい場合は、次を使ってください：

```python
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token, save_method="mxfp4")
```

マージ済みモデルで推論を行うには、vLLMやLlama.cppなどを使えます。OpenAIは両モデルに対して次の [推論設定](/docs/jp/moderu/gpt-oss-how-to-run-and-fine-tune.md#recommended-settings) を推奨しています： `temperature=1.0`, `top_p=1.0`, `top_k=0`

#### :sparkles: Llama.cpp への保存

1. 最新の `llama.cpp` を [GitHub こちら](https://github.com/ggml-org/llama.cpp)から取得してください。以下のビルド手順に従うこともできます。 `-DGGML_CUDA=ON` を `-DGGML_CUDA=OFF` に変更してください。GPU がない場合、または CPU 推論だけを使いたい場合です。

   ```bash
   apt-get update
   apt-get install pciutils build-essential cmake curl libcurl4-openssl-dev -y
   git clone https://github.com/ggml-org/llama.cpp
   cmake llama.cpp -B llama.cpp/build \\
       -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON
   cmake --build llama.cpp/build --config Release -j --clean-first --target llama-cli llama-gguf-split
   cp llama.cpp/build/bin/llama-* llama.cpp
   ```
2. 次を変換します **MXFP4** マージ済みモデル:

   ```bash
   python3 llama.cpp/convert_hf_to_gguf.py gpt-oss-finetuned-merged/ --outfile gpt-oss-finetuned-mxfp4.gguf
   ```
3. 量子化済みモデルで推論を実行する:

   ```bash
   llama.cpp/llama-cli --model gpt-oss-finetuned-mxfp4.gguf \
       --jinja -ngl 99 --threads -1 --ctx-size 16384 \
       --temp 1.0 --top-p 1.0 --top-k 0 \
        -p "人生と宇宙の意味は"
   ```

<details>

<summary><span data-gb-custom-inline data-tag="emoji" data-code="2728">✨</span> SGLangへの保存</summary>

1. SGLangをソースからビルド：\\

   ```bash
   # ソースからビルド
   git clone https://github.com/sgl-project/sglang
   cd sglang
   pip3 install pip --upgrade
   pip3 install -e "python[all]"

   # ROCm 6.3
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/rocm6.3
   git clone https://github.com/triton-lang/triton
   cd python/triton_kernels
   pip3 install .

   # hopper
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu126
   pip3 install sgl-kernel==0.3.2

   # blackwell cu128
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
   pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2+cu128-cp39-abi3-manylinux2014_x86_64.whl

   # blackwell cu129
   pip3 install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu129
   pip3 install https://github.com/sgl-project/whl/releases/download/v0.3.2/sgl_kernel-0.3.2-cp39-abi3-manylinux2014_x86_64.whl
   ```
2. SGLangサーバーを起動：\\

   ```bash
   python3 -m sglang.launch_server --model-path ./gpt-oss-finetuned-merged/
   ```
3. 推論を実行：\\

   ```python
   import requests
   from sglang.utils import print_highlight

   url = f"http://localhost:8000/v1/chat/completions"

   data = {
       "model": "gpt-oss-finetuned-merged",
       "messages": [{"role": "user", "content": "フランスの首都はどこですか？"}],
   }

   response = requests.post(url, json=data)
   print_highlight(response.json())
   ```

</details>

### :diamonds:gpt-ossを直接ファインチューニング

ネイティブMXFP4量子化フォーマットの読み込みを可能にするパッチを実装し、gpt-ossモデルを直接ファインチューニングできるようにもしました。これにより、'openai/gpt-oss'モデルを24GB未満のVRAMで読み込み、QLoRAで微調整できます。単純に以下を使ってモデルを読み込んでください：

```python
model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "unsloth/gpt-oss-20b-BF16", 
    model_name = "unsloth/gpt-oss-20b",
    dtype = dtype, # 自動検出の場合は None
    max_seq_length = max_seq_length, # 長いコンテキスト用に任意の値を選択！
    load_in_4bit = True,  # メモリ削減のための 4 ビット量子化
    full_finetuning = False, # [NEW!] フルファインチューニングに対応しました！
    # token = "hf_...", # gated モデルを使う場合はこれを使用
)
```

Peftレイヤーを追加して `FastLanguageModel.get_peft_model` Peftモデル上でSFTファインチューニングを実行します。

## 🐛 gpt-oss向けバグ修正

私たちは [最近、Hugging Faceと協力して](https://github.com/huggingface/transformers/pull/40197) OpenAIのカーネルを使用し、MXFP4推論中に `swiglu_limit = 7.0` が正しく適用されるようにして、推論上の問題を解決しました。

ユーザーからのフィードバックに基づき、長時間のQLoRA学習（60ステップ超）で **損失が発散して最終的にエラーになる**可能性があることを発見しました。この問題はBF16非対応でF16にフォールバックするデバイス（例：T4 GPU）でのみ発生していました。重要なのは、A100やH100 GPUでのQLoRA学習、またf16 GPUでのLoRA学習には影響しなかったことです。

**徹底的に調査した結果、F16に制限されたGPUを含むすべてのGPU構成で学習損失の挙動を揃えることができました**。この問題で以前困っていた場合は、新しい更新版のgpt-ossノートブックの使用をおすすめします！

<figure><img src="/files/dc604a45f4533be2d20756d13cf8ea078fadd88c" alt=""><figcaption></figcaption></figure>

float16の学習損失曲線をbfloat16マシン（青線）と同等にするために、何度も何度も実験を行う必要がありました。次のことが分かりました：

1. **純粋なfloat16は50ステップ目で無限大になります**
2. **MoEのダウンプロジェクションに非常に大きな外れ値があることが分かりました**
3. **活性化はbfloat16またはfloat32で保存する必要があります**

**以下はGPT OSS 20Bの絶対値活性化を示しており、非常に大きなスパイクがいくつかあります。float16の最大範囲は65504なので、float16マシンではオーバーフローします。**

**これをUnslothで修正したので、すべてのfloat16学習がそのまま動作します！**

<figure><img src="/files/2d4ee694f50f5e8b62b893668c6ed76f48bf3bd2" alt=""><figcaption></figcaption></figure>

## :1234: Sink Attentionの実装

OpenAIのsinkトークン実装は [こちらで提供されています](https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py)。以下に示します：

{% code fullWidth="false" %}

```python
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
    # sliding_window == 0 はスライディングウィンドウなしを意味する
    n_tokens, n_heads, q_mult, d_head = Q.shape
    assert K.shape == (n_tokens, n_heads, d_head)
    assert V.shape == (n_tokens, n_heads, d_head)
    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
    if sliding_window > 0:
        mask += torch.tril(
            mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
        )
    QK = torch.einsum("qhmd,khmd->hmqk", Q, K) * sm_scale
    QK += mask[None, None, :, :]
    QK = torch.cat([QK, S], dim=-1)
    W = torch.softmax(QK, dim=-1)
    W = W[..., :-1]
    attn = torch.einsum("hmqk,khmd->qhmd", W, V)
    return attn.reshape(n_tokens, -1)
```

{% endcode %}

HuggingFace transformersの実装は [こちらで提供されています](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/modeling_gpt_oss.py)。以下にも示します：

{% code fullWidth="false" %}

```python
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
    combined_logits = torch.cat([attn_weights, sinks], dim=-1)

    # これは元の実装にはなく、結果にわずかに影響します。BF16/FP16でのオーバーフローを防ぎます
    # bsz>1で学習するとき、最大値をクリップします。

    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
    scores = probs[..., :-1]  # ここでsinkを取り除く
    attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights
```

{% endcode %}


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://unsloth.ai/docs/jp/moderu/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
