mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix bugs (#129)
* faster saving & inference * Update llama.py * Update save.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update mistral.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * fast inference * Update llama.py * Update save.py * Update llama.py * Mistral correct RoPE scaling * Max sequence lengths * Apache 2 * fast_linear_forward * Update utils.py * Update utils.py * No print * Update utils.py * Update utils.py * inference * Update llama.py * Fast inference RoPE * Update llama.py * Update llama.py * RoPE * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * LoRA * Fast LoRA saving * Update llama.py * hidden_states * q_len == 1 * q_len issue * Update mistral.py * Update mistral.py * incorrect inference * Update to transformers 4.37 * Graceful FA2 error + torch 2.1.1 * Update mapper.py * Update pyproject.toml * Fix saving and bnb-4bit * Update fast_lora.py * Update fast_lora.py * remove patching * Update llama.py * Update llama.py * Update swiglu.py * Repatch * Update fast_lora.py
This commit is contained in:
parent
04f8771821
commit
62fae3aa74
8 changed files with 98 additions and 74 deletions
|
|
@ -32,18 +32,8 @@ include-package-data = false
|
|||
exclude = ["images*"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
huggingfacedev = [
|
||||
"transformers @ git+https://github.com/huggingface/transformers",
|
||||
"datasets",
|
||||
"sentencepiece",
|
||||
"accelerate",
|
||||
"trl>=0.7.9",
|
||||
"peft",
|
||||
"tqdm",
|
||||
"psutil",
|
||||
]
|
||||
huggingface = [
|
||||
"transformers",
|
||||
"transformers>=4.37.0",
|
||||
"datasets",
|
||||
"sentencepiece",
|
||||
"accelerate",
|
||||
|
|
@ -107,15 +97,15 @@ colab_ampere = [
|
|||
"ninja",
|
||||
"flash-attn",
|
||||
]
|
||||
colab_dev = [
|
||||
"unsloth[huggingfacedev]",
|
||||
colab_torch211 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121only]",
|
||||
"unsloth[cu121onlytorch211]",
|
||||
]
|
||||
colab_ampere_dev = [
|
||||
"unsloth[huggingfacedev]",
|
||||
colab_ampere_torch211 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121only]",
|
||||
"unsloth[cu121onlytorch211]",
|
||||
"packaging",
|
||||
"ninja",
|
||||
"flash-attn",
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||
|
||||
# Gate projection LoRA weights
|
||||
d_gateA = X.t() @ (DW_dfg @ gateB.t())
|
||||
d_gateB = (gateA.t() @ X.t() @ DW_dfg)
|
||||
d_gateB = (gateA.t() @ X.t()) @ DW_dfg
|
||||
d_gateA *= gateS
|
||||
d_gateB *= gateS
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,19 @@ major_version, minor_version = torch.cuda.get_device_capability()
|
|||
if major_version >= 8:
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
HAS_FLASH_ATTENTION = True
|
||||
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_cuda
|
||||
HAS_FLASH_ATTENTION = True
|
||||
except:
|
||||
logger.warning_once(
|
||||
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
|
||||
"A possible explanation is you have a new CUDA version which isn't\n"\
|
||||
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
|
||||
"We shall now use Xformers instead, which gets a 0.01% performance hit.\n"\
|
||||
"We found this negligible impact by benchmarking on 1x A100."
|
||||
)
|
||||
HAS_FLASH_ATTENTION = False
|
||||
except:
|
||||
HAS_FLASH_ATTENTION = False
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ def LlamaAttention_fast_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Check for inference
|
||||
if past_key_value is not None and q_len == 1:
|
||||
if past_key_value is not None and q_len == 1 and bsz == 1:
|
||||
A, past_key_value = LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
@ -271,6 +271,7 @@ def LlamaAttention_fast_forward(
|
|||
if past_key_value is not None:
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
pass
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -283,13 +284,13 @@ def LlamaAttention_fast_forward(
|
|||
|
||||
# Group query attention
|
||||
if n_groups != 1:
|
||||
K = K .view(bsz, q_len, n_kv_heads, 1, head_dim)
|
||||
V = V .view(bsz, q_len, n_kv_heads, 1, head_dim)
|
||||
K = K.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
V = V.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
|
||||
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
|
||||
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
|
||||
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
|
||||
if hidden_states.requires_grad:
|
||||
K = K.reshape(bsz, q_len, n_heads, head_dim)
|
||||
V = V.reshape(bsz, q_len, n_heads, head_dim)
|
||||
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
||||
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
||||
else:
|
||||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
pass
|
||||
|
|
@ -304,10 +305,10 @@ def LlamaAttention_fast_forward(
|
|||
else:
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
||||
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
||||
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
|
||||
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
|
||||
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
|
||||
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
|
||||
pass
|
||||
# Needs (batch_size, n_heads, seq_len, head_dim)
|
||||
# is_casual and attention_mask must not be both set!
|
||||
|
|
@ -349,7 +350,7 @@ def LlamaDecoderLayer_fast_forward(
|
|||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
bsz, q_len, hd = hidden_states.size()
|
||||
if (past_key_value is not None and q_len == 1):
|
||||
if (past_key_value is not None and q_len == 1 and bsz == 1):
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
|
||||
|
|
@ -722,9 +723,9 @@ class FastLlamaModel:
|
|||
statistics = \
|
||||
f"==((====))== Unsloth: Fast Llama patching release {__version__}\n"\
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f' "-____-" Apache 2 free license: http://github.com/unslothai/unsloth'
|
||||
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
|
||||
logger.warning_once(statistics)
|
||||
FastLlamaModel.pre_patch()
|
||||
|
||||
|
|
@ -813,10 +814,13 @@ class FastLlamaModel:
|
|||
patch_saving_functions(tokenizer)
|
||||
|
||||
# Fix up config for transformers uploading PEFT
|
||||
name = model.config._name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.config.update({"_name_or_path" : name})
|
||||
# Not necessary anymore since we require transformers>=4.37!
|
||||
if False:
|
||||
name = model.config._name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.config.update({"_name_or_path" : name})
|
||||
pass
|
||||
pass
|
||||
|
||||
# Log Unsloth version for future fastpaths for inference
|
||||
|
|
@ -1019,11 +1023,13 @@ class FastLlamaModel:
|
|||
|
||||
# Fix up config for transformers uploading PEFT
|
||||
for active_adapter in model.peft_config.keys():
|
||||
name = model.peft_config[active_adapter].base_model_name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.peft_config[active_adapter].base_model_name_or_path = name
|
||||
pass
|
||||
# Not necessary since we requires transformers >= 4.37
|
||||
if False:
|
||||
name = model.peft_config[active_adapter].base_model_name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.peft_config[active_adapter].base_model_name_or_path = name
|
||||
pass
|
||||
# Add revision to enable future fast inference paths
|
||||
model.peft_config[active_adapter].revision = f"unsloth"
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def _get_model_name(model_name, load_in_4bit = True):
|
|||
logger.warning_once(
|
||||
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
|
||||
f"4bit loading.\nThe minimum required version is 4.37.\n"\
|
||||
f'Try `pip install "git+https://github.com/huggingface/transformers.git"`\n'\
|
||||
f'Try `pip install --upgrade "transformers>=4.37"`\n'\
|
||||
f"to obtain the latest transformers build, then restart this session.\n"\
|
||||
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ __INT_TO_FLOAT_MAPPER = \
|
|||
),
|
||||
"unsloth/zephyr-sft-bnb-4bit" : (
|
||||
"unsloth/zephyr-sft",
|
||||
"alignment-handbook/zephyr-7b-sft-full",
|
||||
"HuggingFaceH4/mistral-7b-sft-beta",
|
||||
),
|
||||
"unsloth/tinyllama-bnb-4bit" : (
|
||||
"unsloth/tinyllama",
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def MistralAttention_fast_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# Check for inference
|
||||
if past_key_value is not None and q_len == 1:
|
||||
if past_key_value is not None and q_len == 1 and bsz == 1:
|
||||
A, past_key_value = LlamaAttention_fast_forward_inference(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
@ -84,9 +84,9 @@ def MistralAttention_fast_forward(
|
|||
pass
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
K = torch.cat([past_key_value[0], K], dim = 2)
|
||||
V = torch.cat([past_key_value[1], V], dim = 2)
|
||||
pass
|
||||
past_key_value = (K, V) if use_cache else None
|
||||
|
||||
# Attention module
|
||||
|
|
@ -95,32 +95,33 @@ def MistralAttention_fast_forward(
|
|||
Q = Q.transpose(1, 2)
|
||||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
M = bsz * q_len
|
||||
K_M = V_M = bsz * kv_seq_len
|
||||
Q_M = bsz * q_len
|
||||
|
||||
has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
|
||||
|
||||
# Group query attention
|
||||
K = K .view(bsz, q_len, n_kv_heads, 1, head_dim)
|
||||
V = V .view(bsz, q_len, n_kv_heads, 1, head_dim)
|
||||
K = K.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
V = V.expand(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
|
||||
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
|
||||
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
|
||||
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
|
||||
if hidden_states.requires_grad:
|
||||
K = K.reshape(bsz, q_len, n_heads, head_dim)
|
||||
V = V.reshape(bsz, q_len, n_heads, head_dim)
|
||||
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
||||
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
||||
|
||||
if has_swa:
|
||||
Q = Q.view(1, M, n_heads, head_dim)
|
||||
K = K.view(1, M, n_heads, head_dim)
|
||||
V = V.view(1, M, n_heads, head_dim)
|
||||
Q = Q.view(1, Q_M, n_heads, head_dim)
|
||||
K = K.view(1, K_M, n_heads, head_dim)
|
||||
V = V.view(1, V_M, n_heads, head_dim)
|
||||
pass
|
||||
else:
|
||||
# Xformers does support the forward pass though
|
||||
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
||||
|
||||
if has_swa:
|
||||
Q = Q.view(1, M, n_kv_heads, n_groups, head_dim)
|
||||
K = K.view(1, M, n_kv_heads, n_groups, head_dim)
|
||||
V = V.view(1, M, n_kv_heads, n_groups, head_dim)
|
||||
Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
|
||||
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
|
||||
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
|
||||
pass
|
||||
pass
|
||||
|
||||
|
|
@ -132,16 +133,16 @@ def MistralAttention_fast_forward(
|
|||
K = K.transpose(1, 2)
|
||||
V = V.transpose(1, 2)
|
||||
sw = getattr(self.config, "sliding_window", None)
|
||||
sw = q_len if (sw is None or sw == "null") else sw
|
||||
window = (-1, -1) if (q_len <= sw) else (sw, sw)
|
||||
sw = kv_seq_len if (sw is None or sw == "null") else sw
|
||||
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
|
||||
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
|
||||
else:
|
||||
# Grouped query attention
|
||||
# if n_groups != 1:
|
||||
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
||||
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
||||
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
||||
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
|
||||
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
|
||||
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
|
||||
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
|
||||
# pass
|
||||
# Needs (batch_size, n_heads, seq_len, head_dim)
|
||||
# is_casual and attention_mask must not be both set!
|
||||
|
|
@ -278,7 +279,7 @@ class FastMistralModel(FastLlamaModel):
|
|||
statistics = \
|
||||
f"==((====))== Unsloth: Fast Mistral patching release {__version__}\n"\
|
||||
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform = {platform_system}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"O^O/ \_/ \\ Pytorch: {torch.__version__}. CUDA = {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit = {torch.version.cuda}.\n"\
|
||||
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. Xformers = {xformers_version}. FA = {HAS_FLASH_ATTENTION}.\n"\
|
||||
f' "-____-" Apache 2 free license: http://github.com/unslothai/unsloth'
|
||||
logger.warning_once(statistics)
|
||||
|
|
@ -363,11 +364,13 @@ class FastMistralModel(FastLlamaModel):
|
|||
patch_saving_functions(tokenizer)
|
||||
|
||||
# Fix up config for transformers uploading PEFT
|
||||
name = model.config._name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.config.update({"_name_or_path" : name})
|
||||
pass
|
||||
# Not necessary anymore since we require transformers>=4.37
|
||||
if False:
|
||||
name = model.config._name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.config.update({"_name_or_path" : name})
|
||||
pass
|
||||
|
||||
# Log Unsloth version for future fastpaths for inference
|
||||
model.config.update({"unsloth_version" : __version__})
|
||||
|
|
|
|||
|
|
@ -135,6 +135,17 @@ def unsloth_save_model(
|
|||
temporary_location : str = "_unsloth_temporary_saved_buffers",
|
||||
maximum_memory_usage : float = 0.9,
|
||||
):
|
||||
if save_method == "merged_4bit":
|
||||
raise RuntimeError(
|
||||
"Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"\
|
||||
"to merge to GGUF or others later on. I suggest you to do this as a final step\n"\
|
||||
"if you're planning to do multiple saves.\n"\
|
||||
"If you are certain, change `save_method` to `merged_4bit_forced`."
|
||||
)
|
||||
elif save_method == "merged_4bit_forced":
|
||||
save_method = "merged_4bit"
|
||||
pass
|
||||
|
||||
save_pretrained_settings = dict(locals())
|
||||
for deletion in ("model", "tokenizer", "save_method", "temporary_location", "maximum_memory_usage"):
|
||||
del save_pretrained_settings[deletion]
|
||||
|
|
@ -457,6 +468,8 @@ pass
|
|||
def install_llama_cpp_make_non_blocking():
|
||||
env = { **os.environ, "LLAMA_CUBLAS": "1", }
|
||||
n_jobs = max(int(psutil.cpu_count()*1.5), 1)
|
||||
# Force make clean
|
||||
os.system("make clean -C llama.cpp")
|
||||
full_command = ["make", "-j", str(n_jobs), "-C", "llama.cpp"]
|
||||
run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
|
||||
return run_installer
|
||||
|
|
@ -487,8 +500,8 @@ pass
|
|||
|
||||
|
||||
def save_to_gguf(
|
||||
model_directory : str = "unsloth_finetuned_model",
|
||||
quantization_method : str = "fast_quantized",
|
||||
model_directory : str = "unsloth_finetuned_model",
|
||||
quantization_method : str = "fast_quantized",
|
||||
_run_installer = None, # Non blocking install of llama.cpp
|
||||
):
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
|
|
@ -566,7 +579,7 @@ def unsloth_save_pretrained_merged(
|
|||
self,
|
||||
save_directory : Union[str, os.PathLike],
|
||||
tokenizer = None,
|
||||
save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
|
||||
save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
|
||||
push_to_hub : bool = False,
|
||||
token : Optional[Union[str, bool]] = None,
|
||||
is_main_process : bool = True,
|
||||
|
|
|
|||
Loading…
Reference in a new issue