Hotfix - fix inference (#146)

* faster saving & inference

* Update llama.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update llama.py

* Update save.py

* Update llama.py

* Mistral correct RoPE scaling

* Max sequence lengths

* Apache 2

* fast_linear_forward

* Update utils.py

* Update utils.py

* No print

* Update utils.py

* Update utils.py

* inference

* Update llama.py

* Fast inference RoPE

* Update llama.py

* Update llama.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* LoRA

* Fast LoRA saving

* Update llama.py

* hidden_states

* q_len == 1

* q_len issue

* Update mistral.py

* Update mistral.py

* incorrect inference

* Update to transformers 4.37

* Graceful FA2 error + torch 2.1.1

* Update mapper.py

* Update pyproject.toml

* Fix saving and bnb-4bit

* Update fast_lora.py

* Update fast_lora.py

* remove patching

* Update llama.py

* Update llama.py

* Update swiglu.py

* Repatch

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update save.py

* Update fast_lora.py

* Update utils.py

* Update llama.py

* Update fast_lora.py

* Update swiglu.py

* Update save.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Revert "Update llama.py"

This reverts commit a208ec46e0.

* Update llama.py

* Works?

* Update pyproject.toml

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Swiglu

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* attention_mask

* Update llama.py

* Update llama.py

* labels

* Update mistral.py

* Update llama.py

* attention mask

* Update save.py

* Update save.py

* Update mistral.py

* attention mask

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update dpo.py

* Patch saving

* Update save.py

* Update save.py

* patch_saving_functions

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* print

* Mistral patch

* Update mistral.py

* Update save.py

* saving

* Update llama.py

* Update llama.py

* Fast inference repatch

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update mistral.py

* Update __init__.py

* Fix inference

* Update mistral.py

* fast lm_head

* Remove fast path

* Update rope_embedding.py

* Update loader.py

* LlamaAttention_fast_forward_inference

* if past_key_value is not None and q_len == 1:

* revert inference

* Update loader.py

* past_key_value
This commit is contained in:
Daniel Han 2024-01-31 04:03:37 +11:00 committed by GitHub
parent a3a2ad9382
commit 2f55935f94
5 changed files with 159 additions and 50 deletions

View file

@ -22,4 +22,4 @@ from .fast_lora import (
apply_lora_qkv,
apply_lora_o,
)
from .utils import fast_dequantize, QUANT_STATE, fast_linear_forward
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward

View file

@ -134,9 +134,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
# Q += RH_Q
# Q.addcmul_(RH_Q, sin)
RH_Q *= sin
Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
pass
@ -148,9 +148,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
half = dY.shape[-1]//2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
# dY += RH_dY
# dY.addcmul_(RH_dY, sin)
RH_dY *= sin
dY += RH_dY
return dY, None, None, None
pass
pass

View file

@ -114,11 +114,12 @@ def fast_dequantize(W, quant_state = None, out = None):
pass
def fast_gemv(X, W, quant_state, out = None, out_W = None):
quant_state = W.quant_state
bsz = 1
q_len = 1
hd = X.shape[0]
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None: return torch.matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
bsz, q_len, hd = X.shape
assert(q_len == 1)
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
@ -137,9 +138,14 @@ def fast_gemv(X, W, quant_state, out = None, out_W = None):
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
assert(dtype == X.dtype)
bout = shape[0]
if out is None: out = torch.empty(bout, dtype = dtype, device = "cuda")
else: assert(out.shape[0] == bout)
if out is None:
out = torch.empty((bsz, 1, bout,), dtype = dtype, device = "cuda")
else:
assert(out.shape == (bsz, 1, bout,))
pass
n = 1
m = shape[0]
@ -170,30 +176,46 @@ def fast_gemv(X, W, quant_state, out = None, out_W = None):
ptr_stats = get_ptr(stats)
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), ptr_W, ptr_absmax, ptr_stats, get_ptr(out),
lda, ldb, ldc, blocksize)
for row in range(bsz):
fx(m, n, k, get_ptr(X[row]), ptr_W, ptr_absmax, ptr_stats, get_ptr(out[row]),
lda, ldb, ldc, blocksize)
pass
return out
pass
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
bsz, _, in_dim = X.shape
if W_quant is None:
out = torch.matmul(X, W.t())
else:
elif bsz <= 4:
# Only batches of 4 are faster with Gemv
out = fast_gemv(X, W, W_quant, out = out)
if lora_A is not None:
# Save LoRAs for inference to stop data movement costs
if not hasattr(lora_A, "_fast_lora"):
dtype = X.dtype
lora_A._fast_lora = lora_A.to(dtype).t()
lora_B._fast_lora = lora_B.to(dtype)
pass
temp_lora = torch.matmul(X, lora_A._fast_lora, out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
W = fast_dequantize(W.t(), W_quant)
out = torch.matmul(X, W, out = out)
pass
# Add in LoRA weights
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A.to(dtype), X.ravel(), out = temp_lora)
out.addmv_(lora_B.to(dtype), temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A.to(dtype).t(), out = temp_lora)
out.addmm_(temp_lora, lora_B.to(dtype).t(), alpha = lora_S)
pass
out = out.view(bsz, 1, out_dim)
pass
return out
pass

View file

@ -69,7 +69,7 @@ pass
from math import sqrt as math_sqrt
def LlamaAttention_fast_forward_inference(
def _LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
@ -185,11 +185,89 @@ def LlamaAttention_fast_forward_inference(
pass
def LlamaAttention_fast_forward_inference(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
Fast inference using KV cache.
QK^T can be computed in 4 chunks
[Q, q] @ [K, k].T where q, k are the new tokens.
[QK^T, Qk^T]
[qK^T, qk^T]
Since the attention mask wipes Qk^T, we just get
[QK^T, 0]
[qK^T, qk^T]
Since softmax is row-wise, we get
softmax([QK^T, 0])
softmax([qK^T, qk^T])
We then multiply by [V]
[v]
softmax([QK^T, 0]) [softmax(QK^T)V] *
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
But notice * [softmax(QK^T)V] is just the last attention.
We just need to compute the last final row.
This means we can pass in a row of Q, but we need to
remember K and V, which are called the KV cache.
"""
Xn = hidden_states
bsz, _, _ = hidden_states.size()
K1, V1 = past_key_value
n_heads = self.num_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
Qn = self.q_proj(Xn)
Kn = self.k_proj(Xn)
Vn = self.v_proj(Xn)
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
kv_seq_len = K1.shape[-2] + 1
cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
# New KV cache
Kn = torch.cat([K1, Kn], dim = 2)
Vn = torch.cat([V1, Vn], dim = 2)
# Grouped query attention
if n_groups != 1:
_, _, cached_len, _ = Kn.shape
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
else:
Knn, Vnn = Kn, Vn
# Attention
A = torch.matmul(Qn, Knn.transpose(2, 3))
A *= 1.0 / (self.head_dim**0.5)
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(A.dtype)
A = torch.matmul(A, Vnn)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, self.hidden_size)
A = original_apply_o(self, A)
return A, (Kn, Vn)
pass
torch_silu = torch.nn.functional.silu
def fast_mlp_inference(self, X):
hidden_size = self.hidden_size
X = X.view(hidden_size)
# gate = self.gate_proj(X)
# up = self.up_proj(X)
gate = fast_linear_forward(self.gate_proj, X)
@ -198,20 +276,18 @@ def fast_mlp_inference(self, X):
gate *= up
# X = self.down_proj(gate)
down = fast_linear_forward(self.down_proj, gate, out = up[:hidden_size])
X = down.view(1, 1, hidden_size)
return X
down = fast_linear_forward(self.down_proj, gate)
return down
pass
def fast_rms_layernorm_inference(self, X):
old_dtype = X.dtype
X = X.to(torch.float32)
variance = X.square().mean(-1, keepdim = True)
XX = X.to(torch.float32)
variance = XX.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
X *= variance.rsqrt_()
X = X.to(old_dtype)
XX *= variance.rsqrt_()
X = XX.to(old_dtype) # Must preserve due to residual
X *= self.weight
return X
pass
@ -234,7 +310,7 @@ def LlamaAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
# Check for inference
if False: #past_key_value is not None and q_len == 1 and bsz == 1:
if past_key_value is not None:
A, past_key_value = LlamaAttention_fast_forward_inference(
self,
hidden_states,
@ -350,7 +426,7 @@ def LlamaDecoderLayer_fast_forward(
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
bsz, q_len, hd = hidden_states.size()
if False: #(past_key_value is not None and q_len == 1 and bsz == 1):
if past_key_value is not None:
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
@ -488,8 +564,7 @@ def LlamaModel_fast_forward(
# Fix up attention mask by setting elements to 0
# Specifically for DPO
if self._has_no_labels and attention_mask is not None and \
attention_mask.shape[1] == seq_length:
if self._has_no_labels and (attention_mask is not None) and (past_key_values is None):
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
inputs_requires_grad = inputs_embeds.requires_grad
@ -501,7 +576,7 @@ def LlamaModel_fast_forward(
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
elif False:
attention_mask = None
padding_mask = None
else:
@ -522,7 +597,7 @@ def LlamaModel_fast_forward(
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if past_key_values is None and self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
@ -581,7 +656,7 @@ def LlamaModel_fast_forward(
pass
bsz, q_len, hd = hidden_states.size()
if (past_key_value is not None and q_len == 1):
if past_key_values is not None:
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
@ -644,7 +719,13 @@ def LlamaForCausalLM_fast_forward(
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
bsz, q_len, hd = hidden_states.shape
if bsz == 1 and q_len == 1:
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states)
pass
loss = None
if labels is not None:

View file

@ -49,7 +49,7 @@ def MistralAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
# Check for inference
if past_key_value is not None and q_len == 1 and bsz == 1:
if past_key_value is not None:
A, past_key_value = LlamaAttention_fast_forward_inference(
self,
hidden_states,
@ -210,7 +210,13 @@ def MistralForCausalLM_fast_forward(
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
bsz, q_len, hd = hidden_states.shape
if bsz == 1 and q_len == 1:
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
logits = logits.unsqueeze(0).unsqueeze(0)
else:
logits = self.lm_head(hidden_states)
pass
loss = None
if labels is not None: