mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
2024 Release (#96)
* Fix tokenizer, dropout, bias for LoRA * Update loader.py * Fix LoRA downcasting * Update _utils.py * Saving to GGUF * fix * colab_quantize_to_gguf * move save modules * save module * Update __init__.py * Update save.py * Temp downgrade due to TRL issue * Fix up bugs * Faster saving + other changes * Update llama.py * Saving modules * spelling * Update llama.py * Update save.py * Update save.py * Update loader.py * Update llama.py * patch saving * Update save.py * Update save.py * Update save.py * patch saving * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * original_model * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * Update save.py * saving to RAM leakage? * Update save.py * new_save_directory * Update save.py * Update save.py * Update save.py * Update save.py * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml
This commit is contained in:
parent
9e2dec16fb
commit
d691516ab9
9 changed files with 1094 additions and 188 deletions
|
|
@ -65,8 +65,7 @@ try:
|
|||
libcuda_dirs()
|
||||
except:
|
||||
warnings.warn(
|
||||
"CUDA is not linked properly.\n"\
|
||||
"We shall run `ldconfig /usr/lib64-nvidia` to try to fix it."
|
||||
"Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
|
||||
)
|
||||
os.system("ldconfig /usr/lib64-nvidia")
|
||||
importlib.reload(bnb)
|
||||
|
|
|
|||
|
|
@ -41,12 +41,13 @@ def _rms_layernorm_forward(
|
|||
r += row_idx * r_row_stride
|
||||
|
||||
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
|
||||
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
||||
inv_var = 1 / tl.sqrt(row_var + eps)
|
||||
inv_var = 1.0 / tl.sqrt(row_var + eps)
|
||||
tl.store(r, inv_var)
|
||||
normed = X_row * inv_var
|
||||
normed = normed.to(W_row.dtype) # Exact copy from HF
|
||||
output = normed * W_row
|
||||
tl.store(Y + col_offsets, output, mask = mask)
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -25,10 +25,11 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
|||
mask = offsets < n_elements
|
||||
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
f_row = e_row / (1 + tl.exp(-e_row))
|
||||
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
||||
# h = f * g
|
||||
h_row = f_row * g_row
|
||||
|
||||
|
|
@ -53,12 +54,13 @@ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
|||
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
|
||||
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
e_row = tl.load(e + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
||||
|
||||
# f = e * sigmoid(e)
|
||||
se_row = 1 / (1 + tl.exp(-e_row))
|
||||
se_row = 1 / (1 + tl.exp(-e_row.to(tl.float32)))
|
||||
se_row = se_row.to(e_row.dtype) # Exact copy from HF
|
||||
# f = e * se
|
||||
f_row = e_row * se_row
|
||||
# h = f * g
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@
|
|||
|
||||
import torch
|
||||
from typing import Union, Optional, List, Any, Callable
|
||||
import numpy as np
|
||||
import warnings
|
||||
import gc
|
||||
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
||||
import bitsandbytes as bnb
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
import torch
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
logger,
|
||||
BaseModelOutputWithPast,
|
||||
|
|
@ -46,16 +45,13 @@ except:
|
|||
LlamaFlashAttention2 = LlamaAttention
|
||||
pass
|
||||
|
||||
from peft import PeftModelForCausalLM
|
||||
import gc
|
||||
import peft
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import types
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
|
||||
from transformers import set_seed as transformers_set_seed
|
||||
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
|
||||
from peft import PeftModelForCausalLM
|
||||
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
|
||||
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
|
||||
from ..save import patch_saving_functions
|
||||
|
||||
|
||||
def original_apply_qkv(self, X):
|
||||
|
|
@ -110,18 +106,15 @@ def LlamaAttention_fast_forward_inference(
|
|||
bsz, _, _ = hidden_states.size()
|
||||
K1, V1 = past_key_value
|
||||
|
||||
Wq = self.q_proj.weight
|
||||
Wk = self.k_proj.weight
|
||||
Wv = self.v_proj.weight
|
||||
Wo = self.o_proj.weight
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
Qn, Kn, Vn = original_apply_qkv(self, Xn)
|
||||
Qn = self.q_proj(Xn)
|
||||
Kn = self.k_proj(Xn)
|
||||
Vn = self.v_proj(Xn)
|
||||
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
||||
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
||||
|
|
@ -156,6 +149,28 @@ def LlamaAttention_fast_forward_inference(
|
|||
pass
|
||||
|
||||
|
||||
torch_silu = torch.nn.functional.silu
|
||||
def fast_mlp_inference(self, X):
|
||||
gate = self.gate_proj(X)
|
||||
up = self.up_proj(X)
|
||||
gate = torch_silu(gate, inplace = True)
|
||||
gate *= up
|
||||
X = self.down_proj(gate)
|
||||
return X
|
||||
pass
|
||||
|
||||
|
||||
def fast_rms_layernorm_inference(self, X):
|
||||
X = X.to(torch.float32)
|
||||
variance = X.square().mean(-1, keepdim = True)
|
||||
variance += self.variance_epsilon
|
||||
X *= variance.rsqrt_()
|
||||
X = X.to(residual.dtype)
|
||||
X *= self.weight
|
||||
return X
|
||||
pass
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
|
||||
def LlamaAttention_fast_forward(
|
||||
self,
|
||||
|
|
@ -287,28 +302,51 @@ def LlamaDecoderLayer_fast_forward(
|
|||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
bsz, q_len, hd = hidden_states.size()
|
||||
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
if (self.training):
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
else:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
hidden_states += residual
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = fast_mlp_inference(self.mlp, hidden_states)
|
||||
hidden_states += residual
|
||||
pass
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
|
|
@ -378,6 +416,7 @@ def LlamaModel_fast_forward(
|
|||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
pass
|
||||
|
||||
# We already handle KV cache position_ids ourselves.
|
||||
if (past_key_values_length != 0):
|
||||
|
|
@ -391,10 +430,12 @@ def LlamaModel_fast_forward(
|
|||
position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
|
||||
else:
|
||||
position_ids = None
|
||||
pass
|
||||
|
||||
if position_ids is not None:
|
||||
if position_ids.shape[0] != batch_size:
|
||||
position_ids = position_ids.repeat((batch_size, 1))
|
||||
pass
|
||||
|
||||
# embed positions
|
||||
if inputs_embeds is None:
|
||||
|
|
@ -403,19 +444,22 @@ def LlamaModel_fast_forward(
|
|||
# Ignore attention_mask
|
||||
if attention_mask is None:
|
||||
padding_mask = None
|
||||
elif self.training:
|
||||
attention_mask = None
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window = None if not hasattr(self.config, "sliding_window") else \
|
||||
self.config.sliding_window,
|
||||
sliding_window = getattr(self.config, "sliding_window"),
|
||||
)
|
||||
pass
|
||||
|
||||
|
|
@ -479,7 +523,11 @@ def LlamaModel_fast_forward(
|
|||
all_self_attns += (layer_outputs[1],)
|
||||
pass
|
||||
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
|
||||
if (self.training):
|
||||
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
|
||||
else:
|
||||
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
|
||||
pass
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
|
@ -665,6 +713,7 @@ class FastLlamaModel:
|
|||
bnb_4bit_quant_type = "nf4",
|
||||
bnb_4bit_compute_dtype = dtype,
|
||||
)
|
||||
pass
|
||||
|
||||
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
|
||||
# RoPE Scaling's max_position_embeddings must be updated
|
||||
|
|
@ -714,6 +763,7 @@ class FastLlamaModel:
|
|||
token = token,
|
||||
)
|
||||
pass
|
||||
patch_saving_functions(tokenizer)
|
||||
|
||||
# Fix up config for transformers uploading PEFT
|
||||
name = model.config._name_or_path
|
||||
|
|
@ -721,6 +771,7 @@ class FastLlamaModel:
|
|||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.config.update({"_name_or_path" : name})
|
||||
pass
|
||||
|
||||
# Log Unsloth version for future fastpaths for inference
|
||||
model.config.update({"unsloth_version" : __version__})
|
||||
|
||||
|
|
@ -751,7 +802,7 @@ class FastLlamaModel:
|
|||
correct_dtype = lm_head.weight.dtype
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (bnb.nn.Linear4bit, peft.tuners.lora.Linear4bit)):
|
||||
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
|
||||
weight = module.weight
|
||||
quant_state = weight.quant_state
|
||||
|
||||
|
|
@ -766,8 +817,10 @@ class FastLlamaModel:
|
|||
pass
|
||||
|
||||
# Clear deleted GPU items
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
pass
|
||||
|
||||
|
|
@ -782,11 +835,26 @@ class FastLlamaModel:
|
|||
lora_dropout = 0,
|
||||
bias = "none",
|
||||
layers_to_transform = None,
|
||||
layers_pattern = None,
|
||||
use_gradient_checkpointing = True,
|
||||
random_state = 3407,
|
||||
max_seq_length = 2048, # not used anymore
|
||||
use_rslora = False,
|
||||
init_lora_weights = True,
|
||||
loftq_config = None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError(
|
||||
"Unsloth: Your model already has LoRA adapters. No need to run this again!"
|
||||
)
|
||||
pass
|
||||
|
||||
import inspect
|
||||
signature = str(inspect.signature(LoraConfig))
|
||||
SUPPORTS_LOFTQ = "loftq_config" in signature
|
||||
SUPPORTS_RSLORA = "use_rslora" in signature
|
||||
|
||||
assert(max_seq_length <= model.max_seq_length)
|
||||
|
||||
if lora_dropout != 0:
|
||||
|
|
@ -794,11 +862,61 @@ class FastLlamaModel:
|
|||
f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
|
||||
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
|
||||
)
|
||||
pass
|
||||
|
||||
if bias != "none":
|
||||
logger.warning_once(
|
||||
f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
|
||||
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
|
||||
)
|
||||
pass
|
||||
|
||||
if not (type(init_lora_weights) is bool or \
|
||||
init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
|
||||
raise ValueError(
|
||||
'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
|
||||
)
|
||||
pass
|
||||
|
||||
if init_lora_weights == "loftq":
|
||||
|
||||
if not SUPPORTS_LOFTQ:
|
||||
import peft
|
||||
raise RuntimeError(
|
||||
f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
|
||||
"Please install PEFT 0.7.2 or higher.\n"\
|
||||
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
|
||||
)
|
||||
pass
|
||||
|
||||
if loftq_config is None:
|
||||
from peft import LoftQConfig
|
||||
logger.warning_once(
|
||||
f"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
|
||||
f"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
|
||||
)
|
||||
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
|
||||
pass
|
||||
|
||||
if hasattr(model.config, "quantization_config"):
|
||||
raise ValueError(
|
||||
"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
|
||||
"Reload your model without any quantization by setting `load_in_4bit = False`."
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
assert(type(use_rslora) is bool)
|
||||
if use_rslora:
|
||||
if not SUPPORTS_RSLORA:
|
||||
import peft
|
||||
raise RuntimeError(
|
||||
f"Unsloth: Your PEFT version of {peft.__version__} does not support use_rslora.\n"\
|
||||
"Please install PEFT 0.7.2 or higher.\n"\
|
||||
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
transformers_set_seed(random_state)
|
||||
|
||||
|
|
@ -810,16 +928,23 @@ class FastLlamaModel:
|
|||
pass
|
||||
|
||||
# Get LoRA
|
||||
lora_config = LoraConfig(
|
||||
r = r,
|
||||
lora_alpha = lora_alpha,
|
||||
target_modules = target_modules,
|
||||
lora_dropout = lora_dropout,
|
||||
bias = bias,
|
||||
task_type = TaskType.CAUSAL_LM,
|
||||
arguments = dict(
|
||||
r = r,
|
||||
lora_alpha = lora_alpha,
|
||||
target_modules = target_modules,
|
||||
lora_dropout = lora_dropout,
|
||||
bias = bias,
|
||||
task_type = TaskType.CAUSAL_LM,
|
||||
layers_to_transform = layers_to_transform,
|
||||
init_lora_weights = init_lora_weights,
|
||||
loftq_config = loftq_config,
|
||||
use_rslora = use_rslora,
|
||||
**kwargs,
|
||||
)
|
||||
if not SUPPORTS_LOFTQ: del arguments["loftq_config"]
|
||||
if not SUPPORTS_RSLORA: del arguments["use_rslora"]
|
||||
|
||||
lora_config = LoraConfig(**arguments)
|
||||
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
|
|
@ -828,10 +953,21 @@ class FastLlamaModel:
|
|||
)
|
||||
model = _get_peft_model(model, lora_config)
|
||||
|
||||
# Fix up config for transformers uploading PEFT
|
||||
name = model.peft_config["default"].base_model_name_or_path
|
||||
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
|
||||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.peft_config["default"].base_model_name_or_path = name
|
||||
pass
|
||||
# Add revision to enable future fast inference paths
|
||||
model.peft_config["default"].revision = f"unsloth"
|
||||
|
||||
# Do patching
|
||||
n_mlp = 0
|
||||
n_qkv = 0
|
||||
n_o = 0
|
||||
import types
|
||||
|
||||
if lora_dropout == 0 and bias == "none":
|
||||
for idx, layer in enumerate(model.model.model.layers):
|
||||
|
||||
|
|
@ -897,6 +1033,7 @@ class FastLlamaModel:
|
|||
f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
|
||||
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
|
||||
)
|
||||
patch_saving_functions(model)
|
||||
|
||||
# Patch cross entropy loss labels
|
||||
# Fixes https://github.com/unslothai/unsloth/issues/10
|
||||
|
|
|
|||
|
|
@ -16,16 +16,9 @@ from .llama import FastLlamaModel, logger
|
|||
from .mistral import FastMistralModel
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as transformers_version
|
||||
from peft import PeftConfig, PeftModel
|
||||
from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER
|
||||
|
||||
FOURBIT_MAPPER = \
|
||||
{
|
||||
"unsloth/mistral-7b-bnb-4bit" : "unsloth/mistral-7b",
|
||||
"unsloth/llama-2-7b-bnb-4bit" : "unsloth/llama-2-7b",
|
||||
"unsloth/llama-2-13b-bnb-4bit" : "unsloth/llama-13-7b",
|
||||
"unsloth/codellama-34b-bnb-4bit" : "codellama/CodeLlama-34b-hf",
|
||||
"unsloth/zephyr-sft-bnb-4bit" : "unsloth/zephyr-sft",
|
||||
"unsloth/tinyllama-bnb-4bit" : "unsloth/tinyllama",
|
||||
}
|
||||
|
||||
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
|
||||
major, minor = transformers_version.split(".")[:2]
|
||||
|
|
@ -34,6 +27,39 @@ SUPPORTS_FOURBIT = (major > 4) or (major == 4 and minor >= 37)
|
|||
del major, minor
|
||||
|
||||
|
||||
def _get_model_name(model_name, load_in_4bit = True):
|
||||
|
||||
if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER:
|
||||
model_name = INT_TO_FLOAT_MAPPER[model_name]
|
||||
logger.warning_once(
|
||||
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
|
||||
f"4bit loading.\nThe minimum required version is 4.37.\n"\
|
||||
f'Try `pip install "git+https://github.com/huggingface/transformers.git"`\n'\
|
||||
f"to obtain the latest transformers build, then restart this session.\n"\
|
||||
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
|
||||
)
|
||||
|
||||
elif not load_in_4bit and model_name in INT_TO_FLOAT_MAPPER:
|
||||
new_model_name = INT_TO_FLOAT_MAPPER[model_name]
|
||||
logger.warning_once(
|
||||
f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
|
||||
f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
|
||||
)
|
||||
model_name = new_model_name
|
||||
|
||||
elif load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER:
|
||||
new_model_name = FLOAT_TO_INT_MAPPER[model_name]
|
||||
logger.warning_once(
|
||||
f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
|
||||
f"We shall load `{new_model_name}` for 4x faster loading."
|
||||
)
|
||||
model_name = new_model_name
|
||||
pass
|
||||
|
||||
return model_name
|
||||
pass
|
||||
|
||||
|
||||
class FastLanguageModel(FastLlamaModel):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
|
|
@ -47,25 +73,27 @@ class FastLanguageModel(FastLlamaModel):
|
|||
fix_tokenizer = True,
|
||||
*args, **kwargs,
|
||||
):
|
||||
if not SUPPORTS_FOURBIT and model_name in FOURBIT_MAPPER:
|
||||
model_name = FOURBIT_MAPPER[model_name]
|
||||
logger.warning_once(
|
||||
f"Unsloth: Your transformers version of {transformers_version} does not support native "\
|
||||
f"4bit loading.\nThe minimum required version is 4.37.\n"\
|
||||
f'Try `pip install "git+https://github.com/huggingface/transformers.git"`\n'\
|
||||
f"to obtain the latest transformers build, then restart this session.\n"\
|
||||
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
|
||||
)
|
||||
elif not load_in_4bit and model_name in FOURBIT_MAPPER:
|
||||
new_model_name = FOURBIT_MAPPER[model_name]
|
||||
logger.warning_once(
|
||||
f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
|
||||
f"`load_in_4bit = False`. We shall load `{new_model_name}` instead."
|
||||
)
|
||||
model_name = new_model_name
|
||||
old_model_name = model_name
|
||||
model_name = _get_model_name(model_name, load_in_4bit)
|
||||
|
||||
# First check if it's a normal model via AutoConfig
|
||||
is_peft = False
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(model_name, token = token)
|
||||
is_peft = False
|
||||
except:
|
||||
try:
|
||||
# Most likely a PEFT model
|
||||
peft_config = PeftConfig.from_pretrained(model_name, token = token)
|
||||
except:
|
||||
raise RuntimeError(f"Unsloth: `{model_name}` is not a full model or a PEFT model.")
|
||||
|
||||
# Check base model again for PEFT
|
||||
model_name = _get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
|
||||
model_config = AutoConfig.from_pretrained(model_name, token = token)
|
||||
is_peft = True
|
||||
pass
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model_name)
|
||||
model_type = model_config.model_type
|
||||
|
||||
if model_type == "llama": dispatch_model = FastLlamaModel
|
||||
|
|
@ -75,8 +103,9 @@ class FastLanguageModel(FastLlamaModel):
|
|||
f"Unsloth: {model_name} not supported yet!\n"\
|
||||
"Make an issue to https://github.com/unslothai/unsloth!",
|
||||
)
|
||||
pass
|
||||
|
||||
return dispatch_model.from_pretrained(
|
||||
model, tokenizer = dispatch_model.from_pretrained(
|
||||
model_name = model_name,
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = dtype,
|
||||
|
|
@ -87,5 +116,30 @@ class FastLanguageModel(FastLlamaModel):
|
|||
fix_tokenizer = fix_tokenizer,
|
||||
*args, **kwargs,
|
||||
)
|
||||
|
||||
if load_in_4bit:
|
||||
# Fix up bitsandbytes config
|
||||
quantization_config = \
|
||||
{
|
||||
# Sometimes torch_dtype is not a string!!
|
||||
"bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"],
|
||||
"bnb_4bit_quant_type" : "nf4",
|
||||
"bnb_4bit_use_double_quant" : True,
|
||||
"llm_int8_enable_fp32_cpu_offload" : False,
|
||||
"llm_int8_has_fp16_weight" : False,
|
||||
"llm_int8_skip_modules" : "null",
|
||||
"llm_int8_threshold" : 6.0,
|
||||
"load_in_4bit" : True,
|
||||
"load_in_8bit" : False,
|
||||
"quant_method" : "bitsandbytes",
|
||||
}
|
||||
model.config.update({"quantization_config" : quantization_config})
|
||||
pass
|
||||
|
||||
if is_peft:
|
||||
# Now add PEFT adapters
|
||||
model = PeftModel.from_pretrained(model, old_model_name)
|
||||
pass
|
||||
return model, tokenizer
|
||||
pass
|
||||
pass
|
||||
|
|
|
|||
56
unsloth/models/mapper.py
Normal file
56
unsloth/models/mapper.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = [
|
||||
"INT_TO_FLOAT_MAPPER",
|
||||
"FLOAT_TO_INT_MAPPER",
|
||||
]
|
||||
|
||||
__INT_TO_FLOAT_MAPPER = \
|
||||
{
|
||||
"unsloth/mistral-7b-bnb-4bit" : (
|
||||
"unsloth/mistral-7b",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
),
|
||||
"unsloth/llama-2-7b-bnb-4bit" : (
|
||||
"unsloth/llama-2-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
),
|
||||
"unsloth/llama-2-13b-bnb-4bit" : (
|
||||
"unsloth/llama-13-7b",
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
),
|
||||
"unsloth/codellama-34b-bnb-4bit" : (
|
||||
"codellama/CodeLlama-34b-hf",
|
||||
),
|
||||
"unsloth/zephyr-sft-bnb-4bit" : (
|
||||
"unsloth/zephyr-sft",
|
||||
"alignment-handbook/zephyr-7b-sft-full",
|
||||
),
|
||||
"unsloth/tinyllama-bnb-4bit" : (
|
||||
"unsloth/tinyllama",
|
||||
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
|
||||
),
|
||||
}
|
||||
|
||||
INT_TO_FLOAT_MAPPER = {}
|
||||
FLOAT_TO_INT_MAPPER = {}
|
||||
|
||||
for key, values in __INT_TO_FLOAT_MAPPER.items():
|
||||
INT_TO_FLOAT_MAPPER[key] = values[0]
|
||||
|
||||
for value in values:
|
||||
FLOAT_TO_INT_MAPPER[value] = key
|
||||
pass
|
||||
pass
|
||||
|
|
@ -343,6 +343,7 @@ class FastMistralModel(FastLlamaModel):
|
|||
token = token,
|
||||
)
|
||||
pass
|
||||
patch_saving_functions(tokenizer)
|
||||
|
||||
# Fix up config for transformers uploading PEFT
|
||||
name = model.config._name_or_path
|
||||
|
|
@ -350,6 +351,7 @@ class FastMistralModel(FastLlamaModel):
|
|||
name = name[:len(name) - len("-bnb-4bit")]
|
||||
model.config.update({"_name_or_path" : name})
|
||||
pass
|
||||
|
||||
# Log Unsloth version for future fastpaths for inference
|
||||
model.config.update({"unsloth_version" : __version__})
|
||||
|
||||
|
|
|
|||
865
unsloth/save.py
865
unsloth/save.py
|
|
@ -12,24 +12,26 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from peft import PeftModelForCausalLM
|
||||
from collections import OrderedDict
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import gc
|
||||
import os
|
||||
from tqdm import tqdm as ProgressBar
|
||||
import shutil
|
||||
from typing import Optional, Callable, Union
|
||||
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
|
||||
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
|
||||
from typing import Optional, Callable, Union, List
|
||||
import torch
|
||||
import os
|
||||
import pickle
|
||||
import gc
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters
|
||||
import subprocess
|
||||
import psutil
|
||||
|
||||
__all__ = [
|
||||
"print_quantization_methods",
|
||||
"unsloth_save_model",
|
||||
#"colab_quantize_to_gguf",
|
||||
"save_to_gguf",
|
||||
"patch_saving_functions",
|
||||
]
|
||||
|
||||
|
||||
LLAMA_WEIGHTS = (
|
||||
"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj",
|
||||
"mlp.gate_proj", "mlp.up_proj", "mlp.down_proj",
|
||||
|
|
@ -41,25 +43,36 @@ LLAMA_LAYERNORMS = (
|
|||
# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
|
||||
ALLOWED_QUANTS = \
|
||||
{
|
||||
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
|
||||
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_s" : "Uses Q3_K for all tensors",
|
||||
"q4_0" : "Original quant method, 4-bit.",
|
||||
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
|
||||
"q4_k_m" : "Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
|
||||
"q4_k_s" : "Uses Q4_K for all tensors",
|
||||
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
|
||||
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
|
||||
"q5_k_m" : "Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
|
||||
"q5_k_s" : "Uses Q5_K for all tensors",
|
||||
"q6_k" : "Uses Q8_K for all tensors",
|
||||
"q8_0" : "Almost indistinguishable from float16. High resource use and slow. Not recommended for most users.",
|
||||
"not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
|
||||
"fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
|
||||
"quantized" : "Recommended. Slow conversion. Fast inference, small files.",
|
||||
"f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
|
||||
"f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
|
||||
"q8_0" : "Fast conversion. High resource use, but generally acceptable.",
|
||||
"q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
|
||||
"q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
|
||||
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
|
||||
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_s" : "Uses Q3_K for all tensors",
|
||||
"q4_0" : "Original quant method, 4-bit.",
|
||||
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
|
||||
"q4_k_s" : "Uses Q4_K for all tensors",
|
||||
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
|
||||
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
|
||||
"q5_k_s" : "Uses Q5_K for all tensors",
|
||||
"q6_k" : "Uses Q8_K for all tensors",
|
||||
}
|
||||
|
||||
def print_quantization_methods():
|
||||
for key, value in ALLOWED_QUANTS.items():
|
||||
print(f'"{key}" ==> {value}')
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def _merge_lora(layer, name):
|
||||
if isinstance(layer, (bnb.nn.Linear4bit, peft.tuners.lora.Linear4bit)):
|
||||
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit)):
|
||||
# Is LoRA so we need to merge!
|
||||
W, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
|
||||
|
|
@ -75,100 +88,362 @@ def _merge_lora(layer, name):
|
|||
pass
|
||||
|
||||
|
||||
def fast_save_pickle(shard, name):
|
||||
# Use this if # CPUs is <= 2
|
||||
print(f"Unsloth: Saving {name}...")
|
||||
torch.save(
|
||||
shard,
|
||||
name,
|
||||
pickle_module = pickle,
|
||||
pickle_protocol = pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
return
|
||||
pass
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def unsloth_save_model(
|
||||
model,
|
||||
tokenizer,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
state_dict: Optional[dict] = None,
|
||||
save_function: Callable = torch.save,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: Union[int, str] = "7GB",
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
save_peft_format: bool = True,
|
||||
temporary_location = "_unsloth_temporary_saved_buffers",
|
||||
**kwargs,
|
||||
):
|
||||
logger.warning_once(
|
||||
"Unsloth: `unsloth_save_model` is still in development mode.\n"\
|
||||
"If anything errors or breaks, please file a ticket on Github.\n"\
|
||||
"Also, if you used this successfully, please tell us on Discord!"
|
||||
)
|
||||
save_directory : Union[str, os.PathLike],
|
||||
save_method : str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
|
||||
push_to_hub : bool = False,
|
||||
token : Optional[Union[str, bool]] = None,
|
||||
is_main_process : bool = True,
|
||||
state_dict : Optional[dict] = None,
|
||||
save_function : Callable = torch.save,
|
||||
max_shard_size : Union[int, str] = "5GB",
|
||||
safe_serialization : bool = True,
|
||||
variant : Optional[str] = None,
|
||||
save_peft_format : bool = True,
|
||||
|
||||
# Push to hub
|
||||
use_temp_dir : Optional[bool] = None,
|
||||
commit_message : Optional[str] = None,
|
||||
private : Optional[bool] = None,
|
||||
create_pr : bool = False,
|
||||
revision : str = None,
|
||||
commit_description : str = None,
|
||||
tags : List[str] = None,
|
||||
|
||||
# Our functions
|
||||
temporary_location : str = "_unsloth_temporary_saved_buffers",
|
||||
maximum_memory_usage : float = 0.9,
|
||||
):
|
||||
save_pretrained_settings = dict(locals())
|
||||
for deletion in ("model", "tokenizer", "save_method", "temporary_location", "maximum_memory_usage"):
|
||||
del save_pretrained_settings[deletion]
|
||||
pass
|
||||
import re
|
||||
|
||||
assert(maximum_memory_usage > 0 and maximum_memory_usage <= 0.95)
|
||||
|
||||
# Clean memory up first
|
||||
for _ in range(3):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
pass
|
||||
|
||||
save_method = save_method.lower().replace(" ", "_")
|
||||
if save_method != "lora" and save_method != "merged_16bit" and save_method != "merged_4bit":
|
||||
raise RuntimeError(
|
||||
"Unsloth: You must select one of 3 options when saving models:\n"\
|
||||
'"lora" ==> This is the fastest and easiet. Just saves LoRA modules.\n'\
|
||||
'"merged_16bit" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\n'\
|
||||
'"merged_4bit" ==> This merges LoRA weights and saves to 4bit. Useful for DPO / inference.'
|
||||
)
|
||||
pass
|
||||
|
||||
if save_method == "merged_4bit":
|
||||
print("Unsloth: Merging 4bit and LoRA weights to 4bit...")
|
||||
print("This might take 5 minutes...")
|
||||
model = model.merge_and_unload()
|
||||
print("Done.")
|
||||
pass
|
||||
|
||||
if tags is not None:
|
||||
assert(isinstance(tags, (list, tuple)))
|
||||
tags = list(tags) + ["unsloth",]
|
||||
else:
|
||||
tags = ["unsloth",]
|
||||
pass
|
||||
save_pretrained_settings["tags"] = tags
|
||||
|
||||
if (save_method == "lora") and push_to_hub:
|
||||
if token is None:
|
||||
raise RuntimeError(
|
||||
"Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\n"\
|
||||
"Go to https://huggingface.co/settings/tokens."
|
||||
)
|
||||
pass
|
||||
|
||||
model.push_to_hub(
|
||||
repo_id = save_directory,
|
||||
use_temp_dir = use_temp_dir,
|
||||
commit_message = commit_message,
|
||||
private = private,
|
||||
token = token,
|
||||
max_shard_size = max_shard_size,
|
||||
create_pr = create_pr,
|
||||
safe_serialization = safe_serialization,
|
||||
revision = revision,
|
||||
commit_description = commit_description,
|
||||
tags = tags,
|
||||
)
|
||||
if tokenizer is not None:
|
||||
tokenizer.push_to_hub(
|
||||
repo_id = save_directory,
|
||||
use_temp_dir = use_temp_dir,
|
||||
commit_message = commit_message,
|
||||
private = private,
|
||||
token = token,
|
||||
max_shard_size = max_shard_size,
|
||||
create_pr = create_pr,
|
||||
safe_serialization = safe_serialization,
|
||||
revision = revision,
|
||||
commit_description = commit_description,
|
||||
tags = tags,
|
||||
)
|
||||
pass
|
||||
return save_directory
|
||||
pass
|
||||
|
||||
# If push_to_hub, we must remove the .../ part of a repo
|
||||
if push_to_hub and "/" in save_directory:
|
||||
|
||||
new_save_directory = save_directory[save_directory.find("/"):]
|
||||
|
||||
logger.warning_once(
|
||||
f"Unsloth: You are pushing to hub, but you passed your HF username.\n"\
|
||||
f"We shall truncate {save_directory} to {new_save_directory}"
|
||||
)
|
||||
|
||||
save_pretrained_settings["save_directory"] = new_save_directory
|
||||
save_directory = new_save_directory
|
||||
pass
|
||||
|
||||
if (save_method == "merged_4bit") or (save_method == "lora") or (
|
||||
not hasattr(model, "model") or \
|
||||
not hasattr(model.model, "model") or \
|
||||
not hasattr(model.model.model, "layers")
|
||||
):
|
||||
# Do general saving
|
||||
|
||||
# Edit save_pretrained_settings
|
||||
# [TODO] _create_repo has errors due to **kwargs getting accepted
|
||||
for deletion in \
|
||||
("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",):
|
||||
del save_pretrained_settings[deletion]
|
||||
pass
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["unsloth",])
|
||||
|
||||
if tokenizer is not None:
|
||||
print("Unsloth: Saving tokenizer...", end = "")
|
||||
tokenizer.save_pretrained(**save_pretrained_settings)
|
||||
print(" Done.")
|
||||
else:
|
||||
print()
|
||||
|
||||
print("Unsloth: Saving model...", end = "")
|
||||
if save_method != "lora": print(" This might take 10 minutes for Llama-7b...", end = "")
|
||||
|
||||
model.save_pretrained(**save_pretrained_settings)
|
||||
print(" Done.")
|
||||
return save_directory
|
||||
pass
|
||||
|
||||
print("Unsloth: Merging 4bit and LoRA weights to 16bit...")
|
||||
|
||||
# Determine max RAM usage minus sharding
|
||||
max_ram = psutil.virtual_memory().available
|
||||
sharded_ram_usage = 5 * 1024 * 1024 * 1024
|
||||
if type(max_shard_size) is str:
|
||||
gb_found = re.match("([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE)
|
||||
mb_found = re.match("([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE)
|
||||
if gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024
|
||||
elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024
|
||||
elif type(max_shard_size) is int:
|
||||
sharded_ram_usage = sharded_ram_usage
|
||||
pass
|
||||
|
||||
# Switch to our fast saving modules if it's a slow PC!
|
||||
n_cpus = psutil.cpu_count(logical = False)
|
||||
|
||||
if safe_serialization is None:
|
||||
safe_serialization = True
|
||||
save_pretrained_settings["safe_serialization"] = safe_serialization
|
||||
|
||||
elif safe_serialization and (n_cpus <= 2):
|
||||
logger.warning_once(
|
||||
f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"\
|
||||
f"We shall switch to Pytorch saving, which will take 3 minutes and not 30 minutes.\n"\
|
||||
f"To force `safe_serialization`, set it to None instead.",
|
||||
)
|
||||
safe_serialization = False
|
||||
save_function = fast_save_pickle
|
||||
save_pretrained_settings["safe_serialization"] = safe_serialization
|
||||
save_pretrained_settings["save_function"] = save_function
|
||||
pass
|
||||
|
||||
# Only safe_serialization uses more RAM
|
||||
if safe_serialization:
|
||||
max_ram -= sharded_ram_usage
|
||||
else:
|
||||
max_ram -= sharded_ram_usage*0.25 # Uses much less
|
||||
pass
|
||||
|
||||
max_ram = int(max(0, max_ram) * maximum_memory_usage)
|
||||
print(f"Unsloth: Will use up to "\
|
||||
f"{round(max_ram/1024/1024/1024, 2)} out of "\
|
||||
f"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving.")
|
||||
|
||||
# Max directory for disk saving
|
||||
if not os.path.exists(temporary_location):
|
||||
os.makedirs(temporary_location)
|
||||
pass
|
||||
|
||||
assert(hasattr(model, "model"))
|
||||
assert(hasattr(model.model, "model"))
|
||||
assert(hasattr(model.model.model, "layers"))
|
||||
|
||||
# HF also uses a OrderedDict
|
||||
from collections import OrderedDict
|
||||
state_dict = OrderedDict()
|
||||
state_dict["model.embed_tokens.weight"] = model.model.model.embed_tokens.weight
|
||||
state_dict["model.embed_tokens.weight"] = model.model.model.embed_tokens.weight.data
|
||||
|
||||
print("Unsloth: Merging 4bit and LoRA weights to 16bit...")
|
||||
max_vram = int(torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage)
|
||||
|
||||
from tqdm import tqdm as ProgressBar
|
||||
for j, layer in enumerate(ProgressBar(model.model.model.layers)):
|
||||
for item in LLAMA_WEIGHTS:
|
||||
proj = eval(f"layer.{item}")
|
||||
name = f"model.layers.{j}.{item}.weight"
|
||||
W = _merge_lora(proj, name)
|
||||
filename = os.path.join(temporary_location, f"{name}.pt")
|
||||
torch.save(W, filename)
|
||||
state_dict[name] = torch.load(filename, map_location = "cpu", mmap = True)
|
||||
|
||||
if (torch.cuda.memory_allocated() + W.nbytes) < max_vram:
|
||||
# Save to GPU memory
|
||||
state_dict[name] = W
|
||||
# elif (max_ram - W.nbytes) > 0:
|
||||
# # Save to CPU memory
|
||||
# logger.warning_once(f"We will save to RAM and not VRAM now.")
|
||||
# state_dict[name] = W.to("cpu", non_blocking = True)
|
||||
# max_ram = max(max_ram - W.nbytes, 0)
|
||||
else:
|
||||
# Save to Disk
|
||||
logger.warning_once(f"We will save to Disk and not RAM now.")
|
||||
filename = os.path.join(temporary_location, f"{name}.pt")
|
||||
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
|
||||
state_dict[name] = torch.load(filename, map_location = "cpu", mmap = True)
|
||||
pass
|
||||
for item in LLAMA_LAYERNORMS:
|
||||
state_dict[f"model.layers.{j}.{item}.weight"] = eval(f"layer.{item}.weight")
|
||||
state_dict[f"model.layers.{j}.{item}.weight"] = eval(f"layer.{item}.weight.data")
|
||||
pass
|
||||
pass
|
||||
|
||||
state_dict["model.norm.weight"] = model.model.model.norm.weight
|
||||
state_dict["lm_head.weight"] = model.model.lm_head.weight
|
||||
state_dict["model.norm.weight"] = model.model.model.norm.weight.data
|
||||
state_dict["lm_head.weight"] = model.model.lm_head.weight.data
|
||||
|
||||
print("Unsloth: Saving tokenizer...")
|
||||
tokenizer.save_pretrained(
|
||||
save_directory = save_directory,
|
||||
is_main_process = is_main_process,
|
||||
state_dict = state_dict,
|
||||
save_function = save_function,
|
||||
push_to_hub = push_to_hub,
|
||||
max_shard_size = max_shard_size,
|
||||
safe_serialization = safe_serialization,
|
||||
variant = variant,
|
||||
token = token,
|
||||
save_peft_format = save_peft_format,
|
||||
)
|
||||
# All tensors MUST be type torch.Tensor and not torch.nn.parameter.Parameter
|
||||
for key, value in state_dict.items():
|
||||
if hasattr(value, "data"): state_dict[key] = value = value.data
|
||||
if type(value) is not torch.Tensor:
|
||||
logger.warning_once(f"Unsloth: {key} is not a Tensor but a {type(value)}.")
|
||||
pass
|
||||
pass
|
||||
|
||||
print("Unsloth: Saving model. This will take 5 minutes for Llama-7b...")
|
||||
model.model.save_pretrained(
|
||||
save_directory = save_directory,
|
||||
is_main_process = is_main_process,
|
||||
state_dict = state_dict,
|
||||
save_function = save_function,
|
||||
push_to_hub = push_to_hub,
|
||||
max_shard_size = max_shard_size,
|
||||
safe_serialization = safe_serialization,
|
||||
variant = variant,
|
||||
token = token,
|
||||
save_peft_format = save_peft_format,
|
||||
)
|
||||
# Edit save_pretrained_settings
|
||||
# [TODO] _create_repo has errors due to **kwargs getting accepted
|
||||
save_pretrained_settings["state_dict"] = state_dict
|
||||
for deletion in \
|
||||
("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",):
|
||||
del save_pretrained_settings[deletion]
|
||||
pass
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["unsloth",])
|
||||
|
||||
if tokenizer is not None:
|
||||
print("Unsloth: Saving tokenizer...", end = "")
|
||||
tokenizer.save_pretrained(**save_pretrained_settings)
|
||||
print(" Done.")
|
||||
else:
|
||||
print()
|
||||
|
||||
print("Unsloth: Saving model... This might take 5 minutes for Llama-7b...")
|
||||
model.model.save_pretrained(**save_pretrained_settings)
|
||||
print("Done.")
|
||||
|
||||
save_pretrained_settings["state_dict"] = None
|
||||
|
||||
# for j, (key, value) in enumerate(state_dict.items()):
|
||||
# state_dict[key] = None
|
||||
# if j % 10 == 0:
|
||||
# torch.cuda.empty_cache()
|
||||
# gc.collect()
|
||||
# pass
|
||||
# pass
|
||||
# state_dict = None
|
||||
# del state_dict
|
||||
# torch.cuda.empty_cache()
|
||||
# gc.collect()
|
||||
|
||||
# Remove temporary location
|
||||
import shutil
|
||||
shutil.rmtree(temporary_location)
|
||||
|
||||
# for _ in range(3):
|
||||
# torch.cuda.empty_cache()
|
||||
# gc.collect()
|
||||
return save_directory
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
def _colab_quantize_to_gguf(save_directory, quantization_method = "q4_k_m"):
|
||||
def install_llama_cpp_clone_non_blocking():
|
||||
full_command = ["git", "clone", "https://github.com/ggerganov/llama.cpp"]
|
||||
run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
|
||||
return run_installer
|
||||
pass
|
||||
|
||||
logger.warning_once(
|
||||
"Unsloth: `colab_quantize_to_gguf` is still in development mode.\n"\
|
||||
"If anything errors or breaks, please file a ticket on Github.\n"\
|
||||
"Also, if you used this successfully, please tell us on Discord!"
|
||||
)
|
||||
|
||||
def install_llama_cpp_make_non_blocking():
|
||||
env = { **os.environ, "LLAMA_CUBLAS": "1", }
|
||||
n_jobs = max(int(psutil.cpu_count()*1.5), 1)
|
||||
full_command = ["make", "-j", str(n_jobs), "-C", "llama.cpp"]
|
||||
run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
|
||||
return run_installer
|
||||
pass
|
||||
|
||||
|
||||
def install_python_non_blocking(packages = []):
|
||||
full_command = ["pip", "install"] + packages
|
||||
run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
|
||||
return run_installer
|
||||
pass
|
||||
|
||||
|
||||
def install_llama_cpp_blocking():
|
||||
commands = [
|
||||
"git clone https://github.com/ggerganov/llama.cpp",
|
||||
f"cd llama.cpp && make clean && LLAMA_CUBLAS=1 make -j {psutil.cpu_count()*2}",
|
||||
"pip install gguf protobuf",
|
||||
]
|
||||
if os.path.exists("llama.cpp"): return
|
||||
for command in commands:
|
||||
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, bufsize = 1) as sp:
|
||||
for line in sp.stdout:
|
||||
print(line.decode("utf-8"), flush = True, end = "")
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def save_to_gguf(
|
||||
model_directory : str = "unsloth_finetuned_model",
|
||||
quantization_method : str = "fast_quantized",
|
||||
_run_installer = None, # Non blocking install of llama.cpp
|
||||
):
|
||||
from transformers.models.llama.modeling_llama import logger
|
||||
|
||||
if quantization_method == "not_quantized": quantization_method = "f16"
|
||||
elif quantization_method == "fast_quantized": quantization_method = "q8_0"
|
||||
elif quantization_method == "quantized": quantization_method = "q4_k_m"
|
||||
elif quantization_method is None: quantization_method = "q8_0"
|
||||
|
||||
if quantization_method not in ALLOWED_QUANTS.keys():
|
||||
error = f"Unsloth: Quant method = [{quantization_method}] not supported. Choose from below:\n"
|
||||
|
|
@ -181,27 +456,409 @@ def _colab_quantize_to_gguf(save_directory, quantization_method = "q4_k_m"):
|
|||
f"==((====))== Unsloth: Conversion from QLoRA to GGUF information\n"\
|
||||
f" \\\ /| [0] Installing llama.cpp will take 3 minutes.\n"\
|
||||
f"O^O/ \_/ \\ [1] Converting HF to GUUF 16bits will take 3 minutes.\n"\
|
||||
f"\ / [2] Converting GGUF 16bits to q4_k_m will take 20 minutes.\n"\
|
||||
f"\ / [2] Converting GGUF 16bits to {quantization_method} will take 20 minutes.\n"\
|
||||
f' "-____-" In total, you will have to wait around 26 minutes.\n'
|
||||
print(print_info)
|
||||
|
||||
if not os.path.exists("llama.cpp"):
|
||||
print("Unsloth: [0] Installing llama.cpp. This will take 3 minutes...")
|
||||
!git clone https://github.com/ggerganov/llama.cpp
|
||||
!cd llama.cpp && make clean && LLAMA_CUBLAS=1 make -j
|
||||
!pip install gguf protobuf
|
||||
print("Unsloth: [0] Installing llama.cpp. This will take 3 minutes...")
|
||||
if _run_installer is not None:
|
||||
_run_installer.wait()
|
||||
else:
|
||||
install_llama_cpp_blocking()
|
||||
pass
|
||||
|
||||
print("Unsloth: [1] Converting HF into GGUF 16bit. This will take 3 minutes...")
|
||||
!python llama.cpp/convert.py {save_directory} \
|
||||
--outfile {save_directory}-unsloth.gguf \
|
||||
--outtype f16
|
||||
print("Unsloth: [1] Converting HF into GGUF format. This will take 3 minutes...")
|
||||
first_conversion = "f16"
|
||||
if quantization_method == "f32": first_conversion = "f32"
|
||||
elif quantization_method == "f16": first_conversion = "f16"
|
||||
elif quantization_method == "q8_0": first_conversion = "q8_0"
|
||||
|
||||
print("Unsloth: [2] Converting GGUF 16bit into q4_k_m. This will take 20 minutes...")
|
||||
final_location = f"./{save_directory}-{quantization_method}-unsloth.gguf"
|
||||
!./llama.cpp/quantize ./{save_directory}-unsloth.gguf \
|
||||
{final_location} {quantization_method}
|
||||
n_cpus = psutil.cpu_count()*2
|
||||
# Concurrency from https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model
|
||||
|
||||
final_location = f"./{model_directory}-unsloth.{first_conversion.upper()}.gguf"
|
||||
|
||||
print(f"Unsloth: Output location: {final_location}")
|
||||
command = f"python llama.cpp/convert.py {model_directory} "\
|
||||
f"--outfile {final_location} "\
|
||||
f"--outtype {first_conversion} --concurrency {n_cpus}"
|
||||
|
||||
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, bufsize = 1) as sp:
|
||||
for line in sp.stdout:
|
||||
print(line.decode("utf-8"), flush = True, end = "")
|
||||
pass
|
||||
|
||||
print(f"Unsloth: Conversion completed! Output location: {final_location}")
|
||||
|
||||
if quantization_method != first_conversion:
|
||||
old_location = final_location
|
||||
print(f"Unsloth: [2] Converting GGUF 16bit into {quantization_method}. This will take 20 minutes...")
|
||||
final_location = f"./{model_directory}-unsloth.{quantization_method.upper()}.gguf"
|
||||
|
||||
command = f"./llama.cpp/quantize {old_location} "\
|
||||
f"{final_location} {quantization_method} {n_cpus}"
|
||||
|
||||
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, bufsize = 1) as sp:
|
||||
for line in sp.stdout:
|
||||
print(line.decode("utf-8"), flush = True, end = "")
|
||||
pass
|
||||
print(f"Unsloth: Conversion completed! Output location: {final_location}")
|
||||
pass
|
||||
|
||||
return final_location
|
||||
pass
|
||||
|
||||
|
||||
def unsloth_save_pretrained_merged(
|
||||
self,
|
||||
save_directory : Union[str, os.PathLike],
|
||||
save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
|
||||
push_to_hub : bool = False,
|
||||
token : Optional[Union[str, bool]] = None,
|
||||
is_main_process : bool = True,
|
||||
state_dict : Optional[dict] = None,
|
||||
save_function : Callable = torch.save,
|
||||
max_shard_size : Union[int, str] = "5GB",
|
||||
safe_serialization : bool = True,
|
||||
variant : Optional[str] = None,
|
||||
save_peft_format : bool = True,
|
||||
tags : List[str] = None,
|
||||
temporary_location : str = "_unsloth_temporary_saved_buffers",
|
||||
maximum_memory_usage : float = 0.85,
|
||||
):
|
||||
"""
|
||||
Same as .save_pretrained(...) except 4bit weights are auto
|
||||
converted to float16 with as few overhead as possible.
|
||||
|
||||
Choose for `save_method` to be either:
|
||||
1. `merged_16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
|
||||
2. `merged_4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
|
||||
3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
|
||||
"""
|
||||
arguments = dict(locals())
|
||||
arguments["model"] = self
|
||||
arguments["tokenizer"] = None
|
||||
del arguments["self"]
|
||||
unsloth_save_model(**arguments)
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
pass
|
||||
|
||||
|
||||
def unsloth_push_to_hub_merged(
|
||||
self,
|
||||
repo_id : str,
|
||||
save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
|
||||
use_temp_dir : Optional[bool] = None,
|
||||
commit_message : Optional[str] = None,
|
||||
private : Optional[bool] = None,
|
||||
token : Union[bool, str, None] = None,
|
||||
max_shard_size : Union[int, str, None] = "5GB",
|
||||
create_pr : bool = False,
|
||||
safe_serialization : bool = True,
|
||||
revision : str = None,
|
||||
commit_description : str = None,
|
||||
tags : Optional[List[str]] = None,
|
||||
temporary_location : str = "_unsloth_temporary_saved_buffers",
|
||||
maximum_memory_usage : float = 0.85,
|
||||
):
|
||||
"""
|
||||
Same as .push_to_hub(...) except 4bit weights are auto
|
||||
converted to float16 with as few overhead as possible.
|
||||
|
||||
Choose for `save_method` to be either:
|
||||
1. `merged_16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
|
||||
2. `merged_4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
|
||||
3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
|
||||
"""
|
||||
arguments = dict(locals())
|
||||
arguments["model"] = self
|
||||
arguments["tokenizer"] = None
|
||||
arguments["save_directory"] = repo_id
|
||||
arguments["push_to_hub"] = True
|
||||
del arguments["self"]
|
||||
del arguments["repo_id"]
|
||||
unsloth_save_model(**arguments)
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
pass
|
||||
|
||||
|
||||
def unsloth_save_pretrained_gguf(
|
||||
self,
|
||||
save_directory : Union[str, os.PathLike],
|
||||
tokenizer = None,
|
||||
quantization_method : str = "fast_quantized",
|
||||
push_to_hub : bool = False,
|
||||
token : Optional[Union[str, bool]] = None,
|
||||
is_main_process : bool = True,
|
||||
state_dict : Optional[dict] = None,
|
||||
save_function : Callable = torch.save,
|
||||
max_shard_size : Union[int, str] = "5GB",
|
||||
safe_serialization : bool = True,
|
||||
variant : Optional[str] = None,
|
||||
save_peft_format : bool = True,
|
||||
tags : List[str] = None,
|
||||
temporary_location : str = "_unsloth_temporary_saved_buffers",
|
||||
maximum_memory_usage : float = 0.85,
|
||||
):
|
||||
"""
|
||||
Same as .save_pretrained(...) except 4bit weights are auto
|
||||
converted to float16 then converted to GGUF / llama.cpp format.
|
||||
|
||||
Choose for `quantization_method` to be:
|
||||
"not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
|
||||
"fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
|
||||
"quantized" : "Recommended. Slow conversion. Fast inference, small files.",
|
||||
"f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
|
||||
"f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
|
||||
"q8_0" : "Fast conversion. High resource use, but generally acceptable.",
|
||||
"q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
|
||||
"q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
|
||||
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
|
||||
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_s" : "Uses Q3_K for all tensors",
|
||||
"q4_0" : "Original quant method, 4-bit.",
|
||||
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
|
||||
"q4_k_s" : "Uses Q4_K for all tensors",
|
||||
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
|
||||
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
|
||||
"q5_k_s" : "Uses Q5_K for all tensors",
|
||||
"q6_k" : "Uses Q8_K for all tensors",
|
||||
"""
|
||||
if tokenizer is None:
|
||||
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
|
||||
|
||||
arguments = dict(locals())
|
||||
arguments["model"] = self
|
||||
arguments["tokenizer"] = tokenizer
|
||||
arguments["push_to_hub"] = False # We save ourselves
|
||||
arguments["save_method"] = "merged_16bit" # Must be 16bit
|
||||
del arguments["self"]
|
||||
del arguments["quantization_method"]
|
||||
|
||||
# Non blocking install GGUF first
|
||||
git_clone = install_llama_cpp_clone_non_blocking()
|
||||
python_install = install_python_non_blocking(["gguf", "protobuf"])
|
||||
git_clone.wait()
|
||||
makefile = install_llama_cpp_make_non_blocking()
|
||||
new_save_directory = unsloth_save_model(**arguments)
|
||||
python_install.wait()
|
||||
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
|
||||
file_location = save_to_gguf(new_save_directory, quantization_method, makefile)
|
||||
|
||||
# And save to HF
|
||||
if push_to_hub:
|
||||
print("Unsloth: Uploading GGUF to Huggingface Hub...")
|
||||
|
||||
from huggingface_hub import create_repo
|
||||
create_repo(
|
||||
repo_id = save_directory,
|
||||
token = token,
|
||||
repo_type = "model",
|
||||
exist_ok = True,
|
||||
)
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
hf_api = HfApi(token = token)
|
||||
|
||||
if "/" in file_location:
|
||||
uploaded_location = file_location[file_location.rfind("/")+1:]
|
||||
else:
|
||||
uploaded_location = file_location
|
||||
pass
|
||||
|
||||
hf_api.upload_file(
|
||||
path_or_fileobj = file_location,
|
||||
path_in_repo = uploaded_location,
|
||||
repo_id = save_directory,
|
||||
repo_type = "model",
|
||||
)
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def unsloth_push_to_hub_gguf(
|
||||
self,
|
||||
repo_id : str,
|
||||
tokenizer = None,
|
||||
quantization_method : str = "fast_quantized",
|
||||
use_temp_dir : Optional[bool] = None,
|
||||
commit_message : Optional[str] = None,
|
||||
private : Optional[bool] = None,
|
||||
token : Union[bool, str, None] = None,
|
||||
max_shard_size : Union[int, str, None] = "5GB",
|
||||
create_pr : bool = False,
|
||||
safe_serialization : bool = True,
|
||||
revision : str = None,
|
||||
commit_description : str = None,
|
||||
tags : Optional[List[str]] = None,
|
||||
temporary_location : str = "_unsloth_temporary_saved_buffers",
|
||||
maximum_memory_usage : float = 0.85,
|
||||
):
|
||||
"""
|
||||
Same as .push_to_hub(...) except 4bit weights are auto
|
||||
converted to float16 then converted to GGUF / llama.cpp format.
|
||||
|
||||
Choose for `quantization_method` to be:
|
||||
"not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
|
||||
"fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
|
||||
"quantized" : "Recommended. Slow conversion. Fast inference, small files.",
|
||||
"f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
|
||||
"f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
|
||||
"q8_0" : "Fast conversion. High resource use, but generally acceptable.",
|
||||
"q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
|
||||
"q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
|
||||
"q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
|
||||
"q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
|
||||
"q3_k_s" : "Uses Q3_K for all tensors",
|
||||
"q4_0" : "Original quant method, 4-bit.",
|
||||
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
|
||||
"q4_k_s" : "Uses Q4_K for all tensors",
|
||||
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
|
||||
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
|
||||
"q5_k_s" : "Uses Q5_K for all tensors",
|
||||
"q6_k" : "Uses Q8_K for all tensors",
|
||||
"""
|
||||
if tokenizer is None:
|
||||
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
|
||||
|
||||
arguments = dict(locals())
|
||||
arguments["model"] = self
|
||||
arguments["tokenizer"] = tokenizer
|
||||
arguments["save_directory"] = repo_id
|
||||
arguments["push_to_hub"] = False # We save ourselves
|
||||
arguments["save_method"] = "merged_16bit" # Must be 16bit
|
||||
del arguments["self"]
|
||||
del arguments["repo_id"]
|
||||
del arguments["quantization_method"]
|
||||
|
||||
# Non blocking install GGUF first
|
||||
git_clone = install_llama_cpp_clone_non_blocking()
|
||||
python_install = install_python_non_blocking(["gguf", "protobuf"])
|
||||
git_clone.wait()
|
||||
makefile = install_llama_cpp_make_non_blocking()
|
||||
new_save_directory = unsloth_save_model(**arguments)
|
||||
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
|
||||
python_install.wait()
|
||||
file_location = save_to_gguf(new_save_directory, quantization_method, makefile)
|
||||
|
||||
# Save to hub
|
||||
print("Unsloth: Uploading GGUF to Huggingface Hub...")
|
||||
|
||||
from huggingface_hub import create_repo
|
||||
create_repo(
|
||||
repo_id = save_directory,
|
||||
private = private,
|
||||
token = token,
|
||||
repo_type = "model",
|
||||
exist_ok = True,
|
||||
)
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
hf_api = HfApi(token = token)
|
||||
|
||||
if "/" in file_location:
|
||||
uploaded_location = file_location[file_location.rfind("/")+1:]
|
||||
else:
|
||||
uploaded_location = file_location
|
||||
pass
|
||||
|
||||
hf_api.upload_file(
|
||||
path_or_fileobj = file_location,
|
||||
path_in_repo = uploaded_location,
|
||||
repo_id = save_directory,
|
||||
repo_type = "model",
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def patch_saving_functions(model):
|
||||
import inspect
|
||||
import re
|
||||
import types
|
||||
from typing import Callable, Optional, Union, List
|
||||
|
||||
if hasattr(model, "_original_push_to_hub"): return
|
||||
|
||||
original_push_to_hub = model.push_to_hub
|
||||
signature = str(inspect.signature(original_push_to_hub)).replace("NoneType", "None")
|
||||
signature = signature[1:]
|
||||
signature = re.sub("<function save at .+?>", "torch.save", signature)
|
||||
docs = original_push_to_hub.__doc__.encode("utf-8").decode("utf-8")
|
||||
model._original_push_to_hub = original_push_to_hub
|
||||
|
||||
push_to_hub_text = f'''def unsloth_push_to_hub(self, {signature}:
|
||||
"""
|
||||
{docs}
|
||||
"""
|
||||
arguments = dict(locals())
|
||||
del arguments["self"]
|
||||
if "tags" in arguments and arguments["tags"] is not None:
|
||||
assert(isinstance(arguments["tags"], (list, tuple)))
|
||||
arguments["tags"] = list(arguments["tags"]) + ["unsloth",]
|
||||
elif "tags" in arguments:
|
||||
arguments["tags"] = ["unsloth",]
|
||||
elif hasattr(self, "add_model_tags"):
|
||||
self.add_model_tags(["unsloth",])
|
||||
try:
|
||||
return self._original_push_to_hub(**arguments)
|
||||
except:
|
||||
del arguments["tags"]
|
||||
return self._original_push_to_hub(**arguments)
|
||||
pass
|
||||
'''
|
||||
exec(push_to_hub_text, globals())
|
||||
model.push_to_hub = types.MethodType(unsloth_push_to_hub, model)
|
||||
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["unsloth",])
|
||||
|
||||
if hasattr(model, "config"):
|
||||
# Counteract tokenizers
|
||||
model.push_to_hub_merged = types.MethodType(unsloth_push_to_hub_merged, model)
|
||||
model.save_pretrained_merged = types.MethodType(unsloth_save_pretrained_merged, model)
|
||||
model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
|
||||
model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model)
|
||||
else:
|
||||
model.push_to_hub_merged = model.push_to_hub
|
||||
model.save_pretrained_merged = model.save_pretrained
|
||||
model.push_to_hub_gguf = model.push_to_hub
|
||||
model.save_pretrained_gguf = model.save_pretrained
|
||||
pass
|
||||
|
||||
original_model = model
|
||||
while hasattr(original_model, "model"):
|
||||
original_model = original_model.model
|
||||
if hasattr(original_model, "_original_push_to_hub"): continue
|
||||
|
||||
original_model._original_push_to_hub = original_model.push_to_hub
|
||||
original_model.push_to_hub = types.MethodType(unsloth_push_to_hub, original_model)
|
||||
|
||||
if hasattr(original_model, "add_model_tags"):
|
||||
original_model.add_model_tags(["unsloth",])
|
||||
|
||||
if hasattr(original_model, "config"):
|
||||
# Counteract tokenizers
|
||||
original_model.push_to_hub_merged = \
|
||||
types.MethodType(unsloth_push_to_hub_merged, original_model)
|
||||
|
||||
original_model.save_pretrained_merged = \
|
||||
types.MethodType(unsloth_save_pretrained_merged, original_model)
|
||||
|
||||
original_model.push_to_hub_gguf = \
|
||||
types.MethodType(unsloth_push_to_hub_gguf, original_model)
|
||||
|
||||
original_model.save_pretrained_gguf = \
|
||||
types.MethodType(unsloth_save_pretrained_gguf, original_model)
|
||||
pass
|
||||
pass
|
||||
return
|
||||
pass
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue