First upload of Unsloth code

This commit is contained in:
Daniel Han-Chen 2023-11-30 03:51:54 +11:00
parent 1e2ba1b1d2
commit c3d6def64a
13 changed files with 2053 additions and 2 deletions

View file

@ -1,2 +1,34 @@
# unsloth
2x faster 50% less memory LLM finetuning on a single GPU
# Unsloth
2x faster 50% less memory LLM finetuning on a single GPU.
`!pip install "unsloth[cu118] @ git+https://github.com/unslothai/unsloth.git"`
`!pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"`
### Google Colab examples
1. [Unsloth fast finetuning example](https://colab.research.google.com/drive/1oW55fBmwzCOrBVX66RcpptL3a99qWBxb?usp=sharing)
2. [Original slow finetuning example](https://colab.research.google.com/drive/1c7zxdLHaLJ9R9YTZ74y4tUERvS-kySyA?usp=sharing)
### Installation instructions
In Google Colab:
```
!ldconfig /usr/lib64-nvidia
!pip install xformers --index-url https://download.pytorch.org/whl/cu118
!pip install git+https://github.com/danielhanchen/unsloth.git
```
`!ldconfig /usr/lib64-nvidia` is necessary (for now) to link CUDA with Python. Possibly a Google Colab linking bug.
For general installations:
1. Install Xformers *OR* Flash Attention. Choose 1. Old GPUs use Xformers. New use Flash Attention.
2. For Xformers, find your Pytorch CUDA version via `torch.version.cuda` or `nvidia-smi`.
* If you have Conda, `conda install xformers -c xformers`
* If you have CUDA 11.8, `pip install xformers --index-url https://download.pytorch.org/whl/cu118`
* If you have CUDA 12.1, `pip install xformers --index-url https://download.pytorch.org/whl/cu121`
* Go to https://github.com/facebookresearch/xformers for other issues.
* You must have Pytorch 2.1 installed for Xformers. If not, try Flash Attention.
* Xformers supports all GPUs (Tesla T4 etc).
3. For Flash Attention, you must have a Ampere, Ada, Hopper GPU (A100, RTX 3090, RTX 4090, H100).
* Install Flash Attention via `pip uninstall -y ninja && pip install ninja` then `pip install flash-attn --no-build-isolation`.
* Xformers has native support for Flash Attention, so technically installing Xformers is enough.
4. Then install Unsloth:
`pip install git+https://github.com/danielhanchen/unsloth.git`

51
pyproject.toml Normal file
View file

@ -0,0 +1,51 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "unsloth"
version = "2023.11"
description = "2X faster LLM finetuning"
readme = "README.md"
requires-python = ">=3.9"
license = {file = "LICENSE"}
keywords = ["ai", "llm",]
authors = [
{email = "info@unsloth.ai"},
{name = "Unsloth AI team"},
]
maintainers = [
{name = "Daniel Han", email = "danielhanchen@gmail.com"},
{name = "Michael Han", email = "info@unsloth.ai"},
]
classifiers = [
"Programming Language :: Python",
]
dependencies = [
"transformers",
"bitsandbytes",
"datasets",
"sentencepiece",
"accelerate",
"trl",
"peft",
"torch>=2.1.0",
]
[project.optional-dependencies]
cu118 = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system=='Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system=='Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system=='Linux'",
]
cu121 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system=='Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system=='Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system=='Linux'",
]
[project.urls]
homepage = "http://www.unsloth.ai"
documentation = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"

16
unsloth/__init__.py Normal file
View file

@ -0,0 +1,16 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2023.11"
from .models import *

View file

@ -0,0 +1,24 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entropy_loss import fast_cross_entropy_loss
from .rms_layernorm import fast_rms_layernorm
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
from .fast_lora import (
apply_lora_mlp,
apply_lora_qkv,
apply_lora_o,
)
from .utils import fast_dequantize, QUANT_STATE

View file

@ -0,0 +1,167 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _cross_entropy_forward(logits_ptr, logits_row_stride,
loss_ptr,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
loss_ptr += row_idx
lse_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
max_logits = tl.max(logits, 0)
# Maximum stops overflow
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr, lse)
if label_idx != -100:
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = lse - logits_label
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
@triton.jit
def _cross_entropy_backward(logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
dloss_ptr += row_idx * dloss_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
pass
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
n_rows, n_cols = logits.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
n_cols,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(logits, logsumexp, labels)
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, n_cols = logits.shape
_cross_entropy_backward[(n_rows,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
logsumexp,
labels,
n_cols,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert(labels.shape == (batch, seq_len))
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass

View file

@ -0,0 +1,414 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .utils import fast_dequantize, QUANT_STATE
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def get_lora_parameters(proj):
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)
A, B = A.t(), B.t()
if X.dim() == 3:
batch, seq_len, d = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
pass
out = torch.matmul(X, W, out = out)
if W_quant is not None: del W
out += (X @ A) @ (s * B)
return out.view(batch, seq_len, -1) if reshape else out
pass
class LoRA_MLP(torch.autograd.Function):
"""
### LoRA weights
G = G + Ag @ Bg
U = U + Au @ Bu
W = W + Aw @ Bw
### SwiGLU(X)
e = X @ G
f = e * sigmoid(e)
g = X @ U
h = f * g
i = h @ W
### Backpropagation chain rule
df = sigmoid(e) * (1 - f) + f
dC/dW = h.T @ dY
dC/dU = X.T @ (D @ W.T * f)
dC/dG = X.T @ (D @ W.T * df * g)
dC/dX = (D @ W.T * f) @ U.T
+ (D @ W.T * df * g) @ G.T
### Down projection LoRA weights
dC/dAw = dC/dW @ B.T
dC/dBw = A.T @ dC/dW
dC/dAw = h.T @ dY @ B.T
dC/dBw = A.T @ h.T @ dY
### Up projection LoRA weights
dC/dAu = X.T @ (D @ W.T * f) @ B.T
dC/dBu = A.T @ X.T @ (D @ W.T * f)
### Gate projection LoRA weights
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X : torch.Tensor,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS):
dtype = X.dtype
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
h = swiglu_fg_kernel(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
ctx.custom_saved_tensors = (
gateW, gateW_quant, gateS,
upW, upW_quant, upS,
downW, downW_quant, downS,
)
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
X, e, g)
return i
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY : torch.Tensor):
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, = \
ctx.custom_saved_tensors
gateA, gateB, upA,upB, downA, downB, \
X, e, g = ctx.saved_tensors
gateA, gateB, upA,upB, downA, downB = \
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
batch, seq_len, hd = X.shape
dY = dY.view(-1, dY.shape[-1])
X = X .view(-1, X .shape[-1])
e = e .view(-1, e .shape[-1])
g = g .view(-1, g .shape[-1])
dtype = X.dtype
# DW_f = (D @ W.T * f)
# DW_dfg = (D @ W.T * df * g)
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g)
h, DW_f, DW_dfg = DW, e, g # Inplace replacements
# se = torch.nn.functional.sigmoid(e)
# f = e * se
# h = f * g
# df = se * (1 - f) + f
# DW_f = DW * f
# DW_dfg = DW * df * g
# Down projection LoRA weights
d_downA = h.t() @ (dY @ downB.t())
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS
# Up projection LoRA weights
d_upA = X.t() @ (DW_f @ upB.t())
d_upB = (upA.t() @ X.t()) @ DW_f
d_upA *= upS
d_upB *= upS
# Gate projection LoRA weights
d_gateA = X.t() @ (DW_dfg @ gateB.t())
d_gateB = (gateA.t() @ X.t() @ DW_dfg)
d_gateA *= gateS
d_gateB *= gateS
# dC/dX = (D @ W.T * f) @ (U.T + B.T @ A.T)
# + (D @ W.T * df * g) @ (G.T + B.T @ A.T)
# (D @ W.T * f) @ U.T
upW = fast_dequantize(upW.t(), upW_quant)
# (D @ W.T * f) @ (U.T + B.T @ A.T)
dX = torch.matmul(DW_f, upW.t(), out = X)
del upW
dX += DW_f @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
# (D @ W.T * f) @ (U.T + B.T @ A.T) + (D @ W.T * df * g) @ G.T
gateW = fast_dequantize(gateW.t(), gateW_quant)
# (D @ W.T * f) @ (U.T + B.T @ A.T) + (D @ W.T * df * g) @ (G.T + B.T @ A.T)
dX += DW_dfg @ gateW.t()
del gateW
dX += DW_dfg @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
# downW, downW_quant, downA, downB, downS,
return dX.view(batch, seq_len, hd), \
None, None, d_gateA.t(), d_gateB.t(), None, \
None, None, d_upA.t(), d_upB.t(), None, \
None, None, d_downA.t(), d_downB.t(), None,
pass
pass
def apply_lora_mlp(self, X):
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(X,
gateW, gateW_quant, gateA, gateB, gateS,
upW, upW_quant, upA, upB, upS,
downW, downW_quant, downA, downB, downS)
return out
pass
class LoRA_QKV(torch.autograd.Function):
"""
### LoRA weights
Wq = Wq + Aq @ Bq
Wk = Wk + Ak @ Bk
Wv = Wv + Av @ Bv
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
K = X @ Wk = X @ Wk + X @ Ak @ Bk
V = X @ Wv = X @ Wv + X @ Av @ Bv
### Backpropagation chain rule
dC/dWq = X.T @ D(Wq)
dC/dWk = X.T @ D(Wk)
dC/dWv = X.T @ D(Wv)
dC/dX = D(Wq) @ Wq.T
+ D(Wk) @ Wk.T
+ D(Wv) @ Wv.T
### Q projection LoRA weights
dC/dAq = X.T @ D(Wq) @ B.T
dC/dBq = A.T @ X.T @ D(Wq)
### K projection LoRA weights
dC/dAk = X.T @ D(Wk) @ B.T
dC/dBk = A.T @ X.T @ D(Wk)
### V projection LoRA weights
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X : torch.Tensor,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,):
dtype = X.dtype
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
ctx.custom_saved_tensors = (
QW, QW_quant, QS,
KW, KW_quant, KS,
VW, VW_quant, VS,
)
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
return Q, K, V
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dQ, dK, dV):
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
ctx.custom_saved_tensors
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
QA, QB, KA, KB, VA, VB = \
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
batch, seq_len, hd = X.shape
dQ = dQ.view(-1, dQ.shape[-1])
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
dV = dV.view(-1, dV.shape[-1])
X = X .view(-1, X .shape[-1])
dtype = X.dtype
### Weight projection LoRA weights
# dC/dAq = X.T @ D(Wq) @ B.T
# dC/dBq = A.T @ X.T @ D(Wq)
# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS
# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS
# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS
# d/dX
# dC/dX = D(Wq) @ Wq.T
QW = fast_dequantize(QW.t(), QW_quant)
# D(Wq) @ (Wq.T + B.T @ A.T)
dX = torch.matmul(dQ, QW.t(), out = X)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
# D(Wq) @ Wq.T + D(Wk) @ Wk.T
KW = fast_dequantize(KW.t(), KW_quant)
# D(Wq) @ Wq.T + D(Wk) @ (Wk.T + B.T @ A.T)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
# D(Wq) @ Wq.T + D(Wk) @ Wk.T + D(Wv) @ Wv.T
VW = fast_dequantize(VW.t(), VW_quant)
# D(Wq) @ Wq.T + D(Wk) @ Wk.T + D(Wv) @ (Wv.T + B.T @ A.T)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
return dX.view(batch, seq_len, hd), \
None, None, d_QA.t(), d_QB.t(), None, \
None, None, d_KA.t(), d_KB.t(), None, \
None, None, d_VA.t(), d_VB.t(), None
pass
pass
def apply_lora_qkv(self, X):
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(X,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
)
return Q, K, V
pass
class LoRA_W(torch.autograd.Function):
"""
### LoRA weights
Wq = Wq + Aq @ Bq
Wk = Wk + Ak @ Bk
Wv = Wv + Av @ Bv
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
K = X @ Wk = X @ Wk + X @ Ak @ Bk
V = X @ Wv = X @ Wv + X @ Av @ Bv
### Backpropagation chain rule
dC/dWq = X.T @ D(Wq)
dC/dWk = X.T @ D(Wk)
dC/dWv = X.T @ D(Wv)
dC/dX = D(Wq) @ Wq.T
+ D(Wk) @ Wk.T
+ D(Wv) @ Wv.T
### Q projection LoRA weights
dC/dAq = X.T @ D(Wq) @ B.T
dC/dBq = A.T @ X.T @ D(Wq)
### K projection LoRA weights
dC/dAk = X.T @ D(Wk) @ B.T
dC/dBk = A.T @ X.T @ D(Wk)
### V projection LoRA weights
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, X : torch.Tensor,
W, W_quant, A, B, S):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (W, W_quant, S,)
ctx.save_for_backward(A, B, X)
return XW
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY : torch.Tensor):
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
A, B = A.t(), B.t()
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1]) # .view doesn't work on non contiguous
X = X .reshape(-1, X .shape[-1]) # .view doesn't work on non contiguous
dtype = X.dtype
### Weight projection LoRA weights
# dC/dAq = X.T @ D(Wq) @ B.T
# dC/dBq = A.T @ X.T @ D(Wq)
# Weight projection
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S
# dC/dX = D(Wq) @ Wq.T
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), \
None, None, d_A.t(), d_B.t(), None
pass
pass
def apply_lora_o(self, X):
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass

View file

@ -0,0 +1,149 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _rms_layernorm_forward(
Y, Y_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y += row_idx * Y_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1 / tl.sqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * W_row
tl.store(Y + col_offsets, output, mask = mask)
pass
@triton.jit
def _rms_layernorm_backward(
#dX, dX_row_stride,
dY, dY_row_stride,
X, X_row_stride,
W, W_row_stride,
r, r_row_stride,
dW, dW_row_stride,
n_cols, eps,
BLOCK_SIZE : tl.constexpr,
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
#dX += row_idx * dX_row_stride + col_offsets
dY += row_idx * dY_row_stride
X += row_idx * X_row_stride
r += row_idx * r_row_stride
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
# row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
# inv_var = 1 / tl.sqrt(row_var + eps)
inv_var = tl.load(r).to(tl.float32)
normed = X_row * inv_var
dY_W = dY_row * W_row
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
#tl.store(dX, output, mask = mask)
tl.store(dY + col_offsets, output, mask = mask)
pass
class Fast_RMS_Layernorm(torch.autograd.Function):
@staticmethod
def forward(ctx, X, W, eps):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
_rms_layernorm_forward[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
W, W.stride(0),
r, r.stride(0),
n_cols, eps,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, W, r)
return Y.view(*shape)
pass
@staticmethod
def backward(ctx, dY):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
n_rows, n_cols = dY.shape
dW = X
# dX = torch.empty_like(dY)
# dX = dY
_rms_layernorm_backward[(n_rows,)](
#dX, dX.stride(0),
dY, dY.stride(0),
X, X .stride(0),
W, W .stride(0),
r, r .stride(0),
dW, dW.stride(0),
n_cols, ctx.eps,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
#dX = dX.view(*shape)
dX = dY.view(*shape)
# X, W, eps
return dX, None, None
pass
pass
def fast_rms_layernorm(layernorm, X):
W = layernorm.weight
eps = layernorm.variance_epsilon
out = Fast_RMS_Layernorm.apply(X, W, eps)
return out
pass

View file

@ -0,0 +1,178 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
def _rope_embedding(
Q, Q_row_stride,
cos, cos_row_stride,
sin, sin_row_stride,
seqlen, head_dim,
BACKWARD_PASS: tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
row_position = tl.program_id(0)
head_position = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
# TODO: Fixup int32 locations to int64
# https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
rot_position = row_position % seqlen
Q += row_position* Q_row_stride + head_position*head_dim
cos += rot_position*cos_row_stride
sin += rot_position*sin_row_stride
Q1 = tl.load(Q + half_head_dim*0 + col_offsets, mask = mask, other = 0)
sin1 = tl.load(sin + half_head_dim*0 + col_offsets, mask = mask, other = 0)
cos1 = tl.load(cos + half_head_dim*0 + col_offsets, mask = mask, other = 0)
Q2 = tl.load(Q + half_head_dim*1 + col_offsets, mask = mask, other = 0)
# RoPE repeats sin and cos so 128 = [64, 64].
# sin2 = tl.load(sin + half_head_dim*1, mask = mask, other = 0)
# cos2 = tl.load(cos + half_head_dim*1, mask = mask, other = 0)
if BACKWARD_PASS:
"""
Q * cos + rotate_half(Q) * sin
is equivalent to
Q * cos + Q @ R * sin
where R is a rotation matrix [ 0, I]
[-I, 0]
dC/dY = dY * cos + dY @ R.T * sin
where R.T is again the same [ 0, -I]
but the minus is transposed. [ I, 0]
"""
# sin1, sin2 = -sin1, -sin2
sin1 = -sin1
# tl.store(Q + half_head_dim*0, Q1*cos1 - Q2*sin1, mask = mask)
# tl.store(Q + half_head_dim*1, Q2*cos2 + Q1*sin2, mask = mask)
# RoPE repeats sin and cos so 128 = [64, 64].
tl.store(Q + half_head_dim*0 + col_offsets, Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + half_head_dim*1 + col_offsets, Q2*cos1 + Q1*sin1, mask = mask)
pass
class Fast_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin):
cos, sin = cos.squeeze(), sin.squeeze()
batch, seq_len, n_heads, head_dim = Q.shape
Q = Q.view(batch*seq_len, n_heads*head_dim)
n_rows, n_cols = Q.shape
assert(seq_len <= cos.shape[0])
# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
BLOCK_SIZE, num_warps = calculate_settings(head_dim) # (head_dim//2)
_rope_embedding[(n_rows, n_heads,)](
Q, Q.stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.cos = cos # Don't need save_for_backward since a view
ctx.sin = sin
return Q.view(batch, seq_len, n_heads, head_dim)
pass
@staticmethod
def backward(ctx, dY):
batch, seq_len, n_heads, head_dim = dY.shape
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
# Cannot be .view since the problem lies with dK since
# K.T's strides are incorrect.
n_rows, n_cols = dY.shape
cos = ctx.cos
sin = ctx.sin
_rope_embedding[(n_rows, n_heads,)](
dY, dY .stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim,
BACKWARD_PASS = True,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
return dY, None, None,
pass
pass
def fast_rope_embedding(Q, K, cos, sin):
# We need (batch, [seqlen, n_heads], head_dim)
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
# We need (batch, [n_heads, seqlen], head_dim)
return Q, K
pass
class Slow_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin, position_ids):
if position_ids is not None:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# Q * cos + rotate_half(Q) * sin
half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
RH_Q *= sin
Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
pass
@staticmethod
def backward(ctx, dY):
cos, sin = ctx.saved_tensors
# Q * cos + rotate_half.T(Q) * sin
half = dY.shape[-1]//2
# We reverse the minus sign for R.T
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
RH_dY *= sin
dY += RH_dY
return dY, None, None, None
pass
pass
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
return Q, K
pass

88
unsloth/kernels/swiglu.py Normal file
View file

@ -0,0 +1,88 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
import triton.language as tl
import torch
from .utils import calculate_settings
@triton.jit
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
# f = e * sigmoid(e)
# https://github.com/openai/triton/issues/241 exp MUST be done in f32
# or else Triton crashes
f_row = e_row / (1 + tl.exp(-e_row))
# h = f * g
h_row = f_row * g_row
tl.store(h + offsets, h_row, mask = mask)
pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
pass
@triton.jit
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0).to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
# f = e * sigmoid(e)
# https://github.com/openai/triton/issues/241 exp MUST be done in f32
# or else Triton crashes
se_row = 1 / (1 + tl.exp(-e_row))
f_row = e_row * se_row
# h = f * g
h_row = f_row * g_row
# df = se * (1 - f) + f
# DW_f = DW * f
DWf_row = DW_row * f_row
# DW_dfg = DW * df * g
# DW_dfg = DW * (se * (1 - f) + f) * g
# DW_dfg = DW * (se*(g - h) + h)
DW_dfg_row = DW_row * (se_row*(g_row - h_row) + h_row)
tl.store(DW + offsets, h_row, mask = mask) # h
tl.store(e + offsets, DWf_row, mask = mask) # DW * f
tl.store(g + offsets, DW_dfg_row, mask = mask) # DW * df * g
pass
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g # h, DW * f, DW * df * g
pass

93
unsloth/kernels/utils.py Normal file
View file

@ -0,0 +1,93 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import triton
MAX_FUSED_SIZE = 65535 # 2**16 - 1
next_power_of_2 = triton.next_power_of_2
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
# CUDA only supports 65535 - 2^16-1 threads per block
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
import bitsandbytes as bnb
get_ptr = bnb.functional.get_ptr
import ctypes
import torch
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
def QUANT_STATE(W):
return getattr(W, "quant_state", None)
pass
def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = "cuda")
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax)
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
ctypes.c_int(blocksize), ctypes.c_int(out.numel()))
# Careful returning transposed data
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass

View file

@ -0,0 +1,44 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
# Currently only supports 1 GPU, or else seg faults will occur.
reload_package = False
n_gpus = torch.cuda.device_count()
if n_gpus == 0:
raise RuntimeError("Unsloth: Requires at least 1 GPU. Found 0.")
elif n_gpus > 1:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device = os.environ["CUDA_VISIBLE_DEVICES"]
if not device.isdigit():
print(f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {device} "\
"but we require 'CUDA_VISIBLE_DEVICES=0'\n"\
"We shall set it ourselves.")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
reload_package = True
else:
print("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
reload_package = True
pass
# Reload Pytorch with CUDA_VISIBLE_DEVICES
if reload_package:
import importlib
importlib.reload(torch)
pass
from .llama import FastLlamaModel

61
unsloth/models/_utils.py Normal file
View file

@ -0,0 +1,61 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Union, Optional, List, Any, Callable
import numpy as np
import warnings
import gc
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
import bitsandbytes as bnb
def prepare_model_for_kbit_training(
model : Any,
use_gradient_checkpointing : bool = True,
use_reentrant : Optional[bool] = True,
) -> Any:
"""
Calculates where to place the gradient checkpoints given n_layers.
We also freeze all other layers's gradients
Args:
model: Any LlamaModel with layers.
use_gradient_checkpointing (`bool`, *optional*):
Default enabled. Provides memory savings by not saving all activations,
but only some.
use_reentrant (`bool`, *optional*):
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
Optimal gradient checkpointing algorithm which will be the default in
future Pytorch versions.
"""
# Freeze all parameters
for param in model.parameters():
param.requires_grad_(False)
if use_gradient_checkpointing:
model.gradient_checkpointing_enable()
# If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
pass

734
unsloth/models/llama.py Normal file
View file

@ -0,0 +1,734 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional, Tuple, List, Union
from torch.nn.functional import scaled_dot_product_attention
from transformers.models.llama.modeling_llama import (
# apply_rotary_pos_emb,
# repeat_kv,
# _prepare_4d_causal_attention_mask,
logger,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ..kernels import *
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
major_version, minor_version = torch.cuda.get_device_capability()
if major_version >= 8:
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTENTION = True
except:
HAS_FLASH_ATTENTION = False
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
pass
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
# Final patching code
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaForCausalLM,
)
from peft import PeftModelForCausalLM
import gc
import peft
import bitsandbytes as bnb
import numpy as np
import types
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from ._utils import (
prepare_model_for_kbit_training,
)
def original_apply_qkv(self, X):
Q = self.q_proj(X)
K = self.k_proj(X)
V = self.v_proj(X)
return Q, K, V
pass
def original_apply_o(self, X):
O = self.o_proj(X)
return O
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
# Q = self.q_proj(hidden_states)
# K = self.k_proj(hidden_states)
# V = self.v_proj(hidden_states)
Q, K, V = self.apply_qkv(self, hidden_states)
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
# Q, K = apply_rotary_pos_emb(Q, K, cos, sin, position_ids)
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
if past_key_value is not None:
# reuse k, v, self_attention
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
# no_attention_mask = attention_mask is None
# Ignore attention_mask
if (not HAS_FLASH_ATTENTION): #and no_attention_mask:
# Xformers memory efficient attention
# Also has Flash Attention v2 dispatching
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Grouped query attention
if n_groups != 1:
Q = Q.reshape(bsz, q_len, n_groups, n_kv_heads, head_dim)
K = K.reshape(bsz, q_len, n_groups, 1, head_dim)
V = V.reshape(bsz, q_len, n_groups, 1, head_dim)
K = K .expand(bsz, q_len, n_groups, n_kv_heads, head_dim)
V = V .expand(bsz, q_len, n_groups, n_kv_heads, head_dim)
pass
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION:# and no_attention_mask:
# Flash Attention
# (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Flash Attention v2 auto supports grouped query attention
A = flash_attn_func(Q, K, V, causal = True)
else:
# Uses Pytorch's scaled dot product attention
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
pass
# Grouped query attention
# K = repeat_kv(K, n_groups)
# V = repeat_kv(V, n_groups)
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
K = K.reshape(bsz, n_heads, q_len, head_dim)
V = V.reshape(bsz, n_heads, q_len, head_dim)
pass
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = attention_mask is None)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2)
pass
attn_output = A.reshape(bsz, q_len, self.hidden_size)
# attn_output = self.o_proj(attn_output)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def LlamaDecoderLayer_fast_forward(
self,
hidden_states: torch.Tensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# hidden_states = self.input_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
# hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward(
self,
input_ids: torch.LongTensor,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
assert(output_attentions is False)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# We already handle KV cache position_ids ourselves.
if (past_key_values_length != 0):
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype = torch.int32,#dtype=torch.long,
device = "cuda",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
else:
position_ids = None
if position_ids is not None:
if position_ids.shape[0] != batch_size:
position_ids = position_ids.repeat((batch_size, 1))
# embed positions
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Ignore attention_mask
if True:
# if attention_mask is None:
# attention_mask = torch.ones(
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
# )
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
pass
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
pass
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_mask,
attention_mask,
position_ids,
use_reentrant=True,
preserve_rng_state=False,
)
else:
layer_outputs = decoder_layer(
hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
pass
# hidden_states = self.norm(hidden_states)
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
pass
def LlamaForCausalLM_fast_forward(
self,
input_ids: torch.LongTensor = None,
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if causal_mask is None:
causal_mask = xformers.attn_bias.LowerTriangularMask()
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# logits = logits.float()
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# shift_labels = shift_labels.view(-1)
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_logits = logits
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
# loss_fct = torch.nn.CrossEntropyLoss(
# ignore_index = self.ignore_index,
# label_smoothing = self.label_smoothing,
# )
# loss = loss_fct(shift_logits, shift_labels)
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
pass
def PeftModelForCausalLM_fast_forward(
self,
input_ids=None,
causal_mask=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
**kwargs,
):
return self.base_model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
pass
class FastLlamaModel:
@staticmethod
def pre_patch():
LlamaAttention .forward = LlamaAttention_fast_forward
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = LlamaForCausalLM_fast_forward
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
return
pass
@staticmethod
def from_pretrained(
model_name = "meta-llama/Llama-2-7b-hf",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
):
gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
statistics = \
"==((====))== Unsloth: Fast Llama patching release 23.11\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
f"O^O/ \_/ \\ CUDA compute capability = {gpu_stats.major}.{gpu_stats.minor}\n"\
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
f' "-____-" bfloat16 support = {str(SUPPORTS_BFLOAT16).upper()}\n'
print(statistics)
FastLlamaModel.pre_patch()
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# [TODO]: Determine RoPE scaling
# https://github.com/huggingface/transformers/pull/24653
assert(max_seq_length <= 4096)
bnb_config = None
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
quantization_config = bnb_config,
token = token,
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length = max_seq_length,
padding_side = "right",
token = token,
)
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token});
tokenizer.pad_token = tokenizer.unk_token
config = model.config.update({"pad_token_id" : tokenizer.unk_token_id});
model = FastLlamaModel.post_patch(model)
# Patch up QKV / O and MLP
for idx, layer in enumerate(model.model.layers):
layer.self_attn.apply_qkv = original_apply_qkv
layer.self_attn.apply_o = original_apply_o
pass
return model, tokenizer
pass
@staticmethod
def post_patch(model):
# Patch model
layers = model.model.layers
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
correct_dtype = lm_head.weight.dtype
for name, module in model.named_modules():
if isinstance(module, (bnb.nn.Linear4bit, peft.tuners.lora.Linear4bit)):
weight = module.weight
quant_state = weight.quant_state
if type(quant_state) is list:
# BnB seems to have float16 as default!
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
pass
pass
pass
# Clear deleted GPU items
gc.collect()
torch.cuda.empty_cache()
return model
pass
@staticmethod
def get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
layers_to_transform = None,
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = 2048,
):
if lora_dropout != 0:
raise TypeError("Unsloth: Fast Llama patching only works with dropout = 0.")
if bias != "none":
raise TypeError("Unsloth: Fast Llama patching only works with bias = 'none'.")
transformers_set_seed(random_state)
accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",),)
for module in target_modules:
assert(module in accepted_modules)
pass
# Get LoRA
lora_config = LoraConfig(
r = r,
lora_alpha = lora_alpha,
target_modules = target_modules,
lora_dropout = 0,
bias = "none",
task_type = TaskType.CAUSAL_LM,
layers_to_transform = layers_to_transform,
)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = True,
)
model = _get_peft_model(model, lora_config)
# Do patching
for idx, layer in enumerate(model.model.model.layers):
# MLP patching
if hasattr(layer.mlp.gate_proj, "lora_A") and \
hasattr(layer.mlp. up_proj, "lora_A") and \
hasattr(layer.mlp.down_proj, "lora_A"):
# https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
pass
# QKV attention patching
if hasattr(layer.self_attn.q_proj, "lora_A") and \
hasattr(layer.self_attn.k_proj, "lora_A") and \
hasattr(layer.self_attn.v_proj, "lora_A"):
layer.self_attn.apply_qkv = apply_lora_qkv
pass
# O attention patching
if hasattr(layer.self_attn.o_proj, "lora_A"):
layer.self_attn.apply_o = apply_lora_o
pass
pass
# Patch cross entropy loss labels
model.model.extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda")
return model
pass
pass