Finetune Mistral 14x faster

Unsloth open source:

Mistral 7B

2.2x1xA100

faster

Code Llama 34B

1.9x1xA100

faster

Llama 7B

2.2x1xA100

faster

Llama 7B

2x1xT4

faster

You can now QLoRA finetune Mistral 7B 2.2x on 1x A100 faster with 62% less memory or 12.4GB peak VRAM. CodeLlama 34B is 1.9x faster on 1x A100, using 32% less memory or 27GB peak VRAM. It finally doesn’t OOM!

Unsloth Pro version:

Mistral 7B

14x1xA100

faster

Code Llama 34B

13x1xA100

faster

Llama 7B

21x1xA100

faster

Llama 7B

28x2xT4

faster

Unsloth Pro version:

Mistral 7B

-70%1xA100

peak VRAM

Code Llama 34B

-50%1xA100

peak VRAM

Llama 7B

-71%1xA100

peak VRAM

Llama 7B

-44%2xT4

peak VRAM

`pip install "unsloth[cu118_ampere] @ git+https://github.com/unslothai/unsloth.git"`

pip install "unsloth[cu121_ampere] @ git+https://github.com/unslothai/unsloth.git"

pip install "unsloth[colab_ampere] @ git+https://github.com/unslothai/unsloth.git"

Benchmarking

We benchmark Unsloth against Hugging Face’s original implementation, and against adding Flash Attention 2 support on 1x A100 via Google Colab. Flash Attention at most speeds up training by 1.2x, whilst with Unsloth’s open source package, training is 2.2x faster. “Unsloth Equal” is our PRO version, but under the condition that all settings and the loss curve stays the same. Under this scenario, we further boost training to 2.7x. Our MAX version can boost speeds on the LAION dataset to 21x!

All benchmarks use the following setup (unless if some tests OOM, in which we decrease the batch size for all tests):

`QLoRA nf4 layers = [`

"q_proj", "k_proj", "v_proj", "o_proj",

"gate_proj", "up_proj", "down_proj",

]

QLoRA rank = 16, alpha = 16, dropout = 0

max_seq_length = 2048

learning_rate = 2e-4

weight_decay = 0.01

max_steps = 240

warmup_steps = 10

batch_size = 4

gradient_accumulation_steps = 4

lr_scheduler_type = "linear"

optimizer = "adamw_8bit", bfloat16

use_gradient_checkpointing = True

random_state = 3407

On Mistral 7B for 1x A100, Flash Attention v2 boosts training by 1.15x, whilst Unsloth Open boosts it by 2.15x, and we reduce memory by 62%. Again, "Unsloth Equal" which runs an equalized training run boosts speeds by 2.53x, and uses 69% less peak VRAM. Unsloth MAX boosts training by 13.7x.

You can click on "Code" to access our shared notebooks for reproducibility purposes. "Unsloth Equal" only shows the training losses and obscures our other codepaths.

At the end of this blog post, we provide the whole table of all benchmarks, and all 59 notebook links for reproducibility purposes.

Performance breakdowns bit by bit

1. Reduce data upcasting

By reducing upcasting of weights during QLoRA, we can easily save 7.2% of VRAM, and make training take 21.7% less time.2. Bitsandbytes bfloat16

Bitsandbytes internally uses float16, so we have to do an extra memory copy to convert it to bfloat16. We fix this internally, saving 9% time.3. Scaled Dot Product Attention

We use Pytorch's fast implementation of attention, saving 1.4% time. 4, 5, 6. Causal Masking, Xformers, Flash Attention 2

By using a causal mask and not a separate attention mask, we made things 8.1% faster, since we don't need to read the attention matrix. We then switch over to using Xformers, which makes things 8.1% faster and save a whopping 39% of VRAM usage. Switching to Flash Attention v2 had no noticeable effect, since Xformers calls FA2 internally anyways.7. Fast RoPE Embeddings

By implementing RoPE Embeddings in OpenAI's Triton, we save another 7.6% of time. But to do so, we must find the derivative of the RoPE function through manual inspection. Notice RoPE can be rewritten as a matrix multiplication between a rotation matrix R and the original matrix Q. If we do this, the derivative is simply R transpose.8. Fast RMS Layernorm

Unfortunately, the RMS Layernorm's derivative is much more involved. If you carefully use the chain rule and carefully derive the derivative, we get an ugly derivative for the RMS Layernorm. We again implement this into OpenAI's Triton language, which boosts training by 3.1%.9. Fast Cross Entropy Loss

The Cross Entropy Loss is again a bit more involved. We use the log trick where x = exp(log(x)) to derive the derivative. We also use Wikipedia to find the derivative of the infamous logsumexp function is in fact the softmax function! We slash VRAM usage by 17%.

10, 11. Manual Autograd

By bracketing correctly, we can massively reduce the actual number of FLOPs during LoRA finetuning! Normally Pytorch's autograd engine backpropagates through the graph from the end to the start. We find by fusing multiple operations into 1, and bracketing correctly through Chained Matrix Multiplication, the actual # of FLOPs is reduced.If you bracket incorrectly, like what Pytorch's autograd currently does, you first do the multiplication of X.T and dW. Take X to be of size (bsz, seq_len, d). We then reshape X to be of size (m, d), where m is simply bsz * seq_len. d is the attention dimension. In Llama 7b it's 4096, whilst in Llama 70b it's 8192.dW is of size (m, h) where h is the MLP intermediate size. For Llama 7b it's 11,008 and Llama 70b it's 28,672. And B.T is the LoRA weight of size (h, r), where r is the rank of the LoRA matrix, which can be a small 16 or 64.

We find that the slow path takes around (h * d)(m + r) FLOPs.And the fast path, where we instead bracket on the 2nd term takes (m * r)(h + d) FLOPs. We can then divide the slow path by the fast path to get a speedup fraction:To simplify the above, notice normally r is quite small, say 16 or 64. m can be very big, as a batch size of 4 and a sequence length of 4096 can make m = 4 * 4096 = 16,384. This makes (m + r) insignificant, and we drop the addition term. We cannot do this for (m * r), since a multiplication of 16 much bigger than an addition of 16.If we do this, we get a simplified expression, where the speedup is a function of the MLP intermediate size h, the attention size d and the LoRA rank r.

For Llama 7b where h = 11,008 and d = 4,096 and r = 16, we get a speedup of 186.58.

For Llama 70b where h = 28,672 and d = 8,192 and r = 16, we get a speedup of 398.22.

Other features

- 152334H managed to make Unsloth work with DPO! It's still preliminary support, but it seems like it works via TRL.
- RandomInternetPreson managed to make Unsloth work on WSL! So Windows support is currently in preliminary!
- Other bug fixes - supports all vocab sizes up to 2^16 (65536), group query attention now works correctly.
- GQA on older GPUs is now fully supported via Xformers - we had to manually reshape K and V to make Xformers get tricked into doing a normal attention calculation. Sadly Xformers does not support the backward pass for GQA.

FAQ

- Q: Do you support Mixtral?

We're working on it! - Q: How we do buy PRO or MAX?

We're working on a platform now. Stay tuned! - Q: Do we reduce FLOPs?

Yes. - Q: Does full finetuning work on the Open Source version?

No. See Issue. All optimizations are turned off, so you will see no noticeable speed improvement, other than from Flash Attention and some Triton kernels. - Q: Is LoRA, so not QLoRA supported?

Yes. Pass in load_in_4bit to be False. - Q: Does the PRO / MAX support full finetuning, pretraining?

Yes.

Thank you for reading! 🦥

Daniel Han13 December 2023

Full benchmarking tables

© 2024 unsloth. All rights reserved.