mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix tokenizer + docs (#62)
* Patch tokenizer * Update _utils.py * Update _utils.py * Update _utils.py * Cleanup * Add comments to functions * Update rope_embedding.py * Update rope_embedding.py * Update llama.py * New logos! * Update README.md
This commit is contained in:
parent
808211931d
commit
353432271d
12 changed files with 155 additions and 155 deletions
39
README.md
39
README.md
|
|
@ -1,24 +1,25 @@
|
|||
<div class="align-center">
|
||||
<img src="./images/unsloth new logo.png" width="350" />
|
||||
<a href="https://discord.gg/u54VK8m8tk"><img src="./images/Discord.png" width="160"></a>
|
||||
<a href="https://colab.research.google.com/drive/1lBzz5KeZJKXjvivbYvmGarix9Ao6Wxe5?usp=sharing"><img src="./images/try live demo green.png" width="130"></a>
|
||||
<a href="https://colab.research.google.com/drive/1Dyauq4kTZoLewQ1cApceUQVNcnnNTzg_?usp=sharing"><img src="./images/try live demo green.png" height="60"></a>
|
||||
<a href="https://discord.gg/u54VK8m8tk"><img src="./images/Discord.png" height="60"></a>
|
||||
</div>
|
||||
|
||||
## 2-5x faster 60% less memory local QLoRA finetuning
|
||||
## Finetune Mistral, Llama 2-5x faster with 50% less memory!
|
||||
| Llama 7b | Mistral 7b | CodeLlama 34b | Llama 7b Kaggle 2x T4 |
|
||||
|-----------------------------|-----------------------------|-------------------------|------------------------|
|
||||
| **2.2x faster, -43% VRAM** | **2.2x faster, -62% VRAM** | **1.9x faster, -27% VRAM** | **5.5x faster, -44% VRAM** |
|
||||
| [Free Colab Llama + Alpaca example](https://colab.research.google.com/drive/1lBzz5KeZJKXjvivbYvmGarix9Ao6Wxe5?usp=sharing) | [Free Colab Mistral + Alpaca example](https://colab.research.google.com/drive/1Dyauq4kTZoLewQ1cApceUQVNcnnNTzg_?usp=sharing) | [Colab A100 example](https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing) | [Kaggle Alpaca example](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) |
|
||||
| [Colab A100 example](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Colab A100 example](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | (59 more examples if you scroll down) | [Kaggle Slim Orca example](https://www.kaggle.com/danielhanchen/unsloth-slimorca-t4-ddp) |
|
||||
| **Free** Llama <a href="https://colab.research.google.com/drive/1lBzz5KeZJKXjvivbYvmGarix9Ao6Wxe5?usp=sharing"><img src="./images/Colab.png" height="20"> | **Free** Mistral <a href="https://colab.research.google.com/drive/1Dyauq4kTZoLewQ1cApceUQVNcnnNTzg_?usp=sharing"><img src="./images/Colab.png" height="20"> | A100 Colab <a href="https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing"><img src="./images/Colab.png" height="20"> | **Free** Kaggle A <a href="https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp"><img src="./images/Kaggle.png" height="20"> |
|
||||
| A100 Colab <a href="https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing"><img src="./images/Colab.png" height="20"> | A100 Colab <a href="https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing"><img src="./images/Colab.png" height="20"> | (59 more examples below) | **Free** Kaggle B <a href="https://www.kaggle.com/danielhanchen/unsloth-slimorca-t4-ddp"><img src="./images/Kaggle.png" height="20"> |
|
||||
|
||||
* **NEW!** [DPO](https://arxiv.org/abs/2305.18290) support. [Free DPO Colab example](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing). [More info](#DPO).
|
||||
* **NEW!** [TinyLlama](https://github.com/jzhang38/TinyLlama) on 3T tokens. [Free Colab example](https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing). We also show automatic RoPE Scaling extending TinyLlama from 2048 to 4096 tokens!
|
||||
* **NEW!** [DPO](https://arxiv.org/abs/2305.18290) support. **Free** DPO example <a href="https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing"><img src="./images/Colab.png" height="20"> [More info](#DPO) on DPO
|
||||
* **NEW!** [TinyLlama 1.1b](https://github.com/jzhang38/TinyLlama) on 3T tokens! **Free** example <a href="https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing"><img src="./images/Colab.png" height="20">
|
||||
* **NEW!** We're in 🤗 Huggingface's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
||||
* Supports Llama, Yi, Mistral, CodeLlama, Qwen (llamafied), Deepseek and their derived models (Open Hermes etc).
|
||||
* All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language. **Manual backprop engine**.
|
||||
* **0% loss in accuracy** - no approximation methods - all exact.
|
||||
* No change of hardware necessary. Supports NVIDIA GPUs since 2018+. Minimum CUDA Compute Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU](https://developer.nvidia.com/cuda-gpus)
|
||||
* **NEW!** Works on **Linux** and **Windows** via WSL.
|
||||
* **NEW!** Download 4 bit models 4x faster from Huggingface! Eg: `unsloth/mistral-7b-bnb-4bit`
|
||||
* No change of hardware necessary. Supports NVIDIA GPUs since 2018+. Minimum CUDA Compute Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070 and 1080 works, but is a bit slow!
|
||||
* Works on **Linux** and **Windows** via WSL.
|
||||
* **NEW!** Download 4 bit models 4x faster from 🤗 Huggingface! Eg: `unsloth/mistral-7b-bnb-4bit`
|
||||
* Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
|
||||
* **NEW!** Want a UI for finetuning? Try [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory) and use `--use_unsloth`!
|
||||
* Open source trains 5x faster - see [Unsloth Pro](https://unsloth.ai/) for **30x faster training**!
|
||||
|
|
@ -31,8 +32,9 @@
|
|||
| Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
|
||||
|
||||
Join our [Discord](https://discord.gg/nsS4V5Z6ge)!
|
||||
If you trained a model with Unsloth, we made a cool sticker if you want to use it!
|
||||
|
||||
<img src="./images/unsloth made with love.png" width="200" />
|
||||
If you trained a model with Unsloth, we made a cool sticker if you want to use it!
|
||||
|
||||
# Installation Instructions - Conda
|
||||
Select either `pytorch-cuda=11.8` for CUDA 11.8 or `pytorch-cuda=12.1` for CUDA 12.1.
|
||||
|
|
@ -79,6 +81,9 @@ pip install --upgrade pip
|
|||
|
||||
# Documentation
|
||||
We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
|
||||
|
||||
We're in 🤗 Huggingface's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
||||
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
import torch
|
||||
|
|
@ -145,6 +150,9 @@ trainer.train()
|
|||
<a name="DPO"></a>
|
||||
# DPO (Direct Preference Optimization) Support
|
||||
DPO, PPO, Reward Modelling all seem to work as per 3rd party independent testing from [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory). We have a preliminary Google Colab notebook for reproducing Zephyr on Tesla T4 here: [notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing).
|
||||
|
||||
We're in 🤗 Huggingface's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
||||
|
||||
```python
|
||||
from unsloth import FastLanguageModel, PatchDPOTrainer
|
||||
PatchDPOTrainer()
|
||||
|
|
@ -262,15 +270,6 @@ Two Tesla T4s on Kaggle
|
|||
# How did we make it faster?
|
||||
Manual autograd, Triton kernels etc. See our [Benchmark Breakdown](https://unsloth.ai/blog/mistral-benchmark) for more info!
|
||||
|
||||
$$
|
||||
\begin{align}
|
||||
y &= \frac{x_i}{\sqrt{\frac{1}{n}\sum{x_i^2}+\epsilon}} \cdot w \\
|
||||
r &= \frac{1}{\sqrt{\frac{1}{n}\sum{x_i^2}+\epsilon}} \\
|
||||
\frac{dC}{dX} &= \frac{1}{n} r \bigg( n (dY \cdot w) - \bigg( x_i \cdot r \cdot \sum{dY \cdot y_i } \bigg) \bigg)
|
||||
\end{align}
|
||||
$$
|
||||
|
||||
|
||||
# Troubleshooting
|
||||
1. Sometimes `bitsandbytes` or `xformers` does not link properly. Try running:
|
||||
```bash
|
||||
|
|
|
|||
BIN
images/Colab.png
Normal file
BIN
images/Colab.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
BIN
images/Kaggle.png
Normal file
BIN
images/Kaggle.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.5 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 14 KiB |
|
|
@ -75,12 +75,12 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
i = h @ W
|
||||
|
||||
### Backpropagation chain rule
|
||||
See our blog post for more details
|
||||
|
||||
df = sigmoid(e) * (1 - f) + f
|
||||
dC/dW = h.T @ dY
|
||||
dC/dU = X.T @ (D @ W.T * f)
|
||||
dC/dG = X.T @ (D @ W.T * df * g)
|
||||
dC/dX = (D @ W.T * f) @ U.T
|
||||
+ (D @ W.T * df * g) @ G.T
|
||||
|
||||
### Down projection LoRA weights
|
||||
dC/dAw = dC/dW @ B.T
|
||||
|
|
@ -95,6 +95,8 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
### Gate projection LoRA weights
|
||||
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
|
||||
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
|
||||
|
||||
Don't forget to see our blog post for more details!
|
||||
"""
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
|
|
@ -141,13 +143,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
# DW_dfg = (D @ W.T * df * g)
|
||||
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
||||
DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g)
|
||||
h, DW_f, DW_dfg = DW, e, g # Inplace replacements
|
||||
# se = torch.nn.functional.sigmoid(e)
|
||||
# f = e * se
|
||||
# h = f * g
|
||||
# df = se * (1 - f) + f
|
||||
# DW_f = DW * f
|
||||
# DW_dfg = DW * df * g
|
||||
h, DW_f, DW_dfg = DW, e, g
|
||||
|
||||
# Down projection LoRA weights
|
||||
d_downA = h.t() @ (dY @ downB.t())
|
||||
|
|
@ -167,8 +163,8 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
d_gateA *= gateS
|
||||
d_gateB *= gateS
|
||||
|
||||
# dC/dX = (D @ W.T * f) @ (U.T + B.T @ A.T)
|
||||
# + (D @ W.T * df * g) @ (G.T + B.T @ A.T)
|
||||
# Final derivatives to backpropagate backwards.
|
||||
# See our blogpost for more details.
|
||||
# (D @ W.T * f) @ U.T
|
||||
upW = fast_dequantize(upW.t(), upW_quant)
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T)
|
||||
|
|
@ -176,9 +172,8 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
del upW
|
||||
dX += DW_f @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
||||
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T) + (D @ W.T * df * g) @ G.T
|
||||
# And add the derivative for the gate projection
|
||||
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T) + (D @ W.T * df * g) @ (G.T + B.T @ A.T)
|
||||
dX += DW_dfg @ gateW.t()
|
||||
del gateW
|
||||
dX += DW_dfg @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
||||
|
|
@ -217,12 +212,12 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
||||
|
||||
### Backpropagation chain rule
|
||||
See our blogpost for more details.
|
||||
|
||||
dC/dWq = X.T @ D(Wq)
|
||||
dC/dWk = X.T @ D(Wk)
|
||||
dC/dWv = X.T @ D(Wv)
|
||||
dC/dX = D(Wq) @ Wq.T
|
||||
+ D(Wk) @ Wk.T
|
||||
+ D(Wv) @ Wv.T
|
||||
We then sum them all find dC/dX
|
||||
|
||||
### Q projection LoRA weights
|
||||
dC/dAq = X.T @ D(Wq) @ B.T
|
||||
|
|
@ -275,8 +270,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
dtype = X.dtype
|
||||
|
||||
### Weight projection LoRA weights
|
||||
# dC/dAq = X.T @ D(Wq) @ B.T
|
||||
# dC/dBq = A.T @ X.T @ D(Wq)
|
||||
# See our blogpost for more details.
|
||||
|
||||
# Q Projection
|
||||
d_QA = X.t() @ (dQ @ QB.t())
|
||||
|
|
@ -296,24 +290,21 @@ class LoRA_QKV(torch.autograd.Function):
|
|||
d_VA *= VS
|
||||
d_VB *= VS
|
||||
|
||||
# d/dX
|
||||
# dC/dX = D(Wq) @ Wq.T
|
||||
# Combine derivatives to find dX
|
||||
# dQ
|
||||
QW = fast_dequantize(QW.t(), QW_quant)
|
||||
# D(Wq) @ (Wq.T + B.T @ A.T)
|
||||
dX = torch.matmul(dQ, QW.t(), out = X)
|
||||
del QW
|
||||
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
||||
|
||||
# D(Wq) @ Wq.T + D(Wk) @ Wk.T
|
||||
# dK
|
||||
KW = fast_dequantize(KW.t(), KW_quant)
|
||||
# D(Wq) @ Wq.T + D(Wk) @ (Wk.T + B.T @ A.T)
|
||||
dX += dK @ KW.t()
|
||||
del KW
|
||||
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
||||
|
||||
# D(Wq) @ Wq.T + D(Wk) @ Wk.T + D(Wv) @ Wv.T
|
||||
# dV
|
||||
VW = fast_dequantize(VW.t(), VW_quant)
|
||||
# D(Wq) @ Wq.T + D(Wk) @ Wk.T + D(Wv) @ (Wv.T + B.T @ A.T)
|
||||
dX += dV @ VW.t()
|
||||
del VW
|
||||
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
||||
|
|
@ -356,9 +347,6 @@ class LoRA_W(torch.autograd.Function):
|
|||
dC/dWq = X.T @ D(Wq)
|
||||
dC/dWk = X.T @ D(Wk)
|
||||
dC/dWv = X.T @ D(Wv)
|
||||
dC/dX = D(Wq) @ Wq.T
|
||||
+ D(Wk) @ Wk.T
|
||||
+ D(Wv) @ Wv.T
|
||||
|
||||
### Q projection LoRA weights
|
||||
dC/dAq = X.T @ D(Wq) @ B.T
|
||||
|
|
@ -392,21 +380,18 @@ class LoRA_W(torch.autograd.Function):
|
|||
A, B = A.t(), B.t()
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.reshape(-1, dY.shape[-1]) # .view doesn't work on non contiguous
|
||||
X = X .reshape(-1, X .shape[-1]) # .view doesn't work on non contiguous
|
||||
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
|
||||
X = X .reshape(-1, X .shape[-1]) # Must be reshape
|
||||
dtype = X.dtype
|
||||
|
||||
### Weight projection LoRA weights
|
||||
# dC/dAq = X.T @ D(Wq) @ B.T
|
||||
# dC/dBq = A.T @ X.T @ D(Wq)
|
||||
|
||||
# Weight projection
|
||||
d_A = X.t() @ (dY @ B.t())
|
||||
d_B = (A.t() @ X.t()) @ dY
|
||||
d_A *= S
|
||||
d_B *= S
|
||||
|
||||
# dC/dX = D(Wq) @ Wq.T
|
||||
# Get derivative for dX
|
||||
W = fast_dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
|
|
|
|||
|
|
@ -27,6 +27,11 @@ def _rms_layernorm_forward(
|
|||
n_cols, eps,
|
||||
BLOCK_SIZE : tl.constexpr
|
||||
):
|
||||
"""
|
||||
Fast RMS Layernorm kernel
|
||||
Inspiration from a Triton tutorial:
|
||||
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
|
@ -49,7 +54,6 @@ pass
|
|||
|
||||
@triton.jit
|
||||
def _rms_layernorm_backward(
|
||||
#dX, dX_row_stride,
|
||||
dY, dY_row_stride,
|
||||
X, X_row_stride,
|
||||
W, W_row_stride,
|
||||
|
|
@ -58,11 +62,15 @@ def _rms_layernorm_backward(
|
|||
n_cols, eps,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fast RMS Layernorm kernel for the backward pass
|
||||
Inspiration from a Triton tutorial:
|
||||
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
#dX += row_idx * dX_row_stride + col_offsets
|
||||
dY += row_idx * dY_row_stride
|
||||
X += row_idx * X_row_stride
|
||||
r += row_idx * r_row_stride
|
||||
|
|
@ -71,15 +79,13 @@ def _rms_layernorm_backward(
|
|||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
# inv_var = 1 / tl.sqrt(row_var + eps)
|
||||
# Get saved row variance
|
||||
inv_var = tl.load(r).to(tl.float32)
|
||||
normed = X_row * inv_var
|
||||
|
||||
dY_W = dY_row * W_row
|
||||
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
||||
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
||||
#tl.store(dX, output, mask = mask)
|
||||
tl.store(dY + col_offsets, output, mask = mask)
|
||||
pass
|
||||
|
||||
|
|
@ -92,9 +98,10 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
X = X.view(-1, dim)
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
|
||||
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_rms_layernorm_forward[(n_rows,)](
|
||||
Y, Y.stride(0),
|
||||
X, X.stride(0),
|
||||
|
|
@ -120,10 +127,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
n_rows, n_cols = dY.shape
|
||||
dW = X
|
||||
|
||||
# dX = torch.empty_like(dY)
|
||||
# dX = dY
|
||||
_rms_layernorm_backward[(n_rows,)](
|
||||
#dX, dX.stride(0),
|
||||
dY, dY.stride(0),
|
||||
X, X .stride(0),
|
||||
W, W .stride(0),
|
||||
|
|
@ -133,9 +137,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
|
|||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
#dX = dX.view(*shape)
|
||||
dX = dY.view(*shape)
|
||||
# X, W, eps
|
||||
return dX, None, None
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -28,42 +28,35 @@ def _rope_embedding(
|
|||
BACKWARD_PASS: tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Calculates the RoPE Embedding quickly
|
||||
RoPE is Q * cos + rotate_half(Q) * sin
|
||||
See our blog post for more info
|
||||
"""
|
||||
row_position = tl.program_id(0)
|
||||
head_position = tl.program_id(1)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
half_head_dim = head_dim // 2
|
||||
mask = col_offsets < half_head_dim
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
rot_position = row_position % seqlen
|
||||
|
||||
Q += row_position* Q_row_stride + head_position*head_dim
|
||||
cos += rot_position*cos_row_stride
|
||||
sin += rot_position*sin_row_stride
|
||||
|
||||
Q1 = tl.load(Q + half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
sin1 = tl.load(sin + half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
cos1 = tl.load(cos + half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
|
||||
Q2 = tl.load(Q + half_head_dim*1 + col_offsets, mask = mask, other = 0)
|
||||
# RoPE repeats sin and cos so 128 = [64, 64].
|
||||
Q1 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
Q2 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*1 + col_offsets, mask = mask, other = 0)
|
||||
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
|
||||
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
|
||||
if BACKWARD_PASS:
|
||||
"""
|
||||
Q * cos + rotate_half(Q) * sin
|
||||
is equivalent to
|
||||
Q * cos + Q @ R * sin
|
||||
where R is a rotation matrix [ 0, I]
|
||||
[-I, 0]
|
||||
dC/dY = dY * cos + dY @ R.T * sin
|
||||
where R.T is again the same [ 0, -I]
|
||||
but the minus is transposed. [ I, 0]
|
||||
"""
|
||||
# See our blog post for more info.
|
||||
sin1 = -sin1
|
||||
|
||||
# RoPE repeats sin and cos so 128 = [64, 64].
|
||||
tl.store(Q + half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
|
||||
tl.store(Q + half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)
|
||||
pass
|
||||
|
||||
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
|
||||
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
|
||||
half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -90,7 +83,7 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
|
|||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.cos = cos # Don't need save_for_backward since a view
|
||||
ctx.cos = cos
|
||||
ctx.sin = sin
|
||||
return Q.view(batch, seq_len, n_heads, head_dim)
|
||||
pass
|
||||
|
|
@ -99,8 +92,7 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
|
|||
def backward(ctx, dY):
|
||||
batch, seq_len, n_heads, head_dim = dY.shape
|
||||
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
|
||||
# Cannot be .view since the problem lies with dK since
|
||||
# K.T's strides are incorrect.
|
||||
# Must be reshape not view
|
||||
n_rows, n_cols = dY.shape
|
||||
|
||||
cos = ctx.cos
|
||||
|
|
@ -122,10 +114,8 @@ pass
|
|||
|
||||
|
||||
def fast_rope_embedding(Q, K, cos, sin):
|
||||
# We need (batch, [seqlen, n_heads], head_dim)
|
||||
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
|
||||
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
|
||||
# We need (batch, [n_heads, seqlen], head_dim)
|
||||
return Q, K
|
||||
pass
|
||||
|
||||
|
|
@ -155,7 +145,6 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
cos, sin = ctx.saved_tensors
|
||||
# Q * cos + rotate_half.T(Q) * sin
|
||||
half = dY.shape[-1]//2
|
||||
# We reverse the minus sign for R.T
|
||||
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
||||
dY *= cos
|
||||
RH_dY *= sin
|
||||
|
|
|
|||
|
|
@ -28,12 +28,11 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
|||
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
# https://github.com/openai/triton/issues/241 exp MUST be done in f32
|
||||
# or else Triton crashes
|
||||
f_row = e_row / (1 + tl.exp(-e_row))
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
|
||||
# Store h
|
||||
tl.store(h + offsets, h_row, mask = mask)
|
||||
pass
|
||||
|
||||
|
|
@ -59,23 +58,20 @@ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
|||
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
# https://github.com/openai/triton/issues/241 exp MUST be done in f32
|
||||
# or else Triton crashes
|
||||
se_row = 1 / (1 + tl.exp(-e_row))
|
||||
# f = e * se
|
||||
f_row = e_row * se_row
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
# df = se * (1 - f) + f
|
||||
# DW_f = DW * f
|
||||
DWf_row = DW_row * f_row
|
||||
# DW_dfg = DW * df * g
|
||||
# DW_dfg = DW * (se * (1 - f) + f) * g
|
||||
# DW_dfg = DW * (se*(g - h) + h)
|
||||
DW_dfg_row = DW_row * (se_row*(g_row - h_row) + h_row)
|
||||
|
||||
tl.store(DW + offsets, h_row, mask = mask) # h
|
||||
tl.store(e + offsets, DWf_row, mask = mask) # DW * f
|
||||
tl.store(g + offsets, DW_dfg_row, mask = mask) # DW * df * g
|
||||
# Store derivatives in buffers
|
||||
tl.store(DW + offsets, h_row, mask = mask)
|
||||
tl.store(e + offsets, DWf_row, mask = mask)
|
||||
tl.store(g + offsets, DW_dfg_row, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -84,5 +80,5 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
|||
n_elements = e.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
||||
return DW, e, g # h, DW * f, DW * df * g
|
||||
return DW, e, g
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,12 +13,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
MAX_FUSED_SIZE = 65536 # 2**16 Solves https://github.com/unslothai/unsloth/issues/7
|
||||
MAX_FUSED_SIZE = 65536
|
||||
next_power_of_2 = triton.next_power_of_2
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = next_power_of_2(n)
|
||||
# CUDA only supports 65536 - 2^16 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import gc
|
|||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
||||
import bitsandbytes as bnb
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from transformers import AutoTokenizer
|
||||
from platform import system as platform_system
|
||||
platform_system = platform_system()
|
||||
|
||||
|
|
@ -115,24 +116,56 @@ def patch_tokenizer(model, tokenizer):
|
|||
pass
|
||||
|
||||
|
||||
def check_tokenizer(model, tokenizer):
|
||||
def check_tokenizer(
|
||||
model,
|
||||
tokenizer,
|
||||
model_name = "unsloth/llama-2-7b-bnb-4bit",
|
||||
model_max_length = 4096,
|
||||
padding_side = "right",
|
||||
token = None,
|
||||
_reload = True,
|
||||
):
|
||||
# Checks tokenizer for out of bounds ids.
|
||||
# Mainly a fix for https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha
|
||||
# where <sep> had token id=32002.
|
||||
# See https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha/discussions/25
|
||||
special_tokens_map = tokenizer.special_tokens_map
|
||||
max_embedding_size = model.model.embed_tokens.weight.shape[0]
|
||||
# Seems like the Fast tokenizer in Rust breaks things!
|
||||
|
||||
for token_name, token_content in special_tokens_map.items():
|
||||
if type(token_content) is not str: continue
|
||||
token_ids = tokenizer([token_content], add_special_tokens = False, return_attention_mask = False)
|
||||
token_ids = token_ids.input_ids[0][0]
|
||||
if token_ids < 0 or token_ids >= max_embedding_size:
|
||||
raise RuntimeError(
|
||||
f"Unsloth: Extra special token `{token_content}` with id={token_ids} exceeds "\
|
||||
f"the maximum vocabulary size of {max_embedding_size}. You must fix the tokenizer "\
|
||||
"or else out of bounds memory accesses will occur."
|
||||
max_embedding_size = model.model.embed_tokens.weight.shape[0]
|
||||
added_tokens_fast = tokenizer.added_tokens_decoder
|
||||
added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
|
||||
sorted_keys = sorted(added_tokens_fast)
|
||||
added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
|
||||
|
||||
for j, index in enumerate(added_tokens_fast.keys()):
|
||||
if index >= max_embedding_size:
|
||||
bad_indices = list(added_tokens_fast.keys ())[j:]
|
||||
bad_tokens = list(added_tokens_fast.values())[j:]
|
||||
if not _reload:
|
||||
raise RuntimeError(
|
||||
f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
|
||||
f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
|
||||
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
|
||||
)
|
||||
# Try slow tokenizer which can fix things!
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
use_fast = False,
|
||||
)
|
||||
return check_tokenizer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = model_max_length,
|
||||
padding_side = padding_side,
|
||||
token = token,
|
||||
_reload = False,
|
||||
)
|
||||
break
|
||||
pass
|
||||
pass
|
||||
return tokenizer
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -135,8 +135,6 @@ def LlamaAttention_fast_forward_inference(
|
|||
Vn = torch.cat([V1, Vn], dim = 2)
|
||||
|
||||
# Grouped query attention
|
||||
# K = repeat_kv(K, n_groups)
|
||||
# V = repeat_kv(V, n_groups)
|
||||
if n_groups != 1:
|
||||
_, _, cached_len, _ = Kn.shape
|
||||
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
|
|
@ -210,7 +208,6 @@ def LlamaAttention_fast_forward(
|
|||
pass
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
|
@ -219,7 +216,6 @@ def LlamaAttention_fast_forward(
|
|||
if (not HAS_FLASH_ATTENTION):
|
||||
# Xformers memory efficient attention
|
||||
# Also has Flash Attention v2 dispatching
|
||||
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
|
|
@ -231,25 +227,18 @@ def LlamaAttention_fast_forward(
|
|||
K = K.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
V = V.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
if hidden_states.requires_grad:
|
||||
# Xformers does not support backward, so we have to convert
|
||||
# GQA to MQA by cloning K and V
|
||||
K = K.reshape(bsz, q_len, n_heads, head_dim) # A copy will be made
|
||||
V = V.reshape(bsz, q_len, n_heads, head_dim) # A copy will be made
|
||||
K = K.reshape(bsz, q_len, n_heads, head_dim)
|
||||
V = V.reshape(bsz, q_len, n_heads, head_dim)
|
||||
else:
|
||||
# Xformers does support the forward pass though
|
||||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
pass
|
||||
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION:
|
||||
# Flash Attention
|
||||
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
|
||||
# Flash Attention v2 auto supports grouped query attention
|
||||
A = flash_attn_func(Q, K, V, causal = True)
|
||||
else:
|
||||
# Grouped query attention
|
||||
|
|
@ -714,7 +703,14 @@ class FastLlamaModel:
|
|||
internal_model.max_seq_length = max_position_embeddings
|
||||
|
||||
# We check the tokenizer first for errors
|
||||
check_tokenizer(model, tokenizer)
|
||||
tokenizer = check_tokenizer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = max_seq_length,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
return model, tokenizer
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -92,28 +92,23 @@ def MistralAttention_fast_forward(
|
|||
# Attention module
|
||||
if (not HAS_FLASH_ATTENTION):
|
||||
# Xformers memory efficient attention
|
||||
# Also has Flash Attention v2 dispatching
|
||||
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
M = bsz * q_len
|
||||
|
||||
has_sliding_window = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
|
||||
has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
|
||||
|
||||
# Group query attention
|
||||
# if n_groups != 1:
|
||||
K = K .view(bsz, q_len, n_kv_heads, 1, head_dim)
|
||||
V = V .view(bsz, q_len, n_kv_heads, 1, head_dim)
|
||||
K = K.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
V = V.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
if hidden_states.requires_grad:
|
||||
# Xformers does not support backward, so we have to convert
|
||||
# GQA to MQA by cloning K and V
|
||||
K = K.reshape(bsz, q_len, n_heads, head_dim) # A copy will be made
|
||||
V = V.reshape(bsz, q_len, n_heads, head_dim) # A copy will be made
|
||||
K = K.reshape(bsz, q_len, n_heads, head_dim)
|
||||
V = V.reshape(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
if has_sliding_window:
|
||||
if has_swa:
|
||||
Q = Q.view(1, M, n_heads, head_dim)
|
||||
K = K.view(1, M, n_heads, head_dim)
|
||||
V = V.view(1, M, n_heads, head_dim)
|
||||
|
|
@ -122,7 +117,7 @@ def MistralAttention_fast_forward(
|
|||
# Xformers does support the forward pass though
|
||||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
|
||||
if has_sliding_window:
|
||||
if has_swa:
|
||||
Q = Q.view(1, M, n_kv_heads, n_groups, head_dim)
|
||||
K = K.view(1, M, n_kv_heads, n_groups, head_dim)
|
||||
V = V.view(1, M, n_kv_heads, n_groups, head_dim)
|
||||
|
|
@ -133,16 +128,12 @@ def MistralAttention_fast_forward(
|
|||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION:
|
||||
# Flash Attention
|
||||
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
|
||||
# Flash Attention v2 auto supports grouped query attention
|
||||
sliding_window = getattr(self.config, "sliding_window")
|
||||
sliding_window = q_len if sliding_window is None else sliding_window
|
||||
window = (-1, -1) if (q_len <= sliding_window) else (sliding_window, sliding_window)
|
||||
sw = getattr(self.config, "sliding_window")
|
||||
sw = q_len if sw is None else sw
|
||||
window = (-1, -1) if (q_len <= sw) else (sw, sw)
|
||||
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
|
||||
else:
|
||||
# Grouped query attention
|
||||
|
|
@ -317,7 +308,7 @@ class FastMistralModel(FastLlamaModel):
|
|||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = max_seq_length,
|
||||
padding_side = "right", # MUST be right or else attention fails!
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
|
||||
|
|
@ -339,6 +330,16 @@ class FastMistralModel(FastLlamaModel):
|
|||
internal_model = internal_model.model
|
||||
pass
|
||||
internal_model.max_seq_length = max_position_embeddings
|
||||
|
||||
# We check the tokenizer first for errors
|
||||
tokenizer = check_tokenizer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = max_seq_length,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
return model, tokenizer
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue