mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
First upload of Unsloth code
This commit is contained in:
parent
1e2ba1b1d2
commit
c3d6def64a
13 changed files with 2053 additions and 2 deletions
36
README.md
36
README.md
|
|
@ -1,2 +1,34 @@
|
|||
# unsloth
|
||||
2x faster 50% less memory LLM finetuning on a single GPU
|
||||
# Unsloth
|
||||
2x faster 50% less memory LLM finetuning on a single GPU.
|
||||
|
||||
`!pip install "unsloth[cu118] @ git+https://github.com/unslothai/unsloth.git"`
|
||||
`!pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"`
|
||||
|
||||
|
||||
### Google Colab examples
|
||||
1. [Unsloth fast finetuning example](https://colab.research.google.com/drive/1oW55fBmwzCOrBVX66RcpptL3a99qWBxb?usp=sharing)
|
||||
2. [Original slow finetuning example](https://colab.research.google.com/drive/1c7zxdLHaLJ9R9YTZ74y4tUERvS-kySyA?usp=sharing)
|
||||
|
||||
### Installation instructions
|
||||
In Google Colab:
|
||||
```
|
||||
!ldconfig /usr/lib64-nvidia
|
||||
!pip install xformers --index-url https://download.pytorch.org/whl/cu118
|
||||
!pip install git+https://github.com/danielhanchen/unsloth.git
|
||||
```
|
||||
`!ldconfig /usr/lib64-nvidia` is necessary (for now) to link CUDA with Python. Possibly a Google Colab linking bug.
|
||||
|
||||
For general installations:
|
||||
1. Install Xformers *OR* Flash Attention. Choose 1. Old GPUs use Xformers. New use Flash Attention.
|
||||
2. For Xformers, find your Pytorch CUDA version via `torch.version.cuda` or `nvidia-smi`.
|
||||
* If you have Conda, `conda install xformers -c xformers`
|
||||
* If you have CUDA 11.8, `pip install xformers --index-url https://download.pytorch.org/whl/cu118`
|
||||
* If you have CUDA 12.1, `pip install xformers --index-url https://download.pytorch.org/whl/cu121`
|
||||
* Go to https://github.com/facebookresearch/xformers for other issues.
|
||||
* You must have Pytorch 2.1 installed for Xformers. If not, try Flash Attention.
|
||||
* Xformers supports all GPUs (Tesla T4 etc).
|
||||
3. For Flash Attention, you must have a Ampere, Ada, Hopper GPU (A100, RTX 3090, RTX 4090, H100).
|
||||
* Install Flash Attention via `pip uninstall -y ninja && pip install ninja` then `pip install flash-attn --no-build-isolation`.
|
||||
* Xformers has native support for Flash Attention, so technically installing Xformers is enough.
|
||||
4. Then install Unsloth:
|
||||
`pip install git+https://github.com/danielhanchen/unsloth.git`
|
||||
|
|
|
|||
51
pyproject.toml
Normal file
51
pyproject.toml
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
[build-system]
|
||||
requires = ["setuptools", "setuptools-scm"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "unsloth"
|
||||
version = "2023.11"
|
||||
description = "2X faster LLM finetuning"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = {file = "LICENSE"}
|
||||
keywords = ["ai", "llm",]
|
||||
authors = [
|
||||
{email = "info@unsloth.ai"},
|
||||
{name = "Unsloth AI team"},
|
||||
]
|
||||
maintainers = [
|
||||
{name = "Daniel Han", email = "danielhanchen@gmail.com"},
|
||||
{name = "Michael Han", email = "info@unsloth.ai"},
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"transformers",
|
||||
"bitsandbytes",
|
||||
"datasets",
|
||||
"sentencepiece",
|
||||
"accelerate",
|
||||
"trl",
|
||||
"peft",
|
||||
"torch>=2.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
cu118 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system=='Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system=='Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system=='Linux'",
|
||||
]
|
||||
cu121 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system=='Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system=='Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system=='Linux'",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
homepage = "http://www.unsloth.ai"
|
||||
documentation = "https://github.com/unslothai/unsloth"
|
||||
repository = "https://github.com/unslothai/unsloth"
|
||||
16
unsloth/__init__.py
Normal file
16
unsloth/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "2023.11"
|
||||
|
||||
from .models import *
|
||||
24
unsloth/kernels/__init__.py
Normal file
24
unsloth/kernels/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cross_entropy_loss import fast_cross_entropy_loss
|
||||
from .rms_layernorm import fast_rms_layernorm
|
||||
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
|
||||
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
||||
from .fast_lora import (
|
||||
apply_lora_mlp,
|
||||
apply_lora_qkv,
|
||||
apply_lora_o,
|
||||
)
|
||||
from .utils import fast_dequantize, QUANT_STATE
|
||||
167
unsloth/kernels/cross_entropy_loss.py
Normal file
167
unsloth/kernels/cross_entropy_loss.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from .utils import calculate_settings
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||
Pi = exp(xi) / sum(exp(xi))
|
||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
||||
= -y [ x - log[sum(exp(x))] ]
|
||||
= y * (log[sum(exp(x))] - x)
|
||||
If y == 0: CE_i = 0
|
||||
If y == 1: CE_i = logsumexp - x
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
loss_ptr += row_idx
|
||||
lse_ptr += row_idx
|
||||
labels_ptr += row_idx
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
max_logits = tl.max(logits, 0)
|
||||
# Maximum stops overflow
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr, lse)
|
||||
|
||||
if label_idx != -100:
|
||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = lse - logits_label
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(loss_ptr, loss)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
||||
dloss_ptr, dloss_row_stride,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
||||
|
||||
From https://en.wikipedia.org/wiki/LogSumExp
|
||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
||||
|
||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
||||
|
||||
If y == 0: dC/dx = 0
|
||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
dloss_ptr += row_idx * dloss_row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
# TODO: Fixup int32 locations to int64
|
||||
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
|
||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
||||
|
||||
if label_idx != -100:
|
||||
dloss = tl.load(dloss_ptr)
|
||||
else:
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels):
|
||||
n_rows, n_cols = logits.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(logits, logsumexp, labels)
|
||||
return losses
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dlosses):
|
||||
logits, logsumexp, labels = ctx.saved_tensors
|
||||
n_rows, n_cols = logits.shape
|
||||
|
||||
_cross_entropy_backward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
dlosses, dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
return logits, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_cross_entropy_loss(logits, labels):
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, seq_len, vocab_size)
|
||||
labels: (batch, seq_len,)
|
||||
Returns:
|
||||
losses: float
|
||||
"""
|
||||
batch, seq_len, d = logits.shape
|
||||
assert(labels.shape == (batch, seq_len))
|
||||
|
||||
loss = Fast_CrossEntropyLoss.apply(
|
||||
logits.view(batch*seq_len, d),
|
||||
labels.view(-1),
|
||||
)
|
||||
n_items = torch.count_nonzero(labels != -100)
|
||||
return loss.sum() / n_items
|
||||
pass
|
||||
414
unsloth/kernels/fast_lora.py
Normal file
414
unsloth/kernels/fast_lora.py
Normal file
|
|
@ -0,0 +1,414 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from .utils import fast_dequantize, QUANT_STATE
|
||||
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
||||
|
||||
def get_lora_parameters(proj):
|
||||
active_adapter = proj.active_adapters[0] if \
|
||||
hasattr(proj, "active_adapters") else proj.active_adapter
|
||||
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
||||
W = base_layer.weight
|
||||
A = proj.lora_A [active_adapter].weight
|
||||
B = proj.lora_B [active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
return W, QUANT_STATE(W), A, B, s
|
||||
pass
|
||||
|
||||
|
||||
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
||||
dtype = X.dtype
|
||||
W = fast_dequantize(W.t(), W_quant)
|
||||
A, B = A.t(), B.t()
|
||||
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, d = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
pass
|
||||
|
||||
out = torch.matmul(X, W, out = out)
|
||||
if W_quant is not None: del W
|
||||
out += (X @ A) @ (s * B)
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
pass
|
||||
|
||||
|
||||
class LoRA_MLP(torch.autograd.Function):
|
||||
"""
|
||||
### LoRA weights
|
||||
G = G + Ag @ Bg
|
||||
U = U + Au @ Bu
|
||||
W = W + Aw @ Bw
|
||||
|
||||
### SwiGLU(X)
|
||||
e = X @ G
|
||||
f = e * sigmoid(e)
|
||||
g = X @ U
|
||||
h = f * g
|
||||
i = h @ W
|
||||
|
||||
### Backpropagation chain rule
|
||||
df = sigmoid(e) * (1 - f) + f
|
||||
dC/dW = h.T @ dY
|
||||
dC/dU = X.T @ (D @ W.T * f)
|
||||
dC/dG = X.T @ (D @ W.T * df * g)
|
||||
dC/dX = (D @ W.T * f) @ U.T
|
||||
+ (D @ W.T * df * g) @ G.T
|
||||
|
||||
### Down projection LoRA weights
|
||||
dC/dAw = dC/dW @ B.T
|
||||
dC/dBw = A.T @ dC/dW
|
||||
dC/dAw = h.T @ dY @ B.T
|
||||
dC/dBw = A.T @ h.T @ dY
|
||||
|
||||
### Up projection LoRA weights
|
||||
dC/dAu = X.T @ (D @ W.T * f) @ B.T
|
||||
dC/dBu = A.T @ X.T @ (D @ W.T * f)
|
||||
|
||||
### Gate projection LoRA weights
|
||||
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
|
||||
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
|
||||
"""
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(ctx, X : torch.Tensor,
|
||||
gateW, gateW_quant, gateA, gateB, gateS,
|
||||
upW, upW_quant, upA, upB, upS,
|
||||
downW, downW_quant, downA, downB, downS):
|
||||
dtype = X.dtype
|
||||
|
||||
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
|
||||
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
|
||||
h = swiglu_fg_kernel(e, g)
|
||||
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
|
||||
|
||||
ctx.custom_saved_tensors = (
|
||||
gateW, gateW_quant, gateS,
|
||||
upW, upW_quant, upS,
|
||||
downW, downW_quant, downS,
|
||||
)
|
||||
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
|
||||
X, e, g)
|
||||
return i
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dY : torch.Tensor):
|
||||
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, = \
|
||||
ctx.custom_saved_tensors
|
||||
gateA, gateB, upA,upB, downA, downB, \
|
||||
X, e, g = ctx.saved_tensors
|
||||
|
||||
gateA, gateB, upA,upB, downA, downB = \
|
||||
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.view(-1, dY.shape[-1])
|
||||
X = X .view(-1, X .shape[-1])
|
||||
e = e .view(-1, e .shape[-1])
|
||||
g = g .view(-1, g .shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# DW_f = (D @ W.T * f)
|
||||
# DW_dfg = (D @ W.T * df * g)
|
||||
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
||||
DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g)
|
||||
h, DW_f, DW_dfg = DW, e, g # Inplace replacements
|
||||
# se = torch.nn.functional.sigmoid(e)
|
||||
# f = e * se
|
||||
# h = f * g
|
||||
# df = se * (1 - f) + f
|
||||
# DW_f = DW * f
|
||||
# DW_dfg = DW * df * g
|
||||
|
||||
# Down projection LoRA weights
|
||||
d_downA = h.t() @ (dY @ downB.t())
|
||||
d_downB = (downA.t() @ h.t()) @ dY
|
||||
d_downA *= downS
|
||||
d_downB *= downS
|
||||
|
||||
# Up projection LoRA weights
|
||||
d_upA = X.t() @ (DW_f @ upB.t())
|
||||
d_upB = (upA.t() @ X.t()) @ DW_f
|
||||
d_upA *= upS
|
||||
d_upB *= upS
|
||||
|
||||
# Gate projection LoRA weights
|
||||
d_gateA = X.t() @ (DW_dfg @ gateB.t())
|
||||
d_gateB = (gateA.t() @ X.t() @ DW_dfg)
|
||||
d_gateA *= gateS
|
||||
d_gateB *= gateS
|
||||
|
||||
# dC/dX = (D @ W.T * f) @ (U.T + B.T @ A.T)
|
||||
# + (D @ W.T * df * g) @ (G.T + B.T @ A.T)
|
||||
# (D @ W.T * f) @ U.T
|
||||
upW = fast_dequantize(upW.t(), upW_quant)
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T)
|
||||
dX = torch.matmul(DW_f, upW.t(), out = X)
|
||||
del upW
|
||||
dX += DW_f @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
||||
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T) + (D @ W.T * df * g) @ G.T
|
||||
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T) + (D @ W.T * df * g) @ (G.T + B.T @ A.T)
|
||||
dX += DW_dfg @ gateW.t()
|
||||
del gateW
|
||||
dX += DW_dfg @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
||||
|
||||
# gateW, gateW_quant, gateA, gateB, gateS,
|
||||
# upW, upW_quant, upA, upB, upS,
|
||||
# downW, downW_quant, downA, downB, downS,
|
||||
return dX.view(batch, seq_len, hd), \
|
||||
None, None, d_gateA.t(), d_gateB.t(), None, \
|
||||
None, None, d_upA.t(), d_upB.t(), None, \
|
||||
None, None, d_downA.t(), d_downB.t(), None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def apply_lora_mlp(self, X):
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(X,
|
||||
gateW, gateW_quant, gateA, gateB, gateS,
|
||||
upW, upW_quant, upA, upB, upS,
|
||||
downW, downW_quant, downA, downB, downS)
|
||||
return out
|
||||
pass
|
||||
|
||||
|
||||
class LoRA_QKV(torch.autograd.Function):
|
||||
"""
|
||||
### LoRA weights
|
||||
Wq = Wq + Aq @ Bq
|
||||
Wk = Wk + Ak @ Bk
|
||||
Wv = Wv + Av @ Bv
|
||||
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
||||
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
||||
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
||||
|
||||
### Backpropagation chain rule
|
||||
dC/dWq = X.T @ D(Wq)
|
||||
dC/dWk = X.T @ D(Wk)
|
||||
dC/dWv = X.T @ D(Wv)
|
||||
dC/dX = D(Wq) @ Wq.T
|
||||
+ D(Wk) @ Wk.T
|
||||
+ D(Wv) @ Wv.T
|
||||
|
||||
### Q projection LoRA weights
|
||||
dC/dAq = X.T @ D(Wq) @ B.T
|
||||
dC/dBq = A.T @ X.T @ D(Wq)
|
||||
|
||||
### K projection LoRA weights
|
||||
dC/dAk = X.T @ D(Wk) @ B.T
|
||||
dC/dBk = A.T @ X.T @ D(Wk)
|
||||
|
||||
### V projection LoRA weights
|
||||
dC/dAv = X.T @ D(Wv) @ B.T
|
||||
dC/dBv = A.T @ X.T @ D(Wv)
|
||||
"""
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(ctx, X : torch.Tensor,
|
||||
QW, QW_quant, QA, QB, QS,
|
||||
KW, KW_quant, KA, KB, KS,
|
||||
VW, VW_quant, VA, VB, VS,):
|
||||
dtype = X.dtype
|
||||
|
||||
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
|
||||
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
|
||||
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
|
||||
|
||||
ctx.custom_saved_tensors = (
|
||||
QW, QW_quant, QS,
|
||||
KW, KW_quant, KS,
|
||||
VW, VW_quant, VS,
|
||||
)
|
||||
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
|
||||
return Q, K, V
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dQ, dK, dV):
|
||||
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
|
||||
ctx.custom_saved_tensors
|
||||
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
|
||||
|
||||
QA, QB, KA, KB, VA, VB = \
|
||||
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dQ = dQ.view(-1, dQ.shape[-1])
|
||||
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
|
||||
dV = dV.view(-1, dV.shape[-1])
|
||||
X = X .view(-1, X .shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
### Weight projection LoRA weights
|
||||
# dC/dAq = X.T @ D(Wq) @ B.T
|
||||
# dC/dBq = A.T @ X.T @ D(Wq)
|
||||
|
||||
# Q Projection
|
||||
d_QA = X.t() @ (dQ @ QB.t())
|
||||
d_QB = (QA.t() @ X.t()) @ dQ
|
||||
d_QA *= QS
|
||||
d_QB *= QS
|
||||
|
||||
# K Projection
|
||||
d_KA = X.t() @ (dK @ KB.t())
|
||||
d_KB = (KA.t() @ X.t()) @ dK
|
||||
d_KA *= KS
|
||||
d_KB *= KS
|
||||
|
||||
# V Projection
|
||||
d_VA = X.t() @ (dV @ VB.t())
|
||||
d_VB = (VA.t() @ X.t()) @ dV
|
||||
d_VA *= VS
|
||||
d_VB *= VS
|
||||
|
||||
# d/dX
|
||||
# dC/dX = D(Wq) @ Wq.T
|
||||
QW = fast_dequantize(QW.t(), QW_quant)
|
||||
# D(Wq) @ (Wq.T + B.T @ A.T)
|
||||
dX = torch.matmul(dQ, QW.t(), out = X)
|
||||
del QW
|
||||
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
||||
|
||||
# D(Wq) @ Wq.T + D(Wk) @ Wk.T
|
||||
KW = fast_dequantize(KW.t(), KW_quant)
|
||||
# D(Wq) @ Wq.T + D(Wk) @ (Wk.T + B.T @ A.T)
|
||||
dX += dK @ KW.t()
|
||||
del KW
|
||||
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
||||
|
||||
# D(Wq) @ Wq.T + D(Wk) @ Wk.T + D(Wv) @ Wv.T
|
||||
VW = fast_dequantize(VW.t(), VW_quant)
|
||||
# D(Wq) @ Wq.T + D(Wk) @ Wk.T + D(Wv) @ (Wv.T + B.T @ A.T)
|
||||
dX += dV @ VW.t()
|
||||
del VW
|
||||
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
||||
|
||||
# QW, QW_quant, QA, QB, QS,
|
||||
# KW, KW_quant, KA, KB, KS,
|
||||
# VW, VW_quant, VA, VB, VS,
|
||||
return dX.view(batch, seq_len, hd), \
|
||||
None, None, d_QA.t(), d_QB.t(), None, \
|
||||
None, None, d_KA.t(), d_KB.t(), None, \
|
||||
None, None, d_VA.t(), d_VB.t(), None
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def apply_lora_qkv(self, X):
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(X,
|
||||
QW, QW_quant, QA, QB, QS,
|
||||
KW, KW_quant, KA, KB, KS,
|
||||
VW, VW_quant, VA, VB, VS,
|
||||
)
|
||||
return Q, K, V
|
||||
pass
|
||||
|
||||
|
||||
class LoRA_W(torch.autograd.Function):
|
||||
"""
|
||||
### LoRA weights
|
||||
Wq = Wq + Aq @ Bq
|
||||
Wk = Wk + Ak @ Bk
|
||||
Wv = Wv + Av @ Bv
|
||||
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
||||
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
||||
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
||||
|
||||
### Backpropagation chain rule
|
||||
dC/dWq = X.T @ D(Wq)
|
||||
dC/dWk = X.T @ D(Wk)
|
||||
dC/dWv = X.T @ D(Wv)
|
||||
dC/dX = D(Wq) @ Wq.T
|
||||
+ D(Wk) @ Wk.T
|
||||
+ D(Wv) @ Wv.T
|
||||
|
||||
### Q projection LoRA weights
|
||||
dC/dAq = X.T @ D(Wq) @ B.T
|
||||
dC/dBq = A.T @ X.T @ D(Wq)
|
||||
|
||||
### K projection LoRA weights
|
||||
dC/dAk = X.T @ D(Wk) @ B.T
|
||||
dC/dBk = A.T @ X.T @ D(Wk)
|
||||
|
||||
### V projection LoRA weights
|
||||
dC/dAv = X.T @ D(Wv) @ B.T
|
||||
dC/dBv = A.T @ X.T @ D(Wv)
|
||||
"""
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(ctx, X : torch.Tensor,
|
||||
W, W_quant, A, B, S):
|
||||
dtype = X.dtype
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
ctx.custom_saved_tensors = (W, W_quant, S,)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
return XW
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dY : torch.Tensor):
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
A, B = A.t(), B.t()
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.reshape(-1, dY.shape[-1]) # .view doesn't work on non contiguous
|
||||
X = X .reshape(-1, X .shape[-1]) # .view doesn't work on non contiguous
|
||||
dtype = X.dtype
|
||||
|
||||
### Weight projection LoRA weights
|
||||
# dC/dAq = X.T @ D(Wq) @ B.T
|
||||
# dC/dBq = A.T @ X.T @ D(Wq)
|
||||
|
||||
# Weight projection
|
||||
d_A = X.t() @ (dY @ B.t())
|
||||
d_B = (A.t() @ X.t()) @ dY
|
||||
d_A *= S
|
||||
d_B *= S
|
||||
|
||||
# dC/dX = D(Wq) @ Wq.T
|
||||
W = fast_dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), \
|
||||
None, None, d_A.t(), d_B.t(), None
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def apply_lora_o(self, X):
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
return O
|
||||
pass
|
||||
149
unsloth/kernels/rms_layernorm.py
Normal file
149
unsloth/kernels/rms_layernorm.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from .utils import calculate_settings
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_layernorm_forward(
|
||||
Y, Y_row_stride,
|
||||
X, X_row_stride,
|
||||
W, W_row_stride,
|
||||
r, r_row_stride,
|
||||
n_cols, eps,
|
||||
BLOCK_SIZE : tl.constexpr
|
||||
):
|
||||
row_idx = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
Y += row_idx * Y_row_stride
|
||||
X += row_idx * X_row_stride
|
||||
r += row_idx * r_row_stride
|
||||
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
inv_var = 1 / tl.sqrt(row_var + eps)
|
||||
tl.store(r, inv_var)
|
||||
normed = X_row * inv_var
|
||||
output = normed * W_row
|
||||
tl.store(Y + col_offsets, output, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_layernorm_backward(
|
||||
#dX, dX_row_stride,
|
||||
dY, dY_row_stride,
|
||||
X, X_row_stride,
|
||||
W, W_row_stride,
|
||||
r, r_row_stride,
|
||||
dW, dW_row_stride,
|
||||
n_cols, eps,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
#dX += row_idx * dX_row_stride + col_offsets
|
||||
dY += row_idx * dY_row_stride
|
||||
X += row_idx * X_row_stride
|
||||
r += row_idx * r_row_stride
|
||||
|
||||
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
# inv_var = 1 / tl.sqrt(row_var + eps)
|
||||
inv_var = tl.load(r).to(tl.float32)
|
||||
normed = X_row * inv_var
|
||||
|
||||
dY_W = dY_row * W_row
|
||||
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
||||
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
||||
#tl.store(dX, output, mask = mask)
|
||||
tl.store(dY + col_offsets, output, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
class Fast_RMS_Layernorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, X, W, eps):
|
||||
shape = X.shape
|
||||
dim = shape[-1]
|
||||
X = X.view(-1, dim)
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
|
||||
|
||||
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
_rms_layernorm_forward[(n_rows,)](
|
||||
Y, Y.stride(0),
|
||||
X, X.stride(0),
|
||||
W, W.stride(0),
|
||||
r, r.stride(0),
|
||||
n_cols, eps,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(X, W, r)
|
||||
return Y.view(*shape)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dY):
|
||||
shape = dY.shape
|
||||
dim = shape[-1]
|
||||
dY = dY.view(-1, dim)
|
||||
X, W, r = ctx.saved_tensors
|
||||
n_rows, n_cols = dY.shape
|
||||
dW = X
|
||||
|
||||
# dX = torch.empty_like(dY)
|
||||
# dX = dY
|
||||
_rms_layernorm_backward[(n_rows,)](
|
||||
#dX, dX.stride(0),
|
||||
dY, dY.stride(0),
|
||||
X, X .stride(0),
|
||||
W, W .stride(0),
|
||||
r, r .stride(0),
|
||||
dW, dW.stride(0),
|
||||
n_cols, ctx.eps,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
#dX = dX.view(*shape)
|
||||
dX = dY.view(*shape)
|
||||
# X, W, eps
|
||||
return dX, None, None
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_rms_layernorm(layernorm, X):
|
||||
W = layernorm.weight
|
||||
eps = layernorm.variance_epsilon
|
||||
out = Fast_RMS_Layernorm.apply(X, W, eps)
|
||||
return out
|
||||
pass
|
||||
178
unsloth/kernels/rope_embedding.py
Normal file
178
unsloth/kernels/rope_embedding.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from .utils import calculate_settings
|
||||
|
||||
|
||||
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
|
||||
@triton.jit
|
||||
def _rope_embedding(
|
||||
Q, Q_row_stride,
|
||||
cos, cos_row_stride,
|
||||
sin, sin_row_stride,
|
||||
seqlen, head_dim,
|
||||
BACKWARD_PASS: tl.constexpr,
|
||||
BLOCK_SIZE : tl.constexpr,
|
||||
):
|
||||
row_position = tl.program_id(0)
|
||||
head_position = tl.program_id(1)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
half_head_dim = head_dim // 2
|
||||
mask = col_offsets < half_head_dim
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
|
||||
rot_position = row_position % seqlen
|
||||
|
||||
Q += row_position* Q_row_stride + head_position*head_dim
|
||||
cos += rot_position*cos_row_stride
|
||||
sin += rot_position*sin_row_stride
|
||||
|
||||
Q1 = tl.load(Q + half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
sin1 = tl.load(sin + half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
cos1 = tl.load(cos + half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
||||
|
||||
Q2 = tl.load(Q + half_head_dim*1 + col_offsets, mask = mask, other = 0)
|
||||
# RoPE repeats sin and cos so 128 = [64, 64].
|
||||
# sin2 = tl.load(sin + half_head_dim*1, mask = mask, other = 0)
|
||||
# cos2 = tl.load(cos + half_head_dim*1, mask = mask, other = 0)
|
||||
|
||||
if BACKWARD_PASS:
|
||||
"""
|
||||
Q * cos + rotate_half(Q) * sin
|
||||
is equivalent to
|
||||
Q * cos + Q @ R * sin
|
||||
where R is a rotation matrix [ 0, I]
|
||||
[-I, 0]
|
||||
dC/dY = dY * cos + dY @ R.T * sin
|
||||
where R.T is again the same [ 0, -I]
|
||||
but the minus is transposed. [ I, 0]
|
||||
"""
|
||||
# sin1, sin2 = -sin1, -sin2
|
||||
sin1 = -sin1
|
||||
|
||||
# tl.store(Q + half_head_dim*0, Q1*cos1 - Q2*sin1, mask = mask)
|
||||
# tl.store(Q + half_head_dim*1, Q2*cos2 + Q1*sin2, mask = mask)
|
||||
# RoPE repeats sin and cos so 128 = [64, 64].
|
||||
tl.store(Q + half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
|
||||
tl.store(Q + half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
class Fast_RoPE_Embedding(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, Q, cos, sin):
|
||||
cos, sin = cos.squeeze(), sin.squeeze()
|
||||
batch, seq_len, n_heads, head_dim = Q.shape
|
||||
Q = Q.view(batch*seq_len, n_heads*head_dim)
|
||||
n_rows, n_cols = Q.shape
|
||||
assert(seq_len <= cos.shape[0])
|
||||
|
||||
# [TODO] Changing blocksize to head_dim//2 seems to have
|
||||
# some concurrency / un-deterministic issues.
|
||||
BLOCK_SIZE, num_warps = calculate_settings(head_dim) # (head_dim//2)
|
||||
_rope_embedding[(n_rows, n_heads,)](
|
||||
Q, Q.stride(0),
|
||||
cos, cos.stride(0),
|
||||
sin, sin.stride(0),
|
||||
seq_len, head_dim,
|
||||
BACKWARD_PASS = False,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.cos = cos # Don't need save_for_backward since a view
|
||||
ctx.sin = sin
|
||||
return Q.view(batch, seq_len, n_heads, head_dim)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dY):
|
||||
batch, seq_len, n_heads, head_dim = dY.shape
|
||||
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
|
||||
# Cannot be .view since the problem lies with dK since
|
||||
# K.T's strides are incorrect.
|
||||
n_rows, n_cols = dY.shape
|
||||
|
||||
cos = ctx.cos
|
||||
sin = ctx.sin
|
||||
|
||||
_rope_embedding[(n_rows, n_heads,)](
|
||||
dY, dY .stride(0),
|
||||
cos, cos.stride(0),
|
||||
sin, sin.stride(0),
|
||||
seq_len, head_dim,
|
||||
BACKWARD_PASS = True,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
dY = dY.view(batch, seq_len, n_heads, head_dim)
|
||||
return dY, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_rope_embedding(Q, K, cos, sin):
|
||||
# We need (batch, [seqlen, n_heads], head_dim)
|
||||
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
|
||||
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
|
||||
# We need (batch, [n_heads, seqlen], head_dim)
|
||||
return Q, K
|
||||
pass
|
||||
|
||||
|
||||
class Slow_RoPE_Embedding(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, Q, cos, sin, position_ids):
|
||||
if position_ids is not None:
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
|
||||
# Q * cos + rotate_half(Q) * sin
|
||||
half = Q.shape[-1]//2
|
||||
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
||||
Q *= cos
|
||||
RH_Q *= sin
|
||||
Q += RH_Q
|
||||
ctx.save_for_backward(cos, sin)
|
||||
return Q
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dY):
|
||||
cos, sin = ctx.saved_tensors
|
||||
# Q * cos + rotate_half.T(Q) * sin
|
||||
half = dY.shape[-1]//2
|
||||
# We reverse the minus sign for R.T
|
||||
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
||||
dY *= cos
|
||||
RH_dY *= sin
|
||||
dY += RH_dY
|
||||
return dY, None, None, None
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
|
||||
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
|
||||
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
|
||||
return Q, K
|
||||
pass
|
||||
88
unsloth/kernels/swiglu.py
Normal file
88
unsloth/kernels/swiglu.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from .utils import calculate_settings
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
# https://github.com/openai/triton/issues/241 exp MUST be done in f32
|
||||
# or else Triton crashes
|
||||
f_row = e_row / (1 + tl.exp(-e_row))
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
|
||||
tl.store(h + offsets, h_row, mask = mask)
|
||||
pass
|
||||
|
||||
|
||||
def swiglu_fg_kernel(e, g):
|
||||
batch, seq_len, hd = e.shape
|
||||
n_elements = e.numel()
|
||||
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
|
||||
return h
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
# https://github.com/openai/triton/issues/241 exp MUST be done in f32
|
||||
# or else Triton crashes
|
||||
se_row = 1 / (1 + tl.exp(-e_row))
|
||||
f_row = e_row * se_row
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
# df = se * (1 - f) + f
|
||||
# DW_f = DW * f
|
||||
DWf_row = DW_row * f_row
|
||||
# DW_dfg = DW * df * g
|
||||
# DW_dfg = DW * (se * (1 - f) + f) * g
|
||||
# DW_dfg = DW * (se*(g - h) + h)
|
||||
DW_dfg_row = DW_row * (se_row*(g_row - h_row) + h_row)
|
||||
|
||||
tl.store(DW + offsets, h_row, mask = mask) # h
|
||||
tl.store(e + offsets, DWf_row, mask = mask) # DW * f
|
||||
tl.store(g + offsets, DW_dfg_row, mask = mask) # DW * df * g
|
||||
pass
|
||||
|
||||
|
||||
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
||||
batch_seq_len, hd = e.shape
|
||||
n_elements = e.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
||||
return DW, e, g # h, DW * f, DW * df * g
|
||||
pass
|
||||
93
unsloth/kernels/utils.py
Normal file
93
unsloth/kernels/utils.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import triton
|
||||
MAX_FUSED_SIZE = 65535 # 2**16 - 1
|
||||
next_power_of_2 = triton.next_power_of_2
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = next_power_of_2(n)
|
||||
# CUDA only supports 65535 - 2^16-1 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||
return BLOCK_SIZE, num_warps
|
||||
pass
|
||||
|
||||
|
||||
import bitsandbytes as bnb
|
||||
get_ptr = bnb.functional.get_ptr
|
||||
import ctypes
|
||||
import torch
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
|
||||
def QUANT_STATE(W):
|
||||
return getattr(W, "quant_state", None)
|
||||
pass
|
||||
|
||||
def fast_dequantize(W, quant_state = None, out = None):
|
||||
if quant_state is None: return W
|
||||
if type(quant_state) is not list:
|
||||
# New quant_state as a class
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
||||
absmax = quant_state.absmax
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
offset = quant_state.offset
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax
|
||||
code2 = state2.code
|
||||
blocksize2 = state2.blocksize
|
||||
else:
|
||||
# Old quant_state as a list of lists
|
||||
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
||||
offset, state2 = compressed_stats
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
pass
|
||||
|
||||
# Create weight matrix
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype = dtype, device = "cuda")
|
||||
else:
|
||||
assert(out.shape == shape)
|
||||
assert(out.dtype == dtype)
|
||||
|
||||
# NF4 dequantization of statistics
|
||||
n_elements_absmax = absmax.numel()
|
||||
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
|
||||
|
||||
# Do dequantization
|
||||
ptr_out_absmax = get_ptr(out_absmax)
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax)
|
||||
)
|
||||
out_absmax += offset
|
||||
|
||||
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
||||
cdequantize_blockwise_bf16_nf4
|
||||
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
||||
ctypes.c_int(blocksize), ctypes.c_int(out.numel()))
|
||||
|
||||
# Careful returning transposed data
|
||||
is_transposed = (True if W.shape[0] == 1 else False)
|
||||
return out.t() if is_transposed else out
|
||||
pass
|
||||
44
unsloth/models/__init__.py
Normal file
44
unsloth/models/__init__.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import os
|
||||
|
||||
# Currently only supports 1 GPU, or else seg faults will occur.
|
||||
reload_package = False
|
||||
n_gpus = torch.cuda.device_count()
|
||||
if n_gpus == 0:
|
||||
raise RuntimeError("Unsloth: Requires at least 1 GPU. Found 0.")
|
||||
elif n_gpus > 1:
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
device = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
if not device.isdigit():
|
||||
print(f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {device} "\
|
||||
"but we require 'CUDA_VISIBLE_DEVICES=0'\n"\
|
||||
"We shall set it ourselves.")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
reload_package = True
|
||||
else:
|
||||
print("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
reload_package = True
|
||||
pass
|
||||
|
||||
# Reload Pytorch with CUDA_VISIBLE_DEVICES
|
||||
if reload_package:
|
||||
import importlib
|
||||
importlib.reload(torch)
|
||||
pass
|
||||
|
||||
from .llama import FastLlamaModel
|
||||
61
unsloth/models/_utils.py
Normal file
61
unsloth/models/_utils.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from typing import Union, Optional, List, Any, Callable
|
||||
import numpy as np
|
||||
import warnings
|
||||
import gc
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
def prepare_model_for_kbit_training(
|
||||
model : Any,
|
||||
use_gradient_checkpointing : bool = True,
|
||||
use_reentrant : Optional[bool] = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Calculates where to place the gradient checkpoints given n_layers.
|
||||
We also freeze all other layers's gradients
|
||||
|
||||
Args:
|
||||
model: Any LlamaModel with layers.
|
||||
use_gradient_checkpointing (`bool`, *optional*):
|
||||
Default enabled. Provides memory savings by not saving all activations,
|
||||
but only some.
|
||||
use_reentrant (`bool`, *optional*):
|
||||
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
|
||||
Optimal gradient checkpointing algorithm which will be the default in
|
||||
future Pytorch versions.
|
||||
"""
|
||||
|
||||
# Freeze all parameters
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
return model
|
||||
pass
|
||||
734
unsloth/models/llama.py
Normal file
734
unsloth/models/llama.py
Normal file
|
|
@ -0,0 +1,734 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
# apply_rotary_pos_emb,
|
||||
# repeat_kv,
|
||||
# _prepare_4d_causal_attention_mask,
|
||||
logger,
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ..kernels import *
|
||||
|
||||
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
|
||||
major_version, minor_version = torch.cuda.get_device_capability()
|
||||
if major_version >= 8:
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
HAS_FLASH_ATTENTION = True
|
||||
except:
|
||||
HAS_FLASH_ATTENTION = False
|
||||
else:
|
||||
# Tri Dao's benchmark shows xformers is faster for now.
|
||||
HAS_FLASH_ATTENTION = False
|
||||
pass
|
||||
|
||||
import xformers.ops.fmha as xformers
|
||||
xformers_attention = xformers.memory_efficient_attention
|
||||
|
||||
# Final patching code
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
from peft import PeftModelForCausalLM
|
||||
import gc
|
||||
import peft
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import types
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from transformers import set_seed as transformers_set_seed
|
||||
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
|
||||
from ._utils import (
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
|
||||
|
||||
def original_apply_qkv(self, X):
|
||||
Q = self.q_proj(X)
|
||||
K = self.k_proj(X)
|
||||
V = self.v_proj(X)
|
||||
return Q, K, V
|
||||
pass
|
||||
|
||||
|
||||
def original_apply_o(self, X):
|
||||
O = self.o_proj(X)
|
||||
return O
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
|
||||
def LlamaAttention_fast_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
*args, **kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Q = self.q_proj(hidden_states)
|
||||
# K = self.k_proj(hidden_states)
|
||||
# V = self.v_proj(hidden_states)
|
||||
Q, K, V = self.apply_qkv(self, hidden_states)
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
||||
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
||||
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = K.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
|
||||
# Q, K = apply_rotary_pos_emb(Q, K, cos, sin, position_ids)
|
||||
if position_ids is None:
|
||||
cos = self.rotary_emb.cos_cached
|
||||
sin = self.rotary_emb.sin_cached
|
||||
Q, K = fast_rope_embedding(Q, K, cos, sin)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
|
||||
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
||||
pass
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
# no_attention_mask = attention_mask is None
|
||||
# Ignore attention_mask
|
||||
|
||||
if (not HAS_FLASH_ATTENTION): #and no_attention_mask:
|
||||
# Xformers memory efficient attention
|
||||
# Also has Flash Attention v2 dispatching
|
||||
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
Q = Q.reshape(bsz, q_len, n_groups, n_kv_heads, head_dim)
|
||||
|
||||
K = K.reshape(bsz, q_len, n_groups, 1, head_dim)
|
||||
V = V.reshape(bsz, q_len, n_groups, 1, head_dim)
|
||||
K = K .expand(bsz, q_len, n_groups, n_kv_heads, head_dim)
|
||||
V = V .expand(bsz, q_len, n_groups, n_kv_heads, head_dim)
|
||||
pass
|
||||
|
||||
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
||||
A = A.view(bsz, q_len, n_heads, head_dim)
|
||||
|
||||
elif HAS_FLASH_ATTENTION:# and no_attention_mask:
|
||||
# Flash Attention
|
||||
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
|
||||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
|
||||
# Flash Attention v2 auto supports grouped query attention
|
||||
A = flash_attn_func(Q, K, V, causal = True)
|
||||
|
||||
else:
|
||||
# Uses Pytorch's scaled dot product attention
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
pass
|
||||
|
||||
# Grouped query attention
|
||||
# K = repeat_kv(K, n_groups)
|
||||
# V = repeat_kv(V, n_groups)
|
||||
if n_groups != 1:
|
||||
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
||||
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
||||
pass
|
||||
|
||||
# Needs (batch_size, n_heads, seq_len, head_dim)
|
||||
# is_casual and attention_mask must not be both set!
|
||||
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = attention_mask is None)
|
||||
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
||||
A = A.transpose(1, 2)
|
||||
pass
|
||||
attn_output = A.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
# attn_output = self.o_proj(attn_output)
|
||||
attn_output = self.apply_o(self, attn_output)
|
||||
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
|
||||
def LlamaDecoderLayer_fast_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
*args, **kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
# hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
# hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
|
||||
def LlamaModel_fast_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
assert(output_attentions is False)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# We already handle KV cache position_ids ourselves.
|
||||
if (past_key_values_length != 0):
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype = torch.int32,#dtype=torch.long,
|
||||
device = "cuda",
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
elif position_ids is not None:
|
||||
position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
if position_ids is not None:
|
||||
if position_ids.shape[0] != batch_size:
|
||||
position_ids = position_ids.repeat((batch_size, 1))
|
||||
|
||||
# embed positions
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# Ignore attention_mask
|
||||
if True:
|
||||
# if attention_mask is None:
|
||||
# attention_mask = torch.ones(
|
||||
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
# )
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
pass
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
pass
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
use_reentrant=True,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
pass
|
||||
|
||||
# hidden_states = self.norm(hidden_states)
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def LlamaForCausalLM_fast_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
if causal_mask is None:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# logits = logits.float()
|
||||
# shift_logits = logits[..., :-1, :].contiguous()
|
||||
# shift_labels = labels[..., 1:].contiguous()
|
||||
# shift_labels = shift_labels.view(-1)
|
||||
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_logits = logits
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
|
||||
# loss_fct = torch.nn.CrossEntropyLoss(
|
||||
# ignore_index = self.ignore_index,
|
||||
# label_smoothing = self.label_smoothing,
|
||||
# )
|
||||
# loss = loss_fct(shift_logits, shift_labels)
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
)
|
||||
pass
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def PeftModelForCausalLM_fast_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
causal_mask=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
task_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
return self.base_model(
|
||||
input_ids=input_ids,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class FastLlamaModel:
|
||||
|
||||
@staticmethod
|
||||
def pre_patch():
|
||||
LlamaAttention .forward = LlamaAttention_fast_forward
|
||||
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
|
||||
LlamaModel .forward = LlamaModel_fast_forward
|
||||
LlamaForCausalLM .forward = LlamaForCausalLM_fast_forward
|
||||
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
||||
return
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name = "meta-llama/Llama-2-7b-hf",
|
||||
max_seq_length = 4096,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
token = None,
|
||||
device_map = "sequential",
|
||||
):
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
|
||||
|
||||
statistics = \
|
||||
"==((====))== Unsloth: Fast Llama patching release 23.11\n"\
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
|
||||
f"O^O/ \_/ \\ CUDA compute capability = {gpu_stats.major}.{gpu_stats.minor}\n"\
|
||||
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
|
||||
f' "-____-" bfloat16 support = {str(SUPPORTS_BFLOAT16).upper()}\n'
|
||||
print(statistics)
|
||||
|
||||
FastLlamaModel.pre_patch()
|
||||
|
||||
if dtype is None:
|
||||
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
|
||||
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
|
||||
logger.warning_once("Device does not support bfloat16. Will change to float16.")
|
||||
dtype = torch.float16
|
||||
|
||||
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
|
||||
|
||||
# [TODO]: Determine RoPE scaling
|
||||
# https://github.com/huggingface/transformers/pull/24653
|
||||
assert(max_seq_length <= 4096)
|
||||
|
||||
bnb_config = None
|
||||
if load_in_4bit:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit = True,
|
||||
bnb_4bit_use_double_quant = True,
|
||||
bnb_4bit_quant_type = "nf4",
|
||||
bnb_4bit_compute_dtype = dtype,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_map = device_map,
|
||||
torch_dtype = dtype,
|
||||
quantization_config = bnb_config,
|
||||
token = token,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = max_seq_length,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
config = model.config.update({"pad_token_id" : tokenizer.unk_token_id});
|
||||
|
||||
model = FastLlamaModel.post_patch(model)
|
||||
|
||||
# Patch up QKV / O and MLP
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
layer.self_attn.apply_qkv = original_apply_qkv
|
||||
layer.self_attn.apply_o = original_apply_o
|
||||
pass
|
||||
return model, tokenizer
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
def post_patch(model):
|
||||
# Patch model
|
||||
layers = model.model.layers
|
||||
|
||||
# Torch.compile fails on embedding matrix??
|
||||
# Workaround randomnly fixes it for torch versions < 2.2
|
||||
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
|
||||
|
||||
# We also do this for the lm_head
|
||||
lm_head = torch.nn.Linear(1, 1, bias = None)
|
||||
del lm_head.weight
|
||||
lm_head.weight = model.lm_head.weight
|
||||
lm_head.in_features = lm_head.weight.shape[1]
|
||||
lm_head.out_features = lm_head.weight.shape[0]
|
||||
model.lm_head = lm_head
|
||||
|
||||
# Also patch all dtypes - BnB seems to not allocate the correct type?
|
||||
# BnB default dtype seems to be float16!
|
||||
correct_dtype = lm_head.weight.dtype
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (bnb.nn.Linear4bit, peft.tuners.lora.Linear4bit)):
|
||||
weight = module.weight
|
||||
quant_state = weight.quant_state
|
||||
|
||||
if type(quant_state) is list:
|
||||
# BnB seems to have float16 as default!
|
||||
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
|
||||
else:
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
||||
quant_state.dtype = correct_dtype
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
# Clear deleted GPU items
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_peft_model(
|
||||
model,
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
layers_to_transform = None,
|
||||
use_gradient_checkpointing = True,
|
||||
random_state = 3407,
|
||||
max_seq_length = 2048,
|
||||
):
|
||||
if lora_dropout != 0:
|
||||
raise TypeError("Unsloth: Fast Llama patching only works with dropout = 0.")
|
||||
if bias != "none":
|
||||
raise TypeError("Unsloth: Fast Llama patching only works with bias = 'none'.")
|
||||
|
||||
transformers_set_seed(random_state)
|
||||
|
||||
accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",),)
|
||||
for module in target_modules:
|
||||
assert(module in accepted_modules)
|
||||
pass
|
||||
|
||||
# Get LoRA
|
||||
lora_config = LoraConfig(
|
||||
r = r,
|
||||
lora_alpha = lora_alpha,
|
||||
target_modules = target_modules,
|
||||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
task_type = TaskType.CAUSAL_LM,
|
||||
layers_to_transform = layers_to_transform,
|
||||
)
|
||||
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
use_reentrant = True,
|
||||
)
|
||||
model = _get_peft_model(model, lora_config)
|
||||
|
||||
# Do patching
|
||||
for idx, layer in enumerate(model.model.model.layers):
|
||||
|
||||
# MLP patching
|
||||
if hasattr(layer.mlp.gate_proj, "lora_A") and \
|
||||
hasattr(layer.mlp. up_proj, "lora_A") and \
|
||||
hasattr(layer.mlp.down_proj, "lora_A"):
|
||||
|
||||
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
|
||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||
pass
|
||||
|
||||
# QKV attention patching
|
||||
if hasattr(layer.self_attn.q_proj, "lora_A") and \
|
||||
hasattr(layer.self_attn.k_proj, "lora_A") and \
|
||||
hasattr(layer.self_attn.v_proj, "lora_A"):
|
||||
|
||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||
pass
|
||||
|
||||
# O attention patching
|
||||
if hasattr(layer.self_attn.o_proj, "lora_A"):
|
||||
|
||||
layer.self_attn.apply_o = apply_lora_o
|
||||
pass
|
||||
pass
|
||||
|
||||
# Patch cross entropy loss labels
|
||||
model.model.extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda")
|
||||
|
||||
return model
|
||||
pass
|
||||
pass
|
||||
Loading…
Reference in a new issue