Torch 2.2 (#157)

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update save.py

* Update fast_lora.py

* Update utils.py

* Update llama.py

* Update fast_lora.py

* Update swiglu.py

* Update save.py

* Update save.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Revert "Update llama.py"

This reverts commit a208ec46e0.

* Update llama.py

* Works?

* Update pyproject.toml

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Swiglu

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* attention_mask

* Update llama.py

* Update llama.py

* labels

* Update mistral.py

* Update llama.py

* attention mask

* Update save.py

* Update save.py

* Update mistral.py

* attention mask

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update dpo.py

* Patch saving

* Update save.py

* Update save.py

* patch_saving_functions

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* print

* Mistral patch

* Update mistral.py

* Update save.py

* saving

* Update llama.py

* Update llama.py

* Fast inference repatch

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update mistral.py

* Update __init__.py

* Fix inference

* Update mistral.py

* fast lm_head

* Remove fast path

* Update rope_embedding.py

* Update loader.py

* LlamaAttention_fast_forward_inference

* if past_key_value is not None and q_len == 1:

* revert inference

* Update loader.py

* past_key_value

* Update llama.py

* Update llama.py

* Fix SDPA

* Update llama.py

* padding

* Inference

* Update llama.py

* Revert

* Update mistral.py

* faster inference

* inference

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* inference

* Update llama.py

* Update utils.py

* faster inference

* Update llama.py

* revert

* lm_head

* Update llama.py

* inference

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* faster inference

* Update llama.py

* fast inference

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* torch compile

* past_key_values

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update llama.py

* fast inference + saving config.json

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* fast inference again

* more temp matrices

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update mistral.py

* Update llama.py

* SDPA

* attention_mask

* New version

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py

* Update save.py

* Update save.py

* Torch 2.2.0

* Update save.py

* mistral swa

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Fix SWA inference

* Fix llm_int8_skip_modules

* SWA inference

* Update save.py

* Update save.py

* Update pyproject.toml

* __version__

* __version__

* Update save.py

* Update save.py

* Update mistral.py
This commit is contained in:
Daniel Han 2024-02-07 04:40:50 +11:00 committed by GitHub
parent bb66faaa33
commit 25cfc7f590
5 changed files with 84 additions and 15 deletions

View file

@ -23,7 +23,7 @@ classifiers = [
]
[tool.setuptools.dynamic]
version = {attr = "unsloth.__version__"}
version = {attr = "unsloth.models._utils.__version__"}
[tool.setuptools]
include-package-data = false
@ -62,6 +62,16 @@ cu121onlytorch211 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
]
cu118onlytorch220 = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
]
cu121onlytorch220 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
]
cu118 = [
"unsloth[huggingface]",
"bitsandbytes",
@ -82,6 +92,16 @@ cu121_torch211 = [
"bitsandbytes",
"unsloth[cu121onlytorch211]",
]
cu118_torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch220]",
]
cu121_torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch220]",
]
kaggle = [
"unsloth[huggingface]",
]
@ -110,6 +130,19 @@ colab_ampere_torch211 = [
"ninja",
"flash-attn",
]
colab_torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch220]",
]
colab_ampere_torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch220]",
"packaging",
"ninja",
"flash-attn",
]
cu118_ampere = [
"unsloth[huggingface]",
"bitsandbytes",
@ -142,6 +175,22 @@ cu121_ampere_torch211 = [
"ninja",
"flash-attn",
]
cu118_ampere_torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch220]",
"packaging",
"ninja",
"flash-attn",
]
cu121_ampere_torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch220]",
"packaging",
"ninja",
"flash-attn",
]
[project.urls]
homepage = "http://www.unsloth.ai"

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2024.1"
import os
import warnings
import importlib

View file

@ -171,15 +171,28 @@ def LlamaAttention_fast_forward_inference(
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
# Grouped query attention
if n_groups != 1:
_, _, cached_len, _ = Kn.shape
Knn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
# Handle sliding windows
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
slicing_tokens = 1 - sliding_window
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
pass
# Grouped query attention
if n_groups != 1:
_, _, cached_len, _ = Knn.shape
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
pass
# else:
# Knn, Vnn = Knn, Vnn
# pass
# Attention
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:kv_seq_len])

View file

@ -128,7 +128,7 @@ class FastLanguageModel(FastLlamaModel):
"bnb_4bit_use_double_quant" : True,
"llm_int8_enable_fp32_cpu_offload" : False,
"llm_int8_has_fp16_weight" : False,
"llm_int8_skip_modules" : "null",
"llm_int8_skip_modules" : None,
"llm_int8_threshold" : 6.0,
"load_in_4bit" : True,
"load_in_8bit" : False,

View file

@ -72,6 +72,7 @@ def print_quantization_methods():
pass
def _merge_lora(layer, name):
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
@ -85,9 +86,12 @@ def _merge_lora(layer, name):
W = W.to(torch.float32).t()
if A is not None:
sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
W += sAB
if not torch.isfinite(W).all():
# sAB = (A.t().to(torch.float32) @ (s * B.t().to(torch.float32)))
# W += sAB
W.addmm_(A.t().to(torch.float32), B.t().to(torch.float32), alpha = s)
# if not torch.isfinite(W).all():
maximum_element = torch.max(W.min().abs(), W.max())
if not torch.isfinite(maximum_element).item():
raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
pass
W = W.t().to(dtype)
@ -373,7 +377,7 @@ def unsloth_save_model(
# elif (max_ram - W.nbytes) > 0:
# # Save to CPU memory
# logger.warning_once(f"We will save to RAM and not VRAM now.")
# state_dict[name] = W.to("cpu", non_blocking = True)
# state_dict[name] = W.to("cpu", non_blocking = True, copy = True)
# max_ram = max(max_ram - W.nbytes, 0)
else:
# Save to Disk
@ -579,9 +583,11 @@ def save_to_gguf(
f"--outfile {final_location} "\
f"--outtype {first_conversion} --concurrency {n_cpus}"
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, bufsize = 1) as sp:
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.PIPE, bufsize = 1) as sp:
for line in sp.stdout:
print(line.decode("utf-8"), flush = True, end = "")
if sp.returncode is not None and sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, sp.args)
pass
# Check if quantization succeeded!
@ -609,6 +615,8 @@ def save_to_gguf(
with subprocess.Popen(command, shell = True, stderr = subprocess.PIPE, bufsize = 1) as sp:
for line in sp.stderr:
print(line.decode("utf-8"), flush = True, end = "")
if sp.returncode is not None and sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, sp.args)
pass
# Check if quantization succeeded!