Fix bugs + more accurate Swiglu (#137)

* faster saving & inference

* Update llama.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update llama.py

* Update save.py

* Update llama.py

* Mistral correct RoPE scaling

* Max sequence lengths

* Apache 2

* fast_linear_forward

* Update utils.py

* Update utils.py

* No print

* Update utils.py

* Update utils.py

* inference

* Update llama.py

* Fast inference RoPE

* Update llama.py

* Update llama.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* LoRA

* Fast LoRA saving

* Update llama.py

* hidden_states

* q_len == 1

* q_len issue

* Update mistral.py

* Update mistral.py

* incorrect inference

* Update to transformers 4.37

* Graceful FA2 error + torch 2.1.1

* Update mapper.py

* Update pyproject.toml

* Fix saving and bnb-4bit

* Update fast_lora.py

* Update fast_lora.py

* remove patching

* Update llama.py

* Update llama.py

* Update swiglu.py

* Repatch

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update save.py

* Update fast_lora.py

* Update utils.py

* Update llama.py

* Update fast_lora.py

* Update swiglu.py

* Update save.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Revert "Update llama.py"

This reverts commit a208ec46e0.

* Update llama.py

* Works?

* Update pyproject.toml

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Swiglu

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* attention_mask

* Update llama.py

* Update llama.py

* labels

* Update mistral.py

* Update llama.py

* attention mask
This commit is contained in:
Daniel Han 2024-01-28 04:20:06 +11:00 committed by GitHub
parent a81aff286f
commit e2bbd3819e
4 changed files with 67 additions and 39 deletions

View file

@ -36,9 +36,9 @@ huggingface = [
"transformers>=4.37.0",
"datasets",
"sentencepiece",
"accelerate",
"accelerate>=0.26.1",
"trl>=0.7.9",
"peft",
"peft>=0.7.1",
"tqdm",
"psutil",
]

View file

@ -90,6 +90,8 @@ class LoRA_MLP(torch.autograd.Function):
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
# f = torch.nn.functional.silu(e)
# h = f * g
h = swiglu_fg_kernel(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
@ -103,6 +105,7 @@ class LoRA_MLP(torch.autograd.Function):
return i
pass
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY : torch.Tensor):
@ -121,11 +124,16 @@ class LoRA_MLP(torch.autograd.Function):
g = g .view(-1, g .shape[-1])
dtype = X.dtype
# DW_f = (D @ W.T * f)
# DW_dfg = (D @ W.T * df * g)
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
# f = (se * e).to(dtype)
# h = f * g
# df = DW * f
# dg = DW * g
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g)
h, DW_f, DW_dfg = DW, e, g
h, df, de = DW, e, g
# Down projection LoRA weights
d_downA = h.t() @ (dY @ downB.t())
@ -134,31 +142,29 @@ class LoRA_MLP(torch.autograd.Function):
d_downB *= downS
# Up projection LoRA weights
d_upA = X.t() @ (DW_f @ upB.t())
d_upB = (upA.t() @ X.t()) @ DW_f
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS
# Gate projection LoRA weights
d_gateA = X.t() @ (DW_dfg @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ DW_dfg
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS
# Final derivatives to backpropagate backwards.
# See our blogpost for more details.
# (D @ W.T * f) @ U.T
upW = fast_dequantize(upW.t(), upW_quant)
# (D @ W.T * f) @ (U.T + B.T @ A.T)
dX = torch.matmul(DW_f, upW.t(), out = X)
del upW
dX += DW_f @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
# And add the derivative for the gate projection
gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += DW_dfg @ gateW.t()
dX += de @ gateW.t()
del gateW
dX += DW_dfg @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
@ -172,6 +178,11 @@ pass
def apply_lora_mlp(self, X):
# gate = self.gate_proj(X)
# up = self. up_proj(X)
# h = torch.nn.functional.silu(gate) * up
# down = self.down_proj(h)
# return down
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)

View file

@ -28,7 +28,7 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row / (1 + tl.exp(-e_row))
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row
@ -50,30 +50,43 @@ pass
@triton.jit
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
e = e.float()
se = 1.0 / (1.0 + torch.exp(-e))
f = (se * e).to(dtype)
h = f * g
df = DW * f
dg = DW * g
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# f = e * sigmoid(e)
se_row = 1 / (1 + tl.exp(-e_row.to(tl.float32)))
se_row = se_row.to(e_row.dtype) # Exact copy from HF
# f = e * se
f_row = e_row * se_row
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
# f = (se * e).to(dtype)
f_row = se_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
h_row = f_row * g_row
# DW_f = DW * f
DWf_row = DW_row * f_row
# DW_dfg = DW * (se*(g - h) + h)
DW_dfg_row = DW_row * (se_row*(g_row - h_row) + h_row)
h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask = mask)
tl.store(e + offsets, DWf_row, mask = mask)
tl.store(g + offsets, DW_dfg_row, mask = mask)
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
pass

View file

@ -14,6 +14,7 @@
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
from peft.tuners.lora import Linear as Peft_Linear
from typing import Optional, Callable, Union, List
import torch
import os
@ -72,11 +73,15 @@ pass
def _merge_lora(layer, name):
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit)):
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
# Is LoRA so we need to merge!
W, quant_state, A, B, s = get_lora_parameters(layer)
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
W = fast_dequantize(W, quant_state).to(torch.float32).t()
if quant_state is not None:
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
W = fast_dequantize(W, quant_state)
pass
W = W.to(torch.float32).t()
if A is not None:
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
@ -84,7 +89,6 @@ def _merge_lora(layer, name):
if not torch.isfinite(W).all():
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
pass
W = W.t().to(dtype)
else:
W = layer.weight