mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
2-4x faster native HF inference (#119)
* faster saving & inference * Update llama.py * Update save.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * fast inference * Update llama.py * Update save.py * Update llama.py * Mistral correct RoPE scaling * Max sequence lengths * Apache 2 * fast_linear_forward * Update utils.py * Update utils.py * No print * Update utils.py * Update utils.py * inference * Update llama.py * Fast inference RoPE * Update llama.py * Update llama.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * LoRA * Fast LoRA saving
This commit is contained in:
parent
3a9b2dee98
commit
04f8771821
8 changed files with 291 additions and 89 deletions
|
|
@ -22,4 +22,4 @@ from .fast_lora import (
|
|||
apply_lora_qkv,
|
||||
apply_lora_o,
|
||||
)
|
||||
from .utils import fast_dequantize, QUANT_STATE
|
||||
from .utils import fast_dequantize, QUANT_STATE, fast_linear_forward
|
||||
|
|
|
|||
|
|
@ -13,28 +13,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from .utils import fast_dequantize, QUANT_STATE
|
||||
from .utils import fast_dequantize, QUANT_STATE, get_lora_parameters
|
||||
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
||||
|
||||
|
||||
def get_lora_parameters(proj):
|
||||
# For DPO or disabled adapters
|
||||
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
||||
W = base_layer.weight
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
return W, QUANT_STATE(W), None, None, None
|
||||
pass
|
||||
|
||||
active_adapter = proj.active_adapters[0] if \
|
||||
hasattr(proj, "active_adapters") else proj.active_adapter
|
||||
A = proj.lora_A [active_adapter].weight
|
||||
B = proj.lora_B [active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
return W, QUANT_STATE(W), A, B, s
|
||||
pass
|
||||
|
||||
|
||||
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
||||
dtype = X.dtype
|
||||
W = fast_dequantize(W.t(), W_quant)
|
||||
|
|
|
|||
|
|
@ -134,8 +134,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
half = Q.shape[-1]//2
|
||||
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
||||
Q *= cos
|
||||
RH_Q *= sin
|
||||
Q += RH_Q
|
||||
Q.addcmul_(RH_Q, sin)
|
||||
# RH_Q *= sin
|
||||
# Q += RH_Q
|
||||
ctx.save_for_backward(cos, sin)
|
||||
return Q
|
||||
pass
|
||||
|
|
@ -147,8 +148,9 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
|
|||
half = dY.shape[-1]//2
|
||||
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
||||
dY *= cos
|
||||
RH_dY *= sin
|
||||
dY += RH_dY
|
||||
dY.addcmul_(RH_dY, sin)
|
||||
# RH_dY *= sin
|
||||
# dY += RH_dY
|
||||
return dY, None, None, None
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -33,14 +33,36 @@ import bitsandbytes as bnb
|
|||
get_ptr = bnb.functional.get_ptr
|
||||
import ctypes
|
||||
import torch
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
|
||||
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
|
||||
|
||||
|
||||
def QUANT_STATE(W):
|
||||
return getattr(W, "quant_state", None)
|
||||
pass
|
||||
|
||||
|
||||
def get_lora_parameters(proj):
|
||||
# For DPO or disabled adapters
|
||||
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
||||
W = base_layer.weight
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
return W, QUANT_STATE(W), None, None, None
|
||||
pass
|
||||
|
||||
active_adapter = proj.active_adapters[0] if \
|
||||
hasattr(proj, "active_adapters") else proj.active_adapter
|
||||
A = proj.lora_A [active_adapter].weight
|
||||
B = proj.lora_B [active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
return W, QUANT_STATE(W), A, B, s
|
||||
pass
|
||||
|
||||
|
||||
def fast_dequantize(W, quant_state = None, out = None):
|
||||
if quant_state is None: return W
|
||||
if type(quant_state) is not list:
|
||||
|
|
@ -90,3 +112,85 @@ def fast_dequantize(W, quant_state = None, out = None):
|
|||
is_transposed = (True if W.shape[0] == 1 else False)
|
||||
return out.t() if is_transposed else out
|
||||
pass
|
||||
|
||||
|
||||
def fast_gemv(X, W, quant_state, out = None, out_W = None):
|
||||
quant_state = W.quant_state
|
||||
bsz = 1
|
||||
q_len = 1
|
||||
hd = X.shape[0]
|
||||
|
||||
if type(quant_state) is not list:
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
||||
absmax = quant_state.absmax
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
stats = quant_state.code
|
||||
offset = quant_state.offset
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax
|
||||
code2 = state2.code
|
||||
blocksize2 = state2.blocksize
|
||||
else:
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
||||
offset, state2 = compressed_stats
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
pass
|
||||
bout = shape[0]
|
||||
if out is None: out = torch.empty(bout, dtype = dtype, device = "cuda")
|
||||
else: assert(out.shape[0] == bout)
|
||||
|
||||
n = 1
|
||||
m = shape[0]
|
||||
k = shape[1]
|
||||
lda = shape[0]
|
||||
ldc = shape[0]
|
||||
ldb = (X.shape[-1]+1)//2
|
||||
m = ctypes.c_int32(m)
|
||||
n = ctypes.c_int32(n)
|
||||
k = ctypes.c_int32(k)
|
||||
lda = ctypes.c_int32(lda)
|
||||
ldb = ctypes.c_int32(ldb)
|
||||
ldc = ctypes.c_int32(ldc)
|
||||
|
||||
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
||||
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
|
||||
)
|
||||
df += offset
|
||||
absmax = df
|
||||
|
||||
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
||||
cgemm_4bit_inference_naive_bf16
|
||||
|
||||
ptr_W = get_ptr(W)
|
||||
ptr_absmax = get_ptr(absmax)
|
||||
ptr_stats = get_ptr(stats)
|
||||
blocksize = ctypes.c_int32(blocksize)
|
||||
|
||||
fx(m, n, k, get_ptr(X), ptr_W, ptr_absmax, ptr_stats, get_ptr(out),
|
||||
lda, ldb, ldc, blocksize)
|
||||
|
||||
return out
|
||||
pass
|
||||
|
||||
|
||||
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
||||
W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
|
||||
out = fast_gemv(X, W, W_quant, out = out)
|
||||
if lora_A is not None:
|
||||
|
||||
# Save LoRAs for inference to stop data movement costs
|
||||
if not hasattr(lora_A, "_fast_lora"):
|
||||
dtype = X.dtype
|
||||
lora_A._fast_lora = lora_A.to(dtype).t()
|
||||
lora_B._fast_lora = lora_B.to(dtype)
|
||||
pass
|
||||
|
||||
temp_lora = torch.matmul(X, lora_A._fast_lora, out = temp_lora)
|
||||
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
||||
pass
|
||||
return out
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ def original_apply_o(self, X):
|
|||
pass
|
||||
|
||||
|
||||
from math import sqrt as math_sqrt
|
||||
def LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -102,60 +103,104 @@ def LlamaAttention_fast_forward_inference(
|
|||
This means we can pass in a row of Q, but we need to
|
||||
remember K and V, which are called the KV cache.
|
||||
"""
|
||||
Xn = hidden_states
|
||||
bsz, _, _ = hidden_states.size()
|
||||
K1, V1 = past_key_value
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
# assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
Qn = self.q_proj(Xn)
|
||||
Kn = self.k_proj(Xn)
|
||||
Vn = self.v_proj(Xn)
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Xn = hidden_states.view(self.hidden_size)
|
||||
K1, V1 = past_key_value
|
||||
seq_len = K1.shape[-2]
|
||||
K1 = K1.view(n_kv_heads, seq_len, head_dim)
|
||||
V1 = V1.view(n_kv_heads, seq_len, head_dim)
|
||||
|
||||
kv_seq_len = K1.shape[-2] + 1
|
||||
cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
|
||||
Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
|
||||
# LoRA or general matrix multiplication
|
||||
dtype = Xn.dtype
|
||||
# Qn = self.q_proj(Xn)
|
||||
# Kn = self.k_proj(Xn)
|
||||
# Vn = self.v_proj(Xn)
|
||||
Qn = fast_linear_forward(self.q_proj, Xn)
|
||||
Kn = fast_linear_forward(self.k_proj, Xn)
|
||||
Vn = fast_linear_forward(self.v_proj, Xn)
|
||||
|
||||
# Qn = Qn.view(1, 1, n_heads, head_dim).transpose(1, 2)
|
||||
# Kn = Kn.view(1, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
# Vn = Vn.view(1, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Qn = Qn.view(n_heads, 1, head_dim)
|
||||
Kn = Kn.view(n_kv_heads, 1, head_dim)
|
||||
Vn = Vn.view(n_kv_heads, 1, head_dim)
|
||||
|
||||
# kv_seq_len = K1.shape[-2] + 1
|
||||
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
|
||||
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
|
||||
cos = self.rotary_emb.cos_cached[seq_len]
|
||||
sin = self.rotary_emb.sin_cached[seq_len]
|
||||
h = head_dim // 2
|
||||
|
||||
RH_Q = torch.empty((n_heads, 1, head_dim), dtype = dtype, device = "cuda")
|
||||
RH_Q[:, :, :h] = Qn[:, :, h:]; RH_Q[:, :, h:] = Qn[:, :, :h]; torch.neg(RH_Q[:, :, :h], out = RH_Q[:, :, :h]);
|
||||
Qn *= cos; Qn.addcmul_(RH_Q, sin);
|
||||
|
||||
RH_K = RH_Q[:n_kv_heads, :, :] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda")
|
||||
RH_K[:, :, :h] = Kn[:, :, h:]; RH_K[:, :, h:] = Kn[:, :, :h]; torch.neg(RH_K[:, :, :h], out = RH_K[:, :, :h]);
|
||||
Kn *= cos; Kn.addcmul_(RH_K, sin);
|
||||
|
||||
# New KV cache
|
||||
Kn = torch.cat([K1, Kn], dim = 2)
|
||||
Vn = torch.cat([V1, Vn], dim = 2)
|
||||
# Kn = torch.cat([K1, Kn], dim = 2)
|
||||
# Vn = torch.cat([V1, Vn], dim = 2)
|
||||
Kn = torch.cat([K1, Kn], dim = 1)
|
||||
Vn = torch.cat([V1, Vn], dim = 1)
|
||||
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
_, _, cached_len, _ = Kn.shape
|
||||
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
# _, _, cached_len, _ = Kn.shape
|
||||
# Knn = Kn[:, :, None, :, :].expand(1, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
# Vnn = Vn[:, :, None, :, :].expand(1, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
# Knn = Knn.reshape(1, n_heads, cached_len, head_dim)
|
||||
# Vnn = Vnn.reshape(1, n_heads, cached_len, head_dim)
|
||||
new_seq_len = seq_len + 1
|
||||
Knn = Kn[:, None, :, :].expand(n_kv_heads, n_groups, new_seq_len, head_dim)
|
||||
Vnn = Vn[:, None, :, :].expand(n_kv_heads, n_groups, new_seq_len, head_dim)
|
||||
Knn = Knn.reshape(n_heads, new_seq_len, head_dim)
|
||||
Vnn = Vnn.reshape(n_heads, new_seq_len, head_dim)
|
||||
else:
|
||||
Knn, Vnn = Kn, Vn
|
||||
|
||||
# Attention
|
||||
A = torch.matmul(Qn, Knn.transpose(2, 3))
|
||||
A *= 1.0 / (self.head_dim**0.5)
|
||||
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(A.dtype)
|
||||
A = torch.matmul(A, Vnn)
|
||||
A = A.transpose(1, 2)
|
||||
A = A.reshape(bsz, 1, self.hidden_size)
|
||||
A = self.o_proj(A)
|
||||
return A, (Kn, Vn)
|
||||
# A = torch.matmul(Qn, Knn.transpose(2, 3))
|
||||
A = torch.matmul(Qn, Knn.transpose(1, 2))
|
||||
A *= 1.0 / math_sqrt(self.head_dim)
|
||||
A[:] = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
|
||||
A = torch.matmul(A, Vnn, out = Qn)
|
||||
# A = A.transpose(1, 2)
|
||||
A = A.view(self.hidden_size)
|
||||
|
||||
# A = self.o_proj(A)
|
||||
A = fast_linear_forward(self.o_proj, A)
|
||||
A = A.reshape(1, 1, self.hidden_size)
|
||||
|
||||
# return A, (Kn, Vn)
|
||||
return A, (Kn.unsqueeze(0), Vn.unsqueeze(0))
|
||||
pass
|
||||
|
||||
|
||||
torch_silu = torch.nn.functional.silu
|
||||
def fast_mlp_inference(self, X):
|
||||
gate = self.gate_proj(X)
|
||||
up = self.up_proj(X)
|
||||
hidden_size = self.hidden_size
|
||||
X = X.view(hidden_size)
|
||||
|
||||
# gate = self.gate_proj(X)
|
||||
# up = self.up_proj(X)
|
||||
gate = fast_linear_forward(self.gate_proj, X)
|
||||
up = fast_linear_forward(self. up_proj, X)
|
||||
gate = torch_silu(gate, inplace = True)
|
||||
gate *= up
|
||||
X = self.down_proj(gate)
|
||||
|
||||
# X = self.down_proj(gate)
|
||||
down = fast_linear_forward(self.down_proj, gate, out = up[:hidden_size])
|
||||
X = down.view(1, 1, hidden_size)
|
||||
|
||||
return X
|
||||
pass
|
||||
|
||||
|
|
@ -676,10 +721,10 @@ class FastLlamaModel:
|
|||
|
||||
statistics = \
|
||||
f"==((====))== Unsloth: Fast Llama patching release {__version__}\n"\
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
|
||||
f"O^O/ \_/ \\ CUDA capability = {gpu_stats.major}.{gpu_stats.minor}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
|
||||
f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform_system}\n'
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f' "-____-" Apache 2 free license: http://github.com/unslothai/unsloth'
|
||||
logger.warning_once(statistics)
|
||||
FastLlamaModel.pre_patch()
|
||||
|
||||
|
|
@ -731,7 +776,7 @@ class FastLlamaModel:
|
|||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = max_seq_length,
|
||||
model_max_length = max_position_embeddings,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
|
|
@ -760,7 +805,7 @@ class FastLlamaModel:
|
|||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = max_seq_length,
|
||||
model_max_length = max_position_embeddings,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
|
|
@ -1076,4 +1121,47 @@ class FastLlamaModel:
|
|||
internal_model.max_seq_length = max_seq_length
|
||||
return model
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
def for_inference(model):
|
||||
if not hasattr(model, "_original_forward"):
|
||||
model._original_forward = model.forward
|
||||
pass
|
||||
model.forward = torch.inference_mode(model._original_forward)
|
||||
|
||||
internal_model = model
|
||||
internal_model.gradient_checkpointing = False
|
||||
internal_model.training = False
|
||||
|
||||
while hasattr(internal_model, "model"):
|
||||
internal_model = internal_model.model
|
||||
internal_model.gradient_checkpointing = False
|
||||
internal_model.training = False
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
@staticmethod
|
||||
def for_training(model, use_gradient_checkpointing = True):
|
||||
if hasattr(model, "_original_forward"):
|
||||
model.forward = model._original_forward
|
||||
pass
|
||||
|
||||
internal_model = model
|
||||
internal_model.gradient_checkpointing = use_gradient_checkpointing
|
||||
internal_model.training = True
|
||||
|
||||
# Delete all fast inference loras
|
||||
for param in model.parameters():
|
||||
if hasattr(param, "_fast_lora"):
|
||||
del param._fast_lora
|
||||
pass
|
||||
|
||||
while hasattr(internal_model, "model"):
|
||||
internal_model = internal_model.model
|
||||
internal_model.gradient_checkpointing = use_gradient_checkpointing
|
||||
internal_model.training = True
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -42,6 +42,12 @@ __INT_TO_FLOAT_MAPPER = \
|
|||
"unsloth/tinyllama",
|
||||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
|
||||
),
|
||||
"unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : (
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
),
|
||||
"unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : (
|
||||
"mistralai/Mistral-7B-Instruct-v0.2",
|
||||
),
|
||||
}
|
||||
|
||||
INT_TO_FLOAT_MAPPER = {}
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ def MistralAttention_fast_forward(
|
|||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
sw = getattr(self.config, "sliding_window", None)
|
||||
sw = q_len if sw is None else sw
|
||||
sw = q_len if (sw is None or sw == "null") else sw
|
||||
window = (-1, -1) if (q_len <= sw) else (sw, sw)
|
||||
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
|
||||
else:
|
||||
|
|
@ -176,7 +176,7 @@ def MistralForCausalLM_fast_forward(
|
|||
if causal_mask is None:
|
||||
bsz, q_len = input_ids.shape
|
||||
sliding_window = getattr(self.config, "sliding_window", None)
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
elif q_len <= sliding_window:
|
||||
causal_mask = xformers.attn_bias.LowerTriangularMask()
|
||||
|
|
@ -265,9 +265,11 @@ class FastMistralModel(FastLlamaModel):
|
|||
rope_scaling = None, # Mistral does not support RoPE scaling
|
||||
fix_tokenizer = True,
|
||||
**kwargs,
|
||||
):
|
||||
):
|
||||
# Mistral does NOT support RoPE Scaling!
|
||||
if rope_scaling is not None:
|
||||
logger.warning_once("Unsloth: Mistral models do not support RoPE scaling.")
|
||||
pass
|
||||
|
||||
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
|
||||
gpu_stats = torch.cuda.get_device_properties(0)
|
||||
|
|
@ -275,10 +277,10 @@ class FastMistralModel(FastLlamaModel):
|
|||
|
||||
statistics = \
|
||||
f"==((====))== Unsloth: Fast Mistral patching release {__version__}\n"\
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB\n"\
|
||||
f"O^O/ \_/ \\ CUDA capability = {gpu_stats.major}.{gpu_stats.minor}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f"\ / Pytorch version: {torch.__version__}. CUDA Toolkit = {torch.version.cuda}\n"\
|
||||
f' "-____-" bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Platform = {platform_system}\n'
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f' "-____-" Apache 2 free license: http://github.com/unslothai/unsloth'
|
||||
logger.warning_once(statistics)
|
||||
FastMistralModel.pre_patch()
|
||||
|
||||
|
|
@ -290,6 +292,18 @@ class FastMistralModel(FastLlamaModel):
|
|||
|
||||
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
|
||||
|
||||
# Check max sequence length
|
||||
model_config = AutoConfig.from_pretrained(model_name, token = token)
|
||||
model_max_seq_length = model_config.max_position_embeddings
|
||||
|
||||
# Mistral does NOT support RoPE Scaling sadly so we have to error out.
|
||||
if max_seq_length > model_max_seq_length:
|
||||
raise RuntimeError(
|
||||
"Unsloth: Unfortunately Mistral type models do not support RoPE scaling!\n"\
|
||||
f"The maximum sequence length supported is {model_max_seq_length}.",
|
||||
)
|
||||
pass
|
||||
|
||||
bnb_config = None
|
||||
if load_in_4bit:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
|
|
@ -299,20 +313,21 @@ class FastMistralModel(FastLlamaModel):
|
|||
bnb_4bit_compute_dtype = dtype,
|
||||
)
|
||||
|
||||
max_position_embeddings = max(max_seq_length, model_max_seq_length)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_map = device_map,
|
||||
torch_dtype = dtype,
|
||||
device_map = device_map,
|
||||
torch_dtype = dtype,
|
||||
quantization_config = bnb_config,
|
||||
token = token,
|
||||
# rope_scaling = rope_scaling,
|
||||
token = token,
|
||||
# rope_scaling = rope_scaling,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
model_max_length = max_seq_length,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
model_max_length = max_position_embeddings,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
|
||||
model, tokenizer = patch_tokenizer(model, tokenizer)
|
||||
|
|
@ -337,12 +352,12 @@ class FastMistralModel(FastLlamaModel):
|
|||
# We check the tokenizer first for errors
|
||||
if fix_tokenizer:
|
||||
tokenizer = check_tokenizer(
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = max_seq_length,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
model = model,
|
||||
tokenizer = tokenizer,
|
||||
model_name = model_name,
|
||||
model_max_length = max_position_embeddings,
|
||||
padding_side = "right",
|
||||
token = token,
|
||||
)
|
||||
pass
|
||||
patch_saving_functions(tokenizer)
|
||||
|
|
|
|||
|
|
@ -77,10 +77,14 @@ def _merge_lora(layer, name):
|
|||
W, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
|
||||
W = fast_dequantize(W, quant_state).to(torch.float32).t()
|
||||
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
|
||||
W += sAB
|
||||
if not torch.isfinite(W).all():
|
||||
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
|
||||
|
||||
if A is not None:
|
||||
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
|
||||
W += sAB
|
||||
if not torch.isfinite(W).all():
|
||||
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
|
||||
pass
|
||||
|
||||
W = W.t().to(dtype)
|
||||
else:
|
||||
W = layer.weight
|
||||
|
|
@ -156,6 +160,7 @@ def unsloth_save_model(
|
|||
pass
|
||||
|
||||
if save_method == "merged_4bit":
|
||||
|
||||
print("Unsloth: Merging 4bit and LoRA weights to 4bit...")
|
||||
print("This might take 5 minutes...")
|
||||
model = model.merge_and_unload()
|
||||
|
|
|
|||
Loading…
Reference in a new issue