mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix bugs + more accurate Swiglu (#137)
* faster saving & inference
* Update llama.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* fast inference
* Update llama.py
* Update save.py
* Update llama.py
* Mistral correct RoPE scaling
* Max sequence lengths
* Apache 2
* fast_linear_forward
* Update utils.py
* Update utils.py
* No print
* Update utils.py
* Update utils.py
* inference
* Update llama.py
* Fast inference RoPE
* Update llama.py
* Update llama.py
* RoPE
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* LoRA
* Fast LoRA saving
* Update llama.py
* hidden_states
* q_len == 1
* q_len issue
* Update mistral.py
* Update mistral.py
* incorrect inference
* Update to transformers 4.37
* Graceful FA2 error + torch 2.1.1
* Update mapper.py
* Update pyproject.toml
* Fix saving and bnb-4bit
* Update fast_lora.py
* Update fast_lora.py
* remove patching
* Update llama.py
* Update llama.py
* Update swiglu.py
* Repatch
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update llama.py
* Update fast_lora.py
* Update llama.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update save.py
* Update fast_lora.py
* Update utils.py
* Update llama.py
* Update fast_lora.py
* Update swiglu.py
* Update save.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Revert "Update llama.py"
This reverts commit a208ec46e0.
* Update llama.py
* Works?
* Update pyproject.toml
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Swiglu
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* attention_mask
* Update llama.py
* Update llama.py
* labels
* Update mistral.py
* Update llama.py
* attention mask
This commit is contained in:
parent
a81aff286f
commit
e2bbd3819e
4 changed files with 67 additions and 39 deletions
|
|
@ -36,9 +36,9 @@ huggingface = [
|
|||
"transformers>=4.37.0",
|
||||
"datasets",
|
||||
"sentencepiece",
|
||||
"accelerate",
|
||||
"accelerate>=0.26.1",
|
||||
"trl>=0.7.9",
|
||||
"peft",
|
||||
"peft>=0.7.1",
|
||||
"tqdm",
|
||||
"psutil",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
|
||||
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
|
||||
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
|
||||
# f = torch.nn.functional.silu(e)
|
||||
# h = f * g
|
||||
h = swiglu_fg_kernel(e, g)
|
||||
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
|
||||
|
||||
|
|
@ -103,6 +105,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
return i
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dY : torch.Tensor):
|
||||
|
|
@ -121,11 +124,16 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
g = g .view(-1, g .shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# DW_f = (D @ W.T * f)
|
||||
# DW_dfg = (D @ W.T * df * g)
|
||||
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
||||
# e = e.float()
|
||||
# se = 1.0 / (1.0 + torch.exp(-e))
|
||||
# f = (se * e).to(dtype)
|
||||
# h = f * g
|
||||
# df = DW * f
|
||||
# dg = DW * g
|
||||
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
||||
DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g)
|
||||
h, DW_f, DW_dfg = DW, e, g
|
||||
h, df, de = DW, e, g
|
||||
|
||||
# Down projection LoRA weights
|
||||
d_downA = h.t() @ (dY @ downB.t())
|
||||
|
|
@ -134,31 +142,29 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
d_downB *= downS
|
||||
|
||||
# Up projection LoRA weights
|
||||
d_upA = X.t() @ (DW_f @ upB.t())
|
||||
d_upB = (upA.t() @ X.t()) @ DW_f
|
||||
d_upA = X.t() @ (df @ upB.t())
|
||||
d_upB = (upA.t() @ X.t()) @ df
|
||||
d_upA *= upS
|
||||
d_upB *= upS
|
||||
|
||||
# Gate projection LoRA weights
|
||||
d_gateA = X.t() @ (DW_dfg @ gateB.t())
|
||||
d_gateB = (gateA.t() @ X.t()) @ DW_dfg
|
||||
d_gateA = X.t() @ (de @ gateB.t())
|
||||
d_gateB = (gateA.t() @ X.t()) @ de
|
||||
d_gateA *= gateS
|
||||
d_gateB *= gateS
|
||||
|
||||
# Final derivatives to backpropagate backwards.
|
||||
# See our blogpost for more details.
|
||||
# (D @ W.T * f) @ U.T
|
||||
upW = fast_dequantize(upW.t(), upW_quant)
|
||||
# (D @ W.T * f) @ (U.T + B.T @ A.T)
|
||||
dX = torch.matmul(DW_f, upW.t(), out = X)
|
||||
del upW
|
||||
dX += DW_f @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
||||
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
|
||||
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
|
||||
|
||||
upW = fast_dequantize(upW.t(), upW_quant)
|
||||
dX = torch.matmul(df, upW.t(), out = X)
|
||||
del upW
|
||||
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
||||
|
||||
# And add the derivative for the gate projection
|
||||
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
||||
dX += DW_dfg @ gateW.t()
|
||||
dX += de @ gateW.t()
|
||||
del gateW
|
||||
dX += DW_dfg @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
||||
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
||||
|
||||
# gateW, gateW_quant, gateA, gateB, gateS,
|
||||
# upW, upW_quant, upA, upB, upS,
|
||||
|
|
@ -172,6 +178,11 @@ pass
|
|||
|
||||
|
||||
def apply_lora_mlp(self, X):
|
||||
# gate = self.gate_proj(X)
|
||||
# up = self. up_proj(X)
|
||||
# h = torch.nn.functional.silu(gate) * up
|
||||
# down = self.down_proj(h)
|
||||
# return down
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
|||
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
f_row = e_row / (1 + tl.exp(-e_row))
|
||||
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
|
||||
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
|
|
@ -50,30 +50,43 @@ pass
|
|||
|
||||
@triton.jit
|
||||
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
||||
"""
|
||||
e = e.float()
|
||||
se = 1.0 / (1.0 + torch.exp(-e))
|
||||
f = (se * e).to(dtype)
|
||||
h = f * g
|
||||
df = DW * f
|
||||
dg = DW * g
|
||||
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
se_row = 1 / (1 + tl.exp(-e_row.to(tl.float32)))
|
||||
se_row = se_row.to(e_row.dtype) # Exact copy from HF
|
||||
# f = e * se
|
||||
f_row = e_row * se_row
|
||||
# e = e.float()
|
||||
# se = 1.0 / (1.0 + torch.exp(-e))
|
||||
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
|
||||
# f = (se * e).to(dtype)
|
||||
f_row = se_row * e_row
|
||||
f_row = f_row.to(DW_row.dtype)
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
# DW_f = DW * f
|
||||
DWf_row = DW_row * f_row
|
||||
# DW_dfg = DW * (se*(g - h) + h)
|
||||
DW_dfg_row = DW_row * (se_row*(g_row - h_row) + h_row)
|
||||
h_row = f_row * g_row
|
||||
# df = DW * f
|
||||
df_row = DW_row * f_row
|
||||
# dg = DW * g
|
||||
dg_row = DW_row * g_row
|
||||
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
||||
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
|
||||
de_row = de_row.to(DW_row.dtype)
|
||||
|
||||
# Store derivatives in buffers
|
||||
tl.store(DW + offsets, h_row, mask = mask)
|
||||
tl.store(e + offsets, DWf_row, mask = mask)
|
||||
tl.store(g + offsets, DW_dfg_row, mask = mask)
|
||||
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
||||
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
||||
tl.store(g + offsets, de_row, mask = mask) # de
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
|
||||
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
|
||||
from peft.tuners.lora import Linear as Peft_Linear
|
||||
from typing import Optional, Callable, Union, List
|
||||
import torch
|
||||
import os
|
||||
|
|
@ -72,11 +73,15 @@ pass
|
|||
|
||||
|
||||
def _merge_lora(layer, name):
|
||||
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit)):
|
||||
|
||||
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
|
||||
# Is LoRA so we need to merge!
|
||||
W, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
|
||||
W = fast_dequantize(W, quant_state).to(torch.float32).t()
|
||||
if quant_state is not None:
|
||||
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
|
||||
W = fast_dequantize(W, quant_state)
|
||||
pass
|
||||
W = W.to(torch.float32).t()
|
||||
|
||||
if A is not None:
|
||||
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
|
||||
|
|
@ -84,7 +89,6 @@ def _merge_lora(layer, name):
|
|||
if not torch.isfinite(W).all():
|
||||
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
|
||||
pass
|
||||
|
||||
W = W.t().to(dtype)
|
||||
else:
|
||||
W = layer.weight
|
||||
|
|
|
|||
Loading…
Reference in a new issue