mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Hotfix - fix inference (#146)
* faster saving & inference
* Update llama.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* fast inference
* Update llama.py
* Update save.py
* Update llama.py
* Mistral correct RoPE scaling
* Max sequence lengths
* Apache 2
* fast_linear_forward
* Update utils.py
* Update utils.py
* No print
* Update utils.py
* Update utils.py
* inference
* Update llama.py
* Fast inference RoPE
* Update llama.py
* Update llama.py
* RoPE
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* LoRA
* Fast LoRA saving
* Update llama.py
* hidden_states
* q_len == 1
* q_len issue
* Update mistral.py
* Update mistral.py
* incorrect inference
* Update to transformers 4.37
* Graceful FA2 error + torch 2.1.1
* Update mapper.py
* Update pyproject.toml
* Fix saving and bnb-4bit
* Update fast_lora.py
* Update fast_lora.py
* remove patching
* Update llama.py
* Update llama.py
* Update swiglu.py
* Repatch
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update llama.py
* Update fast_lora.py
* Update llama.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update save.py
* Update fast_lora.py
* Update utils.py
* Update llama.py
* Update fast_lora.py
* Update swiglu.py
* Update save.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Revert "Update llama.py"
This reverts commit a208ec46e0.
* Update llama.py
* Works?
* Update pyproject.toml
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Swiglu
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* attention_mask
* Update llama.py
* Update llama.py
* labels
* Update mistral.py
* Update llama.py
* attention mask
* Update save.py
* Update save.py
* Update mistral.py
* attention mask
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update dpo.py
* Patch saving
* Update save.py
* Update save.py
* patch_saving_functions
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* print
* Mistral patch
* Update mistral.py
* Update save.py
* saving
* Update llama.py
* Update llama.py
* Fast inference repatch
* Update llama.py
* Update utils.py
* Update utils.py
* Update utils.py
* Update mistral.py
* Update __init__.py
* Fix inference
* Update mistral.py
* fast lm_head
* Remove fast path
* Update rope_embedding.py
* Update loader.py
* LlamaAttention_fast_forward_inference
* if past_key_value is not None and q_len == 1:
* revert inference
* Update loader.py
* past_key_value
This commit is contained in:
parent
a3a2ad9382
commit
2f55935f94
5 changed files with 159 additions and 50 deletions
|
|
@ -22,4 +22,4 @@ from .fast_lora import (
|
|||
apply_lora_qkv,
|
||||
apply_lora_o,
|
||||
)
|
||||
from .utils import fast_dequantize, QUANT_STATE, fast_linear_forward
|
||||
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward
|
||||
|
|
|
|||
|
|
@ -134,9 +134,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
half = Q.shape[-1]//2
|
||||
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
||||
Q *= cos
|
||||
Q.addcmul_(RH_Q, sin)
|
||||
# RH_Q *= sin
|
||||
# Q += RH_Q
|
||||
# Q.addcmul_(RH_Q, sin)
|
||||
RH_Q *= sin
|
||||
Q += RH_Q
|
||||
ctx.save_for_backward(cos, sin)
|
||||
return Q
|
||||
pass
|
||||
|
|
@ -148,9 +148,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
half = dY.shape[-1]//2
|
||||
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
||||
dY *= cos
|
||||
dY.addcmul_(RH_dY, sin)
|
||||
# RH_dY *= sin
|
||||
# dY += RH_dY
|
||||
# dY.addcmul_(RH_dY, sin)
|
||||
RH_dY *= sin
|
||||
dY += RH_dY
|
||||
return dY, None, None, None
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -114,11 +114,12 @@ def fast_dequantize(W, quant_state = None, out = None):
|
|||
pass
|
||||
|
||||
|
||||
def fast_gemv(X, W, quant_state, out = None, out_W = None):
|
||||
quant_state = W.quant_state
|
||||
bsz = 1
|
||||
q_len = 1
|
||||
hd = X.shape[0]
|
||||
def fast_gemv(X, W, quant_state, out = None):
|
||||
if quant_state is None: return torch.matmul(X, W, out = out)
|
||||
# For fast X @ W where seq_len == 1
|
||||
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
||||
bsz, q_len, hd = X.shape
|
||||
assert(q_len == 1)
|
||||
|
||||
if type(quant_state) is not list:
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
||||
|
|
@ -137,9 +138,14 @@ def fast_gemv(X, W, quant_state, out = None, out_W = None):
|
|||
offset, state2 = compressed_stats
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
pass
|
||||
assert(dtype == X.dtype)
|
||||
bout = shape[0]
|
||||
if out is None: out = torch.empty(bout, dtype = dtype, device = "cuda")
|
||||
else: assert(out.shape[0] == bout)
|
||||
|
||||
if out is None:
|
||||
out = torch.empty((bsz, 1, bout,), dtype = dtype, device = "cuda")
|
||||
else:
|
||||
assert(out.shape == (bsz, 1, bout,))
|
||||
pass
|
||||
|
||||
n = 1
|
||||
m = shape[0]
|
||||
|
|
@ -170,30 +176,46 @@ def fast_gemv(X, W, quant_state, out = None, out_W = None):
|
|||
ptr_stats = get_ptr(stats)
|
||||
blocksize = ctypes.c_int32(blocksize)
|
||||
|
||||
fx(m, n, k, get_ptr(X), ptr_W, ptr_absmax, ptr_stats, get_ptr(out),
|
||||
lda, ldb, ldc, blocksize)
|
||||
for row in range(bsz):
|
||||
fx(m, n, k, get_ptr(X[row]), ptr_W, ptr_absmax, ptr_stats, get_ptr(out[row]),
|
||||
lda, ldb, ldc, blocksize)
|
||||
pass
|
||||
|
||||
return out
|
||||
pass
|
||||
|
||||
|
||||
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
||||
|
||||
W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
|
||||
|
||||
bsz, _, in_dim = X.shape
|
||||
|
||||
if W_quant is None:
|
||||
out = torch.matmul(X, W.t())
|
||||
else:
|
||||
elif bsz <= 4:
|
||||
# Only batches of 4 are faster with Gemv
|
||||
out = fast_gemv(X, W, W_quant, out = out)
|
||||
if lora_A is not None:
|
||||
|
||||
# Save LoRAs for inference to stop data movement costs
|
||||
if not hasattr(lora_A, "_fast_lora"):
|
||||
dtype = X.dtype
|
||||
lora_A._fast_lora = lora_A.to(dtype).t()
|
||||
lora_B._fast_lora = lora_B.to(dtype)
|
||||
pass
|
||||
|
||||
temp_lora = torch.matmul(X, lora_A._fast_lora, out = temp_lora)
|
||||
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
||||
else:
|
||||
W = fast_dequantize(W.t(), W_quant)
|
||||
out = torch.matmul(X, W, out = out)
|
||||
pass
|
||||
|
||||
# Add in LoRA weights
|
||||
if lora_A is not None:
|
||||
out_dim = out.shape[2]
|
||||
dtype = X.dtype
|
||||
if bsz == 1:
|
||||
out = out.view(out_dim)
|
||||
temp_lora = torch.mv(lora_A.to(dtype), X.ravel(), out = temp_lora)
|
||||
out.addmv_(lora_B.to(dtype), temp_lora, alpha = lora_S)
|
||||
else:
|
||||
out = out.view(bsz, out_dim)
|
||||
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A.to(dtype).t(), out = temp_lora)
|
||||
out.addmm_(temp_lora, lora_B.to(dtype).t(), alpha = lora_S)
|
||||
pass
|
||||
out = out.view(bsz, 1, out_dim)
|
||||
pass
|
||||
|
||||
return out
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ pass
|
|||
|
||||
|
||||
from math import sqrt as math_sqrt
|
||||
def LlamaAttention_fast_forward_inference(
|
||||
def _LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
|
|
@ -185,11 +185,89 @@ def LlamaAttention_fast_forward_inference(
|
|||
pass
|
||||
|
||||
|
||||
def LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]],
|
||||
position_ids,
|
||||
):
|
||||
"""
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
|
||||
Fast inference using KV cache.
|
||||
QK^T can be computed in 4 chunks
|
||||
|
||||
[Q, q] @ [K, k].T where q, k are the new tokens.
|
||||
[QK^T, Qk^T]
|
||||
[qK^T, qk^T]
|
||||
|
||||
Since the attention mask wipes Qk^T, we just get
|
||||
[QK^T, 0]
|
||||
[qK^T, qk^T]
|
||||
|
||||
Since softmax is row-wise, we get
|
||||
softmax([QK^T, 0])
|
||||
softmax([qK^T, qk^T])
|
||||
|
||||
We then multiply by [V]
|
||||
[v]
|
||||
softmax([QK^T, 0]) [softmax(QK^T)V] *
|
||||
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
|
||||
|
||||
But notice * [softmax(QK^T)V] is just the last attention.
|
||||
We just need to compute the last final row.
|
||||
|
||||
This means we can pass in a row of Q, but we need to
|
||||
remember K and V, which are called the KV cache.
|
||||
"""
|
||||
Xn = hidden_states
|
||||
bsz, _, _ = hidden_states.size()
|
||||
K1, V1 = past_key_value
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
Qn = self.q_proj(Xn)
|
||||
Kn = self.k_proj(Xn)
|
||||
Vn = self.v_proj(Xn)
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = K1.shape[-2] + 1
|
||||
cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
|
||||
Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
|
||||
|
||||
# New KV cache
|
||||
Kn = torch.cat([K1, Kn], dim = 2)
|
||||
Vn = torch.cat([V1, Vn], dim = 2)
|
||||
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
_, _, cached_len, _ = Kn.shape
|
||||
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
else:
|
||||
Knn, Vnn = Kn, Vn
|
||||
|
||||
# Attention
|
||||
A = torch.matmul(Qn, Knn.transpose(2, 3))
|
||||
A *= 1.0 / (self.head_dim**0.5)
|
||||
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(A.dtype)
|
||||
A = torch.matmul(A, Vnn)
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, self.hidden_size)
|
||||
A = original_apply_o(self, A)
|
||||
return A, (Kn, Vn)
|
||||
pass
|
||||
|
||||
|
||||
torch_silu = torch.nn.functional.silu
|
||||
def fast_mlp_inference(self, X):
|
||||
hidden_size = self.hidden_size
|
||||
X = X.view(hidden_size)
|
||||
|
||||
# gate = self.gate_proj(X)
|
||||
# up = self.up_proj(X)
|
||||
gate = fast_linear_forward(self.gate_proj, X)
|
||||
|
|
@ -198,20 +276,18 @@ def fast_mlp_inference(self, X):
|
|||
gate *= up
|
||||
|
||||
# X = self.down_proj(gate)
|
||||
down = fast_linear_forward(self.down_proj, gate, out = up[:hidden_size])
|
||||
X = down.view(1, 1, hidden_size)
|
||||
|
||||
return X
|
||||
down = fast_linear_forward(self.down_proj, gate)
|
||||
return down
|
||||
pass
|
||||
|
||||
|
||||
def fast_rms_layernorm_inference(self, X):
|
||||
old_dtype = X.dtype
|
||||
X = X.to(torch.float32)
|
||||
variance = X.square().mean(-1, keepdim = True)
|
||||
XX = X.to(torch.float32)
|
||||
variance = XX.square().mean(-1, keepdim = True)
|
||||
variance += self.variance_epsilon
|
||||
X *= variance.rsqrt_()
|
||||
X = X.to(old_dtype)
|
||||
XX *= variance.rsqrt_()
|
||||
X = XX.to(old_dtype) # Must preserve due to residual
|
||||
X *= self.weight
|
||||
return X
|
||||
pass
|
||||
|
|
@ -234,7 +310,7 @@ def LlamaAttention_fast_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Check for inference
|
||||
if False: #past_key_value is not None and q_len == 1 and bsz == 1:
|
||||
if past_key_value is not None:
|
||||
A, past_key_value = LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
@ -350,7 +426,7 @@ def LlamaDecoderLayer_fast_forward(
|
|||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
bsz, q_len, hd = hidden_states.size()
|
||||
if False: #(past_key_value is not None and q_len == 1 and bsz == 1):
|
||||
if past_key_value is not None:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
|
||||
|
|
@ -488,8 +564,7 @@ def LlamaModel_fast_forward(
|
|||
|
||||
# Fix up attention mask by setting elements to 0
|
||||
# Specifically for DPO
|
||||
if self._has_no_labels and attention_mask is not None and \
|
||||
attention_mask.shape[1] == seq_length:
|
||||
if self._has_no_labels and (attention_mask is not None) and (past_key_values is None):
|
||||
# Careful for inference the attention_mask is size (1, kv_seq_len)
|
||||
# Whilst the input_embeds is size (1, 1, 4096)
|
||||
inputs_requires_grad = inputs_embeds.requires_grad
|
||||
|
|
@ -501,7 +576,7 @@ def LlamaModel_fast_forward(
|
|||
# Ignore attention_mask
|
||||
if attention_mask is None:
|
||||
padding_mask = None
|
||||
elif self.training:
|
||||
elif False:
|
||||
attention_mask = None
|
||||
padding_mask = None
|
||||
else:
|
||||
|
|
@ -522,7 +597,7 @@ def LlamaModel_fast_forward(
|
|||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if past_key_values is None and self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
|
||||
|
|
@ -581,7 +656,7 @@ def LlamaModel_fast_forward(
|
|||
pass
|
||||
|
||||
bsz, q_len, hd = hidden_states.size()
|
||||
if (past_key_value is not None and q_len == 1):
|
||||
if past_key_values is not None:
|
||||
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
|
||||
else:
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
|
||||
|
|
@ -644,7 +719,13 @@ def LlamaForCausalLM_fast_forward(
|
|||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
pass
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def MistralAttention_fast_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Check for inference
|
||||
if past_key_value is not None and q_len == 1 and bsz == 1:
|
||||
if past_key_value is not None:
|
||||
A, past_key_value = LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
@ -210,7 +210,13 @@ def MistralForCausalLM_fast_forward(
|
|||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(self.lm_head.weight, hidden_states.ravel())
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
pass
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue