2-4x faster native HF inference (#119)

* faster saving & inference

* Update llama.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update llama.py

* Update save.py

* Update llama.py

* Mistral correct RoPE scaling

* Max sequence lengths

* Apache 2

* fast_linear_forward

* Update utils.py

* Update utils.py

* No print

* Update utils.py

* Update utils.py

* inference

* Update llama.py

* Fast inference RoPE

* Update llama.py

* Update llama.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* LoRA

* Fast LoRA saving
This commit is contained in:
Daniel Han 2024-01-23 03:55:24 +11:00 committed by GitHub
parent 3a9b2dee98
commit 04f8771821
8 changed files with 291 additions and 89 deletions

View file

@ -22,4 +22,4 @@ from .fast_lora import (
apply_lora_qkv,
apply_lora_o,
)
from .utils import fast_dequantize, QUANT_STATE
from .utils import fast_dequantize, QUANT_STATE, fast_linear_forward

View file

@ -13,28 +13,10 @@
# limitations under the License.
import torch
from .utils import fast_dequantize, QUANT_STATE
from .utils import fast_dequantize, QUANT_STATE, get_lora_parameters
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)

View file

@ -134,8 +134,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
RH_Q *= sin
Q += RH_Q
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
# Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
pass
@ -147,8 +148,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
half = dY.shape[-1]//2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
RH_dY *= sin
dY += RH_dY
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
# dY += RH_dY
return dY, None, None, None
pass
pass

View file

@ -33,14 +33,36 @@ import bitsandbytes as bnb
get_ptr = bnb.functional.get_ptr
import ctypes
import torch
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
def QUANT_STATE(W):
return getattr(W, "quant_state", None)
pass
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
pass
active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
@ -90,3 +112,85 @@ def fast_dequantize(W, quant_state = None, out = None):
is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
pass
def fast_gemv(X, W, quant_state, out = None, out_W = None):
quant_state = W.quant_state
bsz = 1
q_len = 1
hd = X.shape[0]
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
stats = quant_state.code
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
bout = shape[0]
if out is None: out = torch.empty(bout, dtype = dtype, device = "cuda")
else: assert(out.shape[0] == bout)
n = 1
m = shape[0]
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (X.shape[-1]+1)//2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
lda = ctypes.c_int32(lda)
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
)
df += offset
absmax = df
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
cgemm_4bit_inference_naive_bf16
ptr_W = get_ptr(W)
ptr_absmax = get_ptr(absmax)
ptr_stats = get_ptr(stats)
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), ptr_W, ptr_absmax, ptr_stats, get_ptr(out),
lda, ldb, ldc, blocksize)
return out
pass
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
out = fast_gemv(X, W, W_quant, out = out)
if lora_A is not None:
# Save LoRAs for inference to stop data movement costs
if not hasattr(lora_A, "_fast_lora"):
dtype = X.dtype
lora_A._fast_lora = lora_A.to(dtype).t()
lora_B._fast_lora = lora_B.to(dtype)
pass
temp_lora = torch.matmul(X, lora_A._fast_lora, out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
pass
return out
pass

View file

@ -68,6 +68,7 @@ def original_apply_o(self, X):
pass
from math import sqrt as math_sqrt
def LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
@ -102,60 +103,104 @@ def LlamaAttention_fast_forward_inference(
This means we can pass in a row of Q, but we need to
remember K and V, which are called the KV cache.
"""
Xn = hidden_states
bsz, _, _ = hidden_states.size()
K1, V1 = past_key_value
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
# assert(n_kv_heads * n_groups == n_heads)
Qn = self.q_proj(Xn)
Kn = self.k_proj(Xn)
Vn = self.v_proj(Xn)
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Xn = hidden_states.view(self.hidden_size)
K1, V1 = past_key_value
seq_len = K1.shape[-2]
K1 = K1.view(n_kv_heads, seq_len, head_dim)
V1 = V1.view(n_kv_heads, seq_len, head_dim)
kv_seq_len = K1.shape[-2] + 1
cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
# LoRA or general matrix multiplication
dtype = Xn.dtype
# Qn = self.q_proj(Xn)
# Kn = self.k_proj(Xn)
# Vn = self.v_proj(Xn)
Qn = fast_linear_forward(self.q_proj, Xn)
Kn = fast_linear_forward(self.k_proj, Xn)
Vn = fast_linear_forward(self.v_proj, Xn)
# Qn = Qn.view(1, 1, n_heads, head_dim).transpose(1, 2)
# Kn = Kn.view(1, 1, n_kv_heads, head_dim).transpose(1, 2)
# Vn = Vn.view(1, 1, n_kv_heads, head_dim).transpose(1, 2)
Qn = Qn.view(n_heads, 1, head_dim)
Kn = Kn.view(n_kv_heads, 1, head_dim)
Vn = Vn.view(n_kv_heads, 1, head_dim)
# kv_seq_len = K1.shape[-2] + 1
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
cos = self.rotary_emb.cos_cached[seq_len]
sin = self.rotary_emb.sin_cached[seq_len]
h = head_dim // 2
RH_Q = torch.empty((n_heads, 1, head_dim), dtype = dtype, device = "cuda")
RH_Q[:, :, :h] = Qn[:, :, h:]; RH_Q[:, :, h:] = Qn[:, :, :h]; torch.neg(RH_Q[:, :, :h], out = RH_Q[:, :, :h]);
Qn *= cos; Qn.addcmul_(RH_Q, sin);
RH_K = RH_Q[:n_kv_heads, :, :] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda")
RH_K[:, :, :h] = Kn[:, :, h:]; RH_K[:, :, h:] = Kn[:, :, :h]; torch.neg(RH_K[:, :, :h], out = RH_K[:, :, :h]);
Kn *= cos; Kn.addcmul_(RH_K, sin);
# New KV cache
Kn = torch.cat([K1, Kn], dim = 2)
Vn = torch.cat([V1, Vn], dim = 2)
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
Kn = torch.cat([K1, Kn], dim = 1)
Vn = torch.cat([V1, Vn], dim = 1)
# Grouped query attention
if n_groups != 1:
_, _, cached_len, _ = Kn.shape
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
# _, _, cached_len, _ = Kn.shape
# Knn = Kn[:, :, None, :, :].expand(1, n_kv_heads, n_groups, cached_len, head_dim)
# Vnn = Vn[:, :, None, :, :].expand(1, n_kv_heads, n_groups, cached_len, head_dim)
# Knn = Knn.reshape(1, n_heads, cached_len, head_dim)
# Vnn = Vnn.reshape(1, n_heads, cached_len, head_dim)
new_seq_len = seq_len + 1
Knn = Kn[:, None, :, :].expand(n_kv_heads, n_groups, new_seq_len, head_dim)
Vnn = Vn[:, None, :, :].expand(n_kv_heads, n_groups, new_seq_len, head_dim)
Knn = Knn.reshape(n_heads, new_seq_len, head_dim)
Vnn = Vnn.reshape(n_heads, new_seq_len, head_dim)
else:
Knn, Vnn = Kn, Vn
# Attention
A = torch.matmul(Qn, Knn.transpose(2, 3))
A *= 1.0 / (self.head_dim**0.5)
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(A.dtype)
A = torch.matmul(A, Vnn)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, self.hidden_size)
A = self.o_proj(A)
return A, (Kn, Vn)
# A = torch.matmul(Qn, Knn.transpose(2, 3))
A = torch.matmul(Qn, Knn.transpose(1, 2))
A *= 1.0 / math_sqrt(self.head_dim)
A[:] = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)
# A = A.transpose(1, 2)
A = A.view(self.hidden_size)
# A = self.o_proj(A)
A = fast_linear_forward(self.o_proj, A)
A = A.reshape(1, 1, self.hidden_size)
# return A, (Kn, Vn)
return A, (Kn.unsqueeze(0), Vn.unsqueeze(0))
pass
torch_silu = torch.nn.functional.silu
def fast_mlp_inference(self, X):
gate = self.gate_proj(X)
up = self.up_proj(X)
hidden_size = self.hidden_size
X = X.view(hidden_size)
# gate = self.gate_proj(X)
# up = self.up_proj(X)
gate = fast_linear_forward(self.gate_proj, X)
up = fast_linear_forward(self. up_proj, X)
gate = torch_silu(gate, inplace = True)
gate *= up
X = self.down_proj(gate)
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:hidden_size])
X = down.view(1, 1, hidden_size)
return X
pass
@ -676,10 +721,10 @@ class FastLlamaModel:
statistics = \
f"==((====))== Unsloth: Fast Llama patching release {__version__}\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
f"O^O/ \_/ \\ CUDA capability = {gpu_stats.major}.{gpu_stats.minor}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform_system}\n'
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
f' "-____-" Apache 2 free license: http://github.com/unslothai/unsloth'
logger.warning_once(statistics)
FastLlamaModel.pre_patch()
@ -731,7 +776,7 @@ class FastLlamaModel:
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length = max_seq_length,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
@ -760,7 +805,7 @@ class FastLlamaModel:
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = max_seq_length,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
@ -1076,4 +1121,47 @@ class FastLlamaModel:
internal_model.max_seq_length = max_seq_length
return model
pass
@staticmethod
def for_inference(model):
if not hasattr(model, "_original_forward"):
model._original_forward = model.forward
pass
model.forward = torch.inference_mode(model._original_forward)
internal_model = model
internal_model.gradient_checkpointing = False
internal_model.training = False
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = False
internal_model.training = False
pass
pass
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
if hasattr(model, "_original_forward"):
model.forward = model._original_forward
pass
internal_model = model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
pass
while hasattr(internal_model, "model"):
internal_model = internal_model.model
internal_model.gradient_checkpointing = use_gradient_checkpointing
internal_model.training = True
pass
pass
pass

View file

@ -42,6 +42,12 @@ __INT_TO_FLOAT_MAPPER = \
"unsloth/tinyllama",
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
),
"unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : (
"mistralai/Mistral-7B-Instruct-v0.1",
),
"unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : (
"mistralai/Mistral-7B-Instruct-v0.2",
),
}
INT_TO_FLOAT_MAPPER = {}

View file

@ -132,7 +132,7 @@ def MistralAttention_fast_forward(
K = K.transpose(1, 2)
V = V.transpose(1, 2)
sw = getattr(self.config, "sliding_window", None)
sw = q_len if sw is None else sw
sw = q_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (q_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
@ -176,7 +176,7 @@ def MistralForCausalLM_fast_forward(
if causal_mask is None:
bsz, q_len = input_ids.shape
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is None or sliding_window <= 0:
if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
causal_mask = xformers.attn_bias.LowerTriangularMask()
elif q_len <= sliding_window:
causal_mask = xformers.attn_bias.LowerTriangularMask()
@ -265,9 +265,11 @@ class FastMistralModel(FastLlamaModel):
rope_scaling = None, # Mistral does not support RoPE scaling
fix_tokenizer = True,
**kwargs,
):
):
# Mistral does NOT support RoPE Scaling!
if rope_scaling is not None:
logger.warning_once("Unsloth: Mistral models do not support RoPE scaling.")
pass
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
@ -275,10 +277,10 @@ class FastMistralModel(FastLlamaModel):
statistics = \
f"==((====))== Unsloth: Fast Mistral patching release {__version__}\n"\
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
f"O^O/ \_/ \\ CUDA capability = {gpu_stats.major}.{gpu_stats.minor}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform_system}\n'
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
f' "-____-" Apache 2 free license: http://github.com/unslothai/unsloth'
logger.warning_once(statistics)
FastMistralModel.pre_patch()
@ -290,6 +292,18 @@ class FastMistralModel(FastLlamaModel):
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# Check max sequence length
model_config = AutoConfig.from_pretrained(model_name, token = token)
model_max_seq_length = model_config.max_position_embeddings
# Mistral does NOT support RoPE Scaling sadly so we have to error out.
if max_seq_length > model_max_seq_length:
raise RuntimeError(
"Unsloth: Unfortunately Mistral type models do not support RoPE scaling!\n"\
f"The maximum sequence length supported is {model_max_seq_length}.",
)
pass
bnb_config = None
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
@ -299,20 +313,21 @@ class FastMistralModel(FastLlamaModel):
bnb_4bit_compute_dtype = dtype,
)
max_position_embeddings = max(max_seq_length, model_max_seq_length)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = device_map,
torch_dtype = dtype,
device_map = device_map,
torch_dtype = dtype,
quantization_config = bnb_config,
token = token,
# rope_scaling = rope_scaling,
token = token,
# rope_scaling = rope_scaling,
**kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length = max_seq_length,
padding_side = "right",
token = token,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
@ -337,12 +352,12 @@ class FastMistralModel(FastLlamaModel):
# We check the tokenizer first for errors
if fix_tokenizer:
tokenizer = check_tokenizer(
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = max_seq_length,
padding_side = "right",
token = token,
model = model,
tokenizer = tokenizer,
model_name = model_name,
model_max_length = max_position_embeddings,
padding_side = "right",
token = token,
)
pass
patch_saving_functions(tokenizer)

View file

@ -77,10 +77,14 @@ def _merge_lora(layer, name):
W, quant_state, A, B, s = get_lora_parameters(layer)
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
W = fast_dequantize(W, quant_state).to(torch.float32).t()
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
W += sAB
if not torch.isfinite(W).all():
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
if A is not None:
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
W += sAB
if not torch.isfinite(W).all():
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
pass
W = W.t().to(dtype)
else:
W = layer.weight
@ -156,6 +160,7 @@ def unsloth_save_model(
pass
if save_method == "merged_4bit":
print("Unsloth: Merging 4bit and LoRA weights to 4bit...")
print("This might take 5 minutes...")
model = model.merge_and_unload()