mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Torch 2.2 (#157)
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update save.py
* Update fast_lora.py
* Update utils.py
* Update llama.py
* Update fast_lora.py
* Update swiglu.py
* Update save.py
* Update save.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Revert "Update llama.py"
This reverts commit a208ec46e0.
* Update llama.py
* Works?
* Update pyproject.toml
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Swiglu
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update swiglu.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* Update fast_lora.py
* attention_mask
* Update llama.py
* Update llama.py
* labels
* Update mistral.py
* Update llama.py
* attention mask
* Update save.py
* Update save.py
* Update mistral.py
* attention mask
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update dpo.py
* Patch saving
* Update save.py
* Update save.py
* patch_saving_functions
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* print
* Mistral patch
* Update mistral.py
* Update save.py
* saving
* Update llama.py
* Update llama.py
* Fast inference repatch
* Update llama.py
* Update utils.py
* Update utils.py
* Update utils.py
* Update mistral.py
* Update __init__.py
* Fix inference
* Update mistral.py
* fast lm_head
* Remove fast path
* Update rope_embedding.py
* Update loader.py
* LlamaAttention_fast_forward_inference
* if past_key_value is not None and q_len == 1:
* revert inference
* Update loader.py
* past_key_value
* Update llama.py
* Update llama.py
* Fix SDPA
* Update llama.py
* padding
* Inference
* Update llama.py
* Revert
* Update mistral.py
* faster inference
* inference
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* inference
* Update llama.py
* Update utils.py
* faster inference
* Update llama.py
* revert
* lm_head
* Update llama.py
* inference
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* faster inference
* Update llama.py
* fast inference
* Update llama.py
* Update llama.py
* Update mistral.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* torch compile
* past_key_values
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update utils.py
* Update utils.py
* Update utils.py
* Update utils.py
* Update llama.py
* fast inference + saving config.json
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update mistral.py
* fast inference again
* more temp matrices
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* fast inference
* Update mistral.py
* Update llama.py
* SDPA
* attention_mask
* New version
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update utils.py
* Update utils.py
* Update save.py
* Update save.py
* Torch 2.2.0
* Update save.py
* mistral swa
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Update save.py
* Fix SWA inference
* Fix llm_int8_skip_modules
* SWA inference
* Update save.py
* Update save.py
* Update pyproject.toml
* __version__
* __version__
* Update save.py
* Update save.py
* Update mistral.py
This commit is contained in:
parent
bb66faaa33
commit
25cfc7f590
5 changed files with 84 additions and 15 deletions
|
|
@ -23,7 +23,7 @@ classifiers = [
|
|||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "unsloth.__version__"}
|
||||
version = {attr = "unsloth.models._utils.__version__"}
|
||||
|
||||
[tool.setuptools]
|
||||
include-package-data = false
|
||||
|
|
@ -62,6 +62,16 @@ cu121onlytorch211 = [
|
|||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
||||
]
|
||||
cu118onlytorch220 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
||||
]
|
||||
cu121onlytorch220 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
||||
]
|
||||
cu118 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
|
|
@ -82,6 +92,16 @@ cu121_torch211 = [
|
|||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch211]",
|
||||
]
|
||||
cu118_torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu118onlytorch220]",
|
||||
]
|
||||
cu121_torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch220]",
|
||||
]
|
||||
kaggle = [
|
||||
"unsloth[huggingface]",
|
||||
]
|
||||
|
|
@ -110,6 +130,19 @@ colab_ampere_torch211 = [
|
|||
"ninja",
|
||||
"flash-attn",
|
||||
]
|
||||
colab_torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch220]",
|
||||
]
|
||||
colab_ampere_torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch220]",
|
||||
"packaging",
|
||||
"ninja",
|
||||
"flash-attn",
|
||||
]
|
||||
cu118_ampere = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
|
|
@ -142,6 +175,22 @@ cu121_ampere_torch211 = [
|
|||
"ninja",
|
||||
"flash-attn",
|
||||
]
|
||||
cu118_ampere_torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu118onlytorch220]",
|
||||
"packaging",
|
||||
"ninja",
|
||||
"flash-attn",
|
||||
]
|
||||
cu121_ampere_torch220 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes",
|
||||
"unsloth[cu121onlytorch220]",
|
||||
"packaging",
|
||||
"ninja",
|
||||
"flash-attn",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
homepage = "http://www.unsloth.ai"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "2024.1"
|
||||
import os
|
||||
import warnings
|
||||
import importlib
|
||||
|
|
|
|||
|
|
@ -171,15 +171,28 @@ def LlamaAttention_fast_forward_inference(
|
|||
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
|
||||
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
|
||||
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
_, _, cached_len, _ = Kn.shape
|
||||
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
# Handle sliding windows
|
||||
sliding_window = getattr(self.config, "sliding_window", None)
|
||||
if sliding_window is not None and kv_seq_len > sliding_window:
|
||||
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
|
||||
slicing_tokens = 1 - sliding_window
|
||||
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
|
||||
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
|
||||
else:
|
||||
Knn, Vnn = Kn, Vn
|
||||
pass
|
||||
|
||||
# Grouped query attention
|
||||
if n_groups != 1:
|
||||
_, _, cached_len, _ = Knn.shape
|
||||
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
||||
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
||||
pass
|
||||
# else:
|
||||
# Knn, Vnn = Knn, Vnn
|
||||
# pass
|
||||
|
||||
# Attention
|
||||
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:kv_seq_len])
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class FastLanguageModel(FastLlamaModel):
|
|||
"bnb_4bit_use_double_quant" : True,
|
||||
"llm_int8_enable_fp32_cpu_offload" : False,
|
||||
"llm_int8_has_fp16_weight" : False,
|
||||
"llm_int8_skip_modules" : "null",
|
||||
"llm_int8_skip_modules" : None,
|
||||
"llm_int8_threshold" : 6.0,
|
||||
"load_in_4bit" : True,
|
||||
"load_in_8bit" : False,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ def print_quantization_methods():
|
|||
pass
|
||||
|
||||
|
||||
|
||||
def _merge_lora(layer, name):
|
||||
|
||||
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
|
||||
|
|
@ -85,9 +86,12 @@ def _merge_lora(layer, name):
|
|||
W = W.to(torch.float32).t()
|
||||
|
||||
if A is not None:
|
||||
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
|
||||
W += sAB
|
||||
if not torch.isfinite(W).all():
|
||||
# sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
|
||||
# W += sAB
|
||||
W.addmm_(A.t().to(torch.float32), B.t().to(torch.float32), alpha = s)
|
||||
# if not torch.isfinite(W).all():
|
||||
maximum_element = torch.max(W.min().abs(), W.max())
|
||||
if not torch.isfinite(maximum_element).item():
|
||||
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
|
||||
pass
|
||||
W = W.t().to(dtype)
|
||||
|
|
@ -373,7 +377,7 @@ def unsloth_save_model(
|
|||
# elif (max_ram - W.nbytes) > 0:
|
||||
# # Save to CPU memory
|
||||
# logger.warning_once(f"We will save to RAM and not VRAM now.")
|
||||
# state_dict[name] = W.to("cpu", non_blocking = True)
|
||||
# state_dict[name] = W.to("cpu", non_blocking = True, copy = True)
|
||||
# max_ram = max(max_ram - W.nbytes, 0)
|
||||
else:
|
||||
# Save to Disk
|
||||
|
|
@ -579,9 +583,11 @@ def save_to_gguf(
|
|||
f"--outfile {final_location} "\
|
||||
f"--outtype {first_conversion} --concurrency {n_cpus}"
|
||||
|
||||
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, bufsize = 1) as sp:
|
||||
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.PIPE, bufsize = 1) as sp:
|
||||
for line in sp.stdout:
|
||||
print(line.decode("utf-8"), flush = True, end = "")
|
||||
if sp.returncode is not None and sp.returncode != 0:
|
||||
raise subprocess.CalledProcessError(sp.returncode, sp.args)
|
||||
pass
|
||||
|
||||
# Check if quantization succeeded!
|
||||
|
|
@ -609,6 +615,8 @@ def save_to_gguf(
|
|||
with subprocess.Popen(command, shell = True, stderr = subprocess.PIPE, bufsize = 1) as sp:
|
||||
for line in sp.stderr:
|
||||
print(line.decode("utf-8"), flush = True, end = "")
|
||||
if sp.returncode is not None and sp.returncode != 0:
|
||||
raise subprocess.CalledProcessError(sp.returncode, sp.args)
|
||||
pass
|
||||
|
||||
# Check if quantization succeeded!
|
||||
|
|
|
|||
Loading…
Reference in a new issue