You can now finetune Google's latest Gemma 2 (9B) model 2x faster and use 63.2% less memory than Flash Attention 2 (FA2) + Hugging Face (HF). Gemma 2 (27B) is 1.9x faster and uses 51% less VRAM. We also worked with the amazing Gemma & HF team to fix some minor bugs which you can read about further below.
You can now finetune Gemma 2 (27b) with QLoRA to 9.7K context lengths (8K max supported) with Unsloth in a 40GB GPU, whilst HF+FA2 only allows 3K lengths. Unsloth can do 11K context lengths on a 24GB card for the 9b model, whilst HF+FA2 fits 2.6K lengths. That's 4-5x longer contexts with Unsloth!
We uploaded a Gemma 2 (9B) Colab Notebook and uploaded pre-quantized 4bit models for 4x faster downloading which includes the Gemma 2 models in Instruct and Base in 4bit form.
We've also upgraded our Phi-3 mini support for Microsoft's new update. Make sure you're using our latest Phi-3 mini notebook and model.
Gemma 2 benchmarks
Model
VRAM
🦥Unsloth speed
🦥 VRAM reduction
🦥 Longer context
🤗Hugging Face+FA2
Gemma 2 (9B)
80GB
2x
63.2%
4-5x longer
1x
Gemma 2 (27B)
80GB
1.9x
51%
3x longer
1x
We tested using the Alpaca Dataset, a batch size of 2, gradient accumulation steps of 4, rank = 32, and applied QLoRA on all linear layers (q, k, v, o, gate, up, down).
🌠 4-5x longer context lengths
By using Unsloth's offloaded gradient checkpointing, you get only a 2% ish slower training run, but you get 30% less VRAM usage! We now enable it by default in Unsloth with `use_gradient_checkpointing = "unsloth"`.
Interestingly because of attention softcapping, Flash Attention is sadly not supported, causing the attention matrices to actually use O(N^2) memory, whilst FA2 allowed on the fly memory usage for the QK^T matrix creation. This means for the meantime, until Flash Attention supports softcapping, memory usage for Gemma will be quadratic in the sequence length.
Unsloth allows you to do 3x longer contexts for the 27b model, and 4-5x longer contexts for the 9b model!
Gemma 2 (27B) max. context length
GPU VRAM
Unsloth (New)
Hugging Face+FA2
16 GB
285
OOM
24 GB
3,436
675
40 GB
9,737
3,116
48 GB
12,888
4,337
80 GB
25,491
9,221
In all our experiments, we used QLoRA with a rank of 32 and applied LoRA adapters to all linear linears (q, k, v, o, gate, up, down). We used a batch size of 1, and repeated data to make it fit to the maximum context window.
Gemma 2 (9B) max. context length
GPU VRAM
Unsloth
Hugging Face+FA2
8 GB
74
OOM
12 GB
2,842
284
16 GB
5,609
1,070
24 GB
11,145
2,642
40 GB
22,215
5,787
48 GB
27,750
7,359
80 GB
49,891
13,649
We likewise test 9b with repeated data to use up the maximum context length.
👨💻 Derivative of Softcapping Function
We had to implement the softcapping mechanism in Gemma 2. We used Desmos to verify if our derivatives are correct, by using the numeric differentiation and integration feature in Desmos. We use some trigonometric identities to derive the gradient of the softcapping function.We also reduce VRAM usage by 500MB or more by fusing the softcapping mechanism in the cross entropy loss calculation. The derivatives are also needed here. By fusing them in, we do not have to keep a copy of the logits before the softcapping operation, reducing VRAM usage, We verified the accuracy of our gradients by confirming if the losses match up.
📈 Softcapping investigation
For the 9b model, we find that it's a must to turn on softcapping for the lm head logits, whilst the attention softcapping is generally recommended as well. If you do not turn it on for the lm head, your training loss will be incorrect.
For the 27b model, we tested all 4 combinations (attn softcapping, logit softcapping, both on, one on, none on). We find the 27b model is vastly more sensitive to softcapping, and we show you must turn on softcapping for the attention and the lm head logits. This is not an optionality like in the 9b model, where turning off the attention softcapping only somewhat impacts the loss.
All in all, we recommend turning on softcapping for the attention and lm head logits for both the 9b and the 27b models. This does mean Flash Attention cannot be used anymore. For Unsloth, we use torch.compile to fuse the softcapping and the attention matrix calculation.
🐛Gemma 2 Pytorch fixes
We also provided some fixes to the official Gemma 2 Pytorch repo! See Pull Request 67. We show that you must follow our previous Gemma bug fixes, where we show you must carefully downcast and upcast certain areas of the code. We provided 2 fixes! This is mainly because GPU mixed precision training is somewhat different from TPU mixed precision training.
We also found the Gemma team padded the tokenizer vocabulary by 128 as well, which is pretty cool! This was most likely to make training somewhat faster.
🪟 Phi-3 mini update
Thanks to Microsoft, Phi-3 mini has received an amazing new update so we've made sure to update all of our infrastructure to support this. We uploaded Phi-3's old original model if you would still like to use that.
Open the image below in a new tab to see all the new benchmarks:
🌎 AI World's Fair 2024
A huge thank you to everyone who showed up to our 2 sessions at the AI Engineer World's Fair last week. We met so many amazing people and couldn't be more thankful to Swyx and his amazing team for organizing this. We did a 3 hour workshop and absolutely loved how interactive the audience was and hopefully some of you managed to grab some stickers.
Watch our 20 minute lightning talk about fixing bugs in open source models on YouTube.
💕 Thank you!
Feel free to support us via our Ko-fi donation page. Huge shout out to: Creivailty, kearm, MrDragonFox, Sebastien, Fimbul, Jeff, Steffen, Andrew & Shailendra who are new supporters! 🙏
As always, be sure to join our Discord server for help or just to show your support! You can also follow us on Twitter and Substack.