mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Bug fixes (#1516)
* use exact model name
* Update save.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* print
* Update _utils.py
* Update _utils.py
* Update llama.py
* Update _utils.py
* Update vision.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update _utils.py
* Update loader.py
* accurate_accumulation
* Update loader.py
* Update loader.py
* Update _utils.py
* Update loader.py
* Update loader.py
* Update loader.py
* Update loader.py
* Update pyproject.toml
* Update __init__.py
* Update pyproject.toml
* Update __init__.py
* Update __init__.py
* Fix Triton heuristics
https://github.com/triton-lang/triton/issues/5224
* Update __init__.py
* Update __init__.py
* Update __init__.py
* Update __init__.py
* Xformers
* Update loader.py
* Update loader.py
* Rewind
* Update _utils.py
* Update _utils.py
* requires grad
* Update loader.py
* Update _utils.py
* Update loader.py
* changing model to base_model if peft model is already used
* Improve debugging experience (#1512)
* Create CONTRIBUTING.md (#1472)
Creating contributing guidelines
* Update CONTRIBUTING.md
improved sentence
* Improve logging control in `unsloth_compile_transformers` by conditionally redirecting stdout based on UNSLOTH_DISABLE_LOGGER environment variable
---------
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com>
* Update loader.py
* Update llama.py
* Update llama.py
* Revert "Update llama.py"
This reverts commit b7ddf962d2.
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Update llama.py
* Auto change is_bfloat16_supported
* Update llama.py
* Force data-type
* Update llama.py
* All attention refactor fix (#1491)
* change initilization of n_heads, n_kv_heads, hidden_size in llama.py
* do the same for cohere, mistral, gemma2, granite
* do the same for flexattention,cohere, mistral, granite
* Update llama.py
* Update llama.py
* Update granite to work with latest post_patch methods (#1502)
* Update granite to work with latest post_patch methods
* Pass position_embeddings for granite even if transformers<4.47
* Update llama.py
---------
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* Minor fixes for granite models (#1503)
* Update granite.py
Grab residual multiplier directly from layer
* Update llama.py
Version should read >= 4.47.1 as that is the version requiring the changes
* Update granite.py
* Update llama.py
---------
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* support modelscope models and datasets (#1481)
* support modelscope
* change modelscope args
* remove useless import
* remove useless import
* fix
* wip
* fix
* remove useless code
* add readme
* add some comments
* change print to raise error
* update comment
* Update loader.py
---------
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
---------
Co-authored-by: Itsuro Tajima <tajima@georepublic.de>
Co-authored-by: Muhammad Osama <muhammadosama1994@gmail.com>
Co-authored-by: Edd <68678137+Erland366@users.noreply.github.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com>
Co-authored-by: Kareem <81531392+KareemMusleh@users.noreply.github.com>
Co-authored-by: Datta Nimmaturi <datta.nimmaturi@nutanix.com>
Co-authored-by: Z <coffeevampirebusiness@gmail.com>
Co-authored-by: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com>
This commit is contained in:
parent
48627f876c
commit
63782ea3af
16 changed files with 307 additions and 160 deletions
|
|
@ -212,6 +212,9 @@ For **advanced installation instructions** or if you see weird errors during ins
|
|||
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
|
||||
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
|
||||
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
||||
- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`.
|
||||
|
||||
> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community.
|
||||
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
|
|
|||
|
|
@ -148,20 +148,20 @@ cu124onlytorch250 = [
|
|||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
||||
]
|
||||
cu121onlytorch251 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
||||
]
|
||||
cu124onlytorch251 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
||||
]
|
||||
cu118 = [
|
||||
"unsloth[huggingface]",
|
||||
|
|
|
|||
|
|
@ -30,11 +30,14 @@ Happy fine-tuning!
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def run(args):
|
||||
import torch
|
||||
from unsloth import FastLanguageModel
|
||||
from datasets import load_dataset
|
||||
from transformers.utils import strtobool
|
||||
from trl import SFTTrainer
|
||||
from transformers import TrainingArguments
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
|
@ -86,8 +89,13 @@ def run(args):
|
|||
texts.append(text)
|
||||
return {"text": texts}
|
||||
|
||||
# Load and format dataset
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
|
||||
if use_modelscope:
|
||||
from modelscope import MsDataset
|
||||
dataset = MsDataset.load(args.dataset, split="train")
|
||||
else:
|
||||
# Load and format dataset
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
print("Data is formatted and ready!")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,16 +17,6 @@ from packaging.version import Version
|
|||
import os, re, subprocess, inspect
|
||||
import numpy as np
|
||||
|
||||
# # Define a list of modules to check
|
||||
# MODULES_TO_CHECK = ["bitsandbytes"]
|
||||
|
||||
# # Check if any of the modules in the list have been imported
|
||||
# for module in MODULES_TO_CHECK:
|
||||
# if module in sys.modules:
|
||||
# raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
|
||||
# pass
|
||||
# pass
|
||||
|
||||
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
|
||||
# enabling it will require much more work, so we have to prioritize. Please understand!
|
||||
# We do have a beta version, which you can contact us about!
|
||||
|
|
@ -55,7 +45,12 @@ else:
|
|||
pass
|
||||
|
||||
# Reduce VRAM usage by reducing fragmentation
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]"
|
||||
# And optimize pinning of memory
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
|
||||
"expandable_segments:True,"\
|
||||
"roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\
|
||||
"pinned_use_cuda_host_register:True,"\
|
||||
"pinned_num_register_threads:8"
|
||||
|
||||
# Hugging Face Hub faster downloads
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||
|
|
@ -89,6 +84,36 @@ elif (major_torch == 2) and (minor_torch < 2):
|
|||
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
|
||||
pass
|
||||
|
||||
# Fix Xformers performance issues since 0.0.25
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from importlib.metadata import version as importlib_version
|
||||
from packaging.version import Version
|
||||
try:
|
||||
xformers_version = importlib_version("xformers")
|
||||
if Version(xformers_version) < Version("0.0.29"):
|
||||
xformers_location = importlib.util.find_spec("xformers").origin
|
||||
xformers_location = os.path.split(xformers_location)[0]
|
||||
cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py"
|
||||
|
||||
if cutlass.exists():
|
||||
with open(cutlass, "r+") as f:
|
||||
text = f.read()
|
||||
# See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591
|
||||
if "num_splits_key=-1," in text:
|
||||
text = text.replace("num_splits_key=-1,", "num_splits_key=None,")
|
||||
f.seek(0)
|
||||
f.write(text)
|
||||
f.truncate()
|
||||
print("Unsloth: Patching Xformers to fix some performance issues.")
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
# Torch 2.4 has including_emulation
|
||||
major_version, minor_version = torch.cuda.get_device_capability()
|
||||
SUPPORTS_BFLOAT16 = (major_version >= 8)
|
||||
|
|
@ -166,9 +191,18 @@ pass
|
|||
|
||||
# Check for unsloth_zoo
|
||||
try:
|
||||
unsloth_zoo_version = importlib_version("unsloth_zoo")
|
||||
if Version(unsloth_zoo_version) < Version("2025.1.1"):
|
||||
try:
|
||||
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
|
||||
except:
|
||||
try:
|
||||
os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo")
|
||||
except:
|
||||
raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`")
|
||||
import unsloth_zoo
|
||||
except:
|
||||
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`")
|
||||
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`")
|
||||
pass
|
||||
|
||||
from .models import *
|
||||
|
|
|
|||
|
|
@ -25,11 +25,6 @@ from unsloth_zoo.loss_utils import (
|
|||
)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
||||
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
||||
})
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(
|
||||
logits_ptr ,
|
||||
logits_row_stride ,
|
||||
|
|
@ -95,13 +90,15 @@ def _cross_entropy_forward(
|
|||
tl.store(logsumexp_ptr, logsumexp)
|
||||
tl.store(loss_ptr, loss)
|
||||
pass
|
||||
_cross_entropy_forward = triton.jit(_cross_entropy_forward)
|
||||
_cross_entropy_forward = triton.heuristics(
|
||||
{
|
||||
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
||||
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
||||
}
|
||||
)(_cross_entropy_forward)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
||||
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
||||
})
|
||||
@triton.jit
|
||||
def _chunked_cross_entropy_forward(
|
||||
logits_ptr ,
|
||||
logits_row_stride ,
|
||||
|
|
@ -177,13 +174,15 @@ def _chunked_cross_entropy_forward(
|
|||
pass
|
||||
tl.store(logsumexp_ptr, logsumexp)
|
||||
pass
|
||||
_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
|
||||
_chunked_cross_entropy_forward = triton.heuristics(
|
||||
{
|
||||
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
||||
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
||||
}
|
||||
)(_chunked_cross_entropy_forward)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
||||
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
||||
})
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(
|
||||
logits_ptr ,
|
||||
logits_row_stride ,
|
||||
|
|
@ -264,10 +263,16 @@ def _cross_entropy_backward(
|
|||
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
||||
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
||||
pass
|
||||
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
|
||||
_cross_entropy_backward = triton.heuristics(
|
||||
{
|
||||
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
||||
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
||||
}
|
||||
)(_cross_entropy_backward)
|
||||
|
||||
|
||||
MAX_FUSED_SIZE = 65536 # 2**16
|
||||
|
||||
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
|
||||
|
|
|
|||
|
|
@ -43,9 +43,9 @@ if not HAS_FLEX_ATTENTION:
|
|||
# Logit softcapping
|
||||
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
||||
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
head_dim = self.head_dim
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
|
||||
# Grouped query attention
|
||||
|
|
@ -130,7 +130,7 @@ else:
|
|||
pass
|
||||
|
||||
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
head_dim = self.head_dim
|
||||
s = self.config.query_pre_attn_scalar
|
||||
t = self.config.attn_logit_softcapping
|
||||
|
|
@ -147,9 +147,9 @@ torch_matmul = torch.matmul
|
|||
torch_tanh = torch.tanh
|
||||
torch_nn_functional_softmax = torch.nn.functional.softmax
|
||||
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
head_dim = self.head_dim
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
|
||||
# Grouped query attention
|
||||
|
|
|
|||
|
|
@ -53,8 +53,6 @@ def _rms_layernorm_forward(
|
|||
pass
|
||||
|
||||
|
||||
@triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),})
|
||||
@triton.jit
|
||||
def _rms_layernorm_backward(
|
||||
dY, dY_row_stride,
|
||||
dX, dX_row_stride,
|
||||
|
|
@ -97,6 +95,12 @@ def _rms_layernorm_backward(
|
|||
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
||||
tl.store(dX + col_offsets, output, mask = mask)
|
||||
pass
|
||||
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
|
||||
_rms_layernorm_backward = triton.heuristics(
|
||||
{
|
||||
"GEMMA": lambda args: bool(args["GEMMA"]),
|
||||
}
|
||||
)(_rms_layernorm_backward)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
|
|
|||
|
|
@ -18,8 +18,6 @@ import torch
|
|||
from .utils import calculate_settings
|
||||
ROPE_GROUP_SIZE : int = 4
|
||||
|
||||
@triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),})
|
||||
@triton.jit
|
||||
def _rope_embedding(
|
||||
Q, Q_row_stride,
|
||||
cos, cos_row_stride,
|
||||
|
|
@ -69,6 +67,12 @@ def _rope_embedding(
|
|||
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
|
||||
pass
|
||||
pass
|
||||
_rope_embedding = triton.jit(_rope_embedding)
|
||||
_rope_embedding = triton.heuristics(
|
||||
{
|
||||
"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
|
||||
}
|
||||
)(_rope_embedding)
|
||||
|
||||
|
||||
class Fast_RoPE_Embedding(torch.autograd.Function):
|
||||
|
|
|
|||
|
|
@ -12,9 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "2024.12.12"
|
||||
__version__ = "2025.1.1"
|
||||
|
||||
__all__ = [
|
||||
"SUPPORTS_BFLOAT16",
|
||||
"is_bfloat16_supported",
|
||||
|
||||
"prepare_model_for_kbit_training",
|
||||
"xformers",
|
||||
"xformers_attention",
|
||||
|
|
@ -30,7 +33,6 @@ __all__ = [
|
|||
"offload_to_disk",
|
||||
"offload_input_embeddings",
|
||||
"offload_output_embeddings",
|
||||
"is_bfloat16_supported",
|
||||
"unsloth_offloaded_gradient_checkpoint",
|
||||
"torch_compile_options",
|
||||
"patch_linear_scaling",
|
||||
|
|
@ -58,7 +60,6 @@ __all__ = [
|
|||
"fused_linear_cross_entropy",
|
||||
"patch_unsloth_smart_gradient_checkpointing",
|
||||
"unpatch_unsloth_smart_gradient_checkpointing",
|
||||
"create_gradient_checkpointing_buffer",
|
||||
|
||||
"patch_compiled_autograd",
|
||||
"process_vision_info",
|
||||
|
|
@ -97,7 +98,6 @@ from unsloth_zoo.gradient_checkpointing import (
|
|||
|
||||
patch_unsloth_smart_gradient_checkpointing,
|
||||
unpatch_unsloth_smart_gradient_checkpointing,
|
||||
create_gradient_checkpointing_buffer,
|
||||
)
|
||||
from unsloth_zoo.loss_utils import (
|
||||
HAS_CUT_CROSS_ENTROPY,
|
||||
|
|
@ -556,6 +556,7 @@ def prepare_model_for_kbit_training(
|
|||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
pass
|
||||
|
||||
return model
|
||||
pass
|
||||
|
|
@ -1203,8 +1204,6 @@ def unsloth_compile_transformers(
|
|||
return
|
||||
pass
|
||||
|
||||
if disable: return
|
||||
|
||||
model_types = get_transformers_model_type(
|
||||
model_name = model_name,
|
||||
token = token,
|
||||
|
|
@ -1212,6 +1211,8 @@ def unsloth_compile_transformers(
|
|||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
|
||||
if disable: return
|
||||
|
||||
for model_type in model_types:
|
||||
_unsloth_compile_transformers(
|
||||
model_type,
|
||||
|
|
|
|||
|
|
@ -94,9 +94,9 @@ def CohereAttention_fast_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
|
|
@ -259,12 +259,14 @@ def CohereAttention_fast_forward_inference(
|
|||
K1, V1 = past_key_value
|
||||
dtype = Xn.dtype
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
attention_size = n_heads*head_dim
|
||||
# assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
attention_size = n_heads*head_dim
|
||||
seq_len = K1.shape[-2]
|
||||
kv_seq_len = seq_len + 1
|
||||
|
||||
|
|
@ -281,10 +283,10 @@ def CohereAttention_fast_forward_inference(
|
|||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
|
||||
# Mistral Nemo 12b has weird dimensions
|
||||
if attention_size != self.hidden_size:
|
||||
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
|
||||
if attention_size != hidden_size:
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
|
||||
else:
|
||||
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
|
||||
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
|
||||
pass
|
||||
|
||||
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
||||
|
|
|
|||
|
|
@ -98,9 +98,9 @@ def Gemma2Attention_fast_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
|
|
@ -255,12 +255,14 @@ def Gemma2Attention_fast_forward_inference(
|
|||
K1, V1 = past_key_value
|
||||
dtype = Xn.dtype
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
attention_size = n_heads*head_dim
|
||||
# assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
attention_size = n_heads*head_dim
|
||||
seq_len = K1.shape[-2]
|
||||
kv_seq_len = seq_len + 1
|
||||
|
||||
|
|
@ -276,7 +278,7 @@ def Gemma2Attention_fast_forward_inference(
|
|||
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
# Only for Gemma2
|
||||
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
|
||||
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
||||
|
||||
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ from .llama import (
|
|||
LlamaLinearScalingRotaryEmbedding,
|
||||
)
|
||||
from .mistral import *
|
||||
|
||||
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
|
||||
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
|
||||
try:
|
||||
from transformers.models.granite.modeling_granite import (
|
||||
GraniteAttention,
|
||||
|
|
@ -84,9 +85,9 @@ def GraniteAttention_fast_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
|
|
@ -181,6 +182,11 @@ def GraniteDecoderLayer_fast_forward(
|
|||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args, **kwargs,
|
||||
):
|
||||
residual_multiplier = \
|
||||
self.residual_multiplier \
|
||||
if hasattr(self, "residual_multiplier") else \
|
||||
self.config.residual_multiplier
|
||||
|
||||
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
|
||||
|
|
@ -196,13 +202,13 @@ def GraniteDecoderLayer_fast_forward(
|
|||
position_embeddings = position_embeddings,
|
||||
_flag_for_generation=self._flag_for_generation,
|
||||
)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
else:
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
|
||||
|
|
@ -217,13 +223,13 @@ def GraniteDecoderLayer_fast_forward(
|
|||
padding_mask=padding_mask,
|
||||
position_embeddings = position_embeddings,
|
||||
)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
pass
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
|
@ -257,12 +263,14 @@ def GraniteAttention_fast_forward_inference(
|
|||
K1, V1 = past_key_value
|
||||
dtype = Xn.dtype
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
attention_size = n_heads*head_dim
|
||||
# assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
attention_size = n_heads*head_dim
|
||||
seq_len = K1.shape[-2]
|
||||
kv_seq_len = seq_len + 1
|
||||
|
||||
|
|
@ -278,7 +286,7 @@ def GraniteAttention_fast_forward_inference(
|
|||
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
|
||||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
# Only for Gemma2
|
||||
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
|
||||
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
||||
|
||||
|
||||
|
|
@ -367,6 +375,10 @@ def GraniteModel_fast_forward_inference(
|
|||
hidden_states = self.model.embed_tokens(input_ids)
|
||||
hidden_states = hidden_states.to(self.config.torch_dtype)
|
||||
hidden_states *= self.model.embedding_multiplier
|
||||
residual_multiplier = \
|
||||
self.residual_multiplier \
|
||||
if hasattr(self, "residual_multiplier") else \
|
||||
self.config.residual_multiplier
|
||||
|
||||
bsz, q_len, hd = hidden_states.shape
|
||||
seq_len = past_key_values[0][0].shape[-2]
|
||||
|
|
@ -398,12 +410,12 @@ def GraniteModel_fast_forward_inference(
|
|||
position_embeddings = position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
|
||||
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
|
||||
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
|
||||
|
||||
next_decoder_cache.append(present_key_value)
|
||||
pass
|
||||
|
|
@ -421,6 +433,18 @@ class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
|
|||
def __init__(self, config):
|
||||
super().__init__(config = config)
|
||||
|
||||
def patched_init(original_init):
|
||||
def new_init(self, *args, **kwargs):
|
||||
# we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here
|
||||
# https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243
|
||||
# The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference
|
||||
# So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later
|
||||
config = kwargs.get("config", args[0] if args else None)
|
||||
if config is not None:
|
||||
self.config = config
|
||||
original_init(self, *args, **kwargs)
|
||||
return new_init
|
||||
|
||||
class FastGraniteModel(FastLlamaModel):
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -435,12 +459,13 @@ class FastGraniteModel(FastLlamaModel):
|
|||
exec(function, globals())
|
||||
GraniteAttention.__init__ = eval(init_name)
|
||||
pass
|
||||
GraniteAttention .forward = GraniteAttention_fast_forward
|
||||
GraniteSdpaAttention .forward = GraniteAttention_fast_forward
|
||||
GraniteFlashAttention2.forward = GraniteAttention_fast_forward
|
||||
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
|
||||
GraniteModel .forward = LlamaModel_fast_forward
|
||||
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
|
||||
GraniteAttention .forward = GraniteAttention_fast_forward
|
||||
GraniteSdpaAttention .forward = GraniteAttention_fast_forward
|
||||
GraniteFlashAttention2.forward = GraniteAttention_fast_forward
|
||||
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
|
||||
GraniteModel .forward = LlamaModel_fast_forward
|
||||
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
|
||||
GraniteForCausalLM .__init__ = patched_init(GraniteForCausalLM.__init__)
|
||||
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
|
||||
fix_prepare_inputs_for_generation(GraniteForCausalLM)
|
||||
|
||||
|
|
@ -452,7 +477,7 @@ class FastGraniteModel(FastLlamaModel):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def post_patch(model):
|
||||
def post_patch(model, tokenizer):
|
||||
|
||||
# Torch.compile fails on embedding matrix??
|
||||
# Workaround randomnly fixes it for torch versions < 2.2
|
||||
|
|
@ -517,7 +542,7 @@ class FastGraniteModel(FastLlamaModel):
|
|||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
return model, tokenizer
|
||||
pass
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ from ._utils import *
|
|||
from ._utils import __version__
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from transformers import __version__ as transformers_version
|
||||
from unsloth_zoo.utils import Version
|
||||
transformers_version = Version(transformers_version)
|
||||
# Transformers moved rotary embeddings out of all attention layers
|
||||
IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1")
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
logger,
|
||||
BaseModelOutputWithPast,
|
||||
|
|
@ -146,12 +150,14 @@ def LlamaAttention_fast_forward_inference(
|
|||
K1, V1 = past_key_value
|
||||
dtype = Xn.dtype
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
attention_size = n_heads*head_dim
|
||||
# assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
attention_size = n_heads*head_dim
|
||||
seq_len = K1.shape[-2]
|
||||
kv_seq_len = seq_len + 1
|
||||
|
||||
|
|
@ -168,10 +174,10 @@ def LlamaAttention_fast_forward_inference(
|
|||
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
||||
|
||||
# Mistral Nemo 12b has weird dimensions
|
||||
if attention_size != self.hidden_size:
|
||||
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
|
||||
if attention_size != hidden_size:
|
||||
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
|
||||
else:
|
||||
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
|
||||
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
|
||||
pass
|
||||
|
||||
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
||||
|
|
@ -356,9 +362,9 @@ def LlamaAttention_fast_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
|
|
@ -786,7 +792,7 @@ def LlamaModel_fast_forward(
|
|||
pass
|
||||
pass
|
||||
|
||||
if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"):
|
||||
if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"):
|
||||
# Transformers main has made it mandatory to pass position_embeddings
|
||||
# https://github.com/huggingface/transformers/pull/34858
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings)
|
||||
|
|
@ -996,18 +1002,20 @@ def CausalLM_fast_forward(fast_forward_inference):
|
|||
lm_head = self.lm_head.weight
|
||||
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
|
||||
logit_scaling = getattr(self.config, "logit_scale", 0)
|
||||
dtype = lm_head.dtype
|
||||
|
||||
if bsz == 1 and q_len == 1:
|
||||
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
|
||||
logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))
|
||||
logits = logits.unsqueeze(0).unsqueeze(0)
|
||||
elif num_logits_to_keep != 0:
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(dtype))
|
||||
else:
|
||||
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
|
||||
# < 1024 Normal Unsloth uses less VRAM!
|
||||
if bsz*q_len <= 1024: RETURN_LOGITS = True
|
||||
|
||||
if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
|
||||
|
||||
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
|
||||
loss = fused_linear_cross_entropy(
|
||||
hidden_states = hidden_states,
|
||||
|
|
@ -1029,7 +1037,7 @@ def CausalLM_fast_forward(fast_forward_inference):
|
|||
)
|
||||
return output
|
||||
pass
|
||||
logits = self.lm_head(hidden_states.to(lm_head.dtype))
|
||||
logits = self.lm_head(hidden_states.to(dtype))
|
||||
pass
|
||||
|
||||
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
|
||||
|
|
@ -1607,6 +1615,9 @@ class FastLlamaModel:
|
|||
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
|
||||
logger.warning_once("Device does not support bfloat16. Will change to float16.")
|
||||
dtype = torch.float16
|
||||
elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
|
||||
logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
|
||||
|
||||
|
|
@ -1879,6 +1890,13 @@ class FastLlamaModel:
|
|||
internal_model = internal_model.model
|
||||
pass
|
||||
internal_model._saved_temp_tokenizer = tokenizer
|
||||
|
||||
# For transformers > 4.47.1, we need to add rotary_emb to all attention layers
|
||||
if IS_ATTENTION_REFACTOR or hasattr(model.model, "rotary_emb"):
|
||||
rotary_emb = model.model.rotary_emb
|
||||
for layer in model.model.layers:
|
||||
layer.self_attn.rotary_emb = rotary_emb
|
||||
pass
|
||||
|
||||
return model, tokenizer
|
||||
pass
|
||||
|
|
@ -1967,29 +1985,41 @@ class FastLlamaModel:
|
|||
if "embed_tokens" in new_target_modules:
|
||||
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
|
||||
|
||||
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
|
||||
model.model.model.embed_tokens.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
|
||||
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
|
||||
new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.get_input_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
|
||||
# [TODO] Move old embed_tokens to CPU - should be disk!
|
||||
model.model.model.embed_tokens.original_module\
|
||||
model.get_input_embeddings().original_module\
|
||||
.to(device = "cpu", non_blocking = True)
|
||||
model.model.model.embed_tokens.original_module.requires_grad_(False)
|
||||
model.get_input_embeddings().original_module.requires_grad_(False)
|
||||
pass
|
||||
|
||||
if "lm_head" in new_target_modules:
|
||||
print("Unsloth: Training lm_head in mixed precision to save VRAM")
|
||||
|
||||
dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype
|
||||
model.model.lm_head.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
|
||||
model.model.lm_head.modules_to_save.default.requires_grad_(True)
|
||||
new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.get_output_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
|
||||
# [TODO] Move old lm_head to CPU - should be disk!
|
||||
model.model.lm_head.original_module\
|
||||
model.get_output_embeddings().original_module\
|
||||
.to(device = "cpu", non_blocking = True)
|
||||
model.model.lm_head.original_module.requires_grad_(False)
|
||||
model.get_output_embeddings().original_module.requires_grad_(False)
|
||||
pass
|
||||
|
||||
return model
|
||||
|
|
@ -2216,25 +2246,36 @@ class FastLlamaModel:
|
|||
|
||||
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
|
||||
|
||||
# Now patch lm_head and embed_tokens
|
||||
if train_embed_tokens:
|
||||
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
|
||||
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
|
||||
assert(hasattr(model.get_input_embeddings(), "modules_to_save"))
|
||||
|
||||
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
|
||||
model.model.model.embed_tokens.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
|
||||
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
|
||||
new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.get_input_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
pass
|
||||
|
||||
if train_lm_head:
|
||||
print("Unsloth: Training lm_head in mixed precision to save VRAM")
|
||||
assert(hasattr(model.model.lm_head, "modules_to_save"))
|
||||
assert(hasattr(model.get_output_embeddings(), "modules_to_save"))
|
||||
|
||||
dtype = model.model.lm_head.modules_to_save.default.weight.dtype
|
||||
model.model.lm_head.modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
|
||||
model.model.lm_head.modules_to_save.default.requires_grad_(True)
|
||||
new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
|
||||
if new_dtype == torch.float16:
|
||||
# See https://github.com/unslothai/unsloth/pull/1200
|
||||
# Tesla T4 must use float32 and not float16
|
||||
new_dtype = torch.float32
|
||||
pass
|
||||
|
||||
model.get_output_embeddings().modules_to_save.default\
|
||||
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
|
||||
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
|
||||
pass
|
||||
|
||||
# Patch tokenizer to pad to the right
|
||||
|
|
|
|||
|
|
@ -31,8 +31,17 @@ except:
|
|||
pass
|
||||
from huggingface_hub import HfFileSystem
|
||||
|
||||
# [TODO] Move USE_MODELSCOPE to utils
|
||||
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
|
||||
if USE_MODELSCOPE:
|
||||
import importlib
|
||||
if importlib.util.find_spec("modelscope") is None:
|
||||
raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
|
||||
pass
|
||||
pass
|
||||
|
||||
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
|
||||
from unsloth_zoo.utils import Version
|
||||
from unsloth_zoo.utils import Version, _get_dtype
|
||||
transformers_version = Version(transformers_version)
|
||||
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
|
||||
SUPPORTS_GEMMA = transformers_version >= Version("4.38")
|
||||
|
|
@ -47,28 +56,11 @@ if SUPPORTS_GEMMA2:
|
|||
pass
|
||||
import torch
|
||||
|
||||
def _get_dtype(dtype):
|
||||
__DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
torch.float32: torch.float32,
|
||||
"float16": torch.float16,
|
||||
torch.float16: torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
torch.bfloat16: torch.bfloat16,
|
||||
}
|
||||
if dtype is None or dtype == None: return None
|
||||
elif dtype in __DTYPE_MAP: return __DTYPE_MAP[dtype]
|
||||
else:
|
||||
print(f"Unsloth: {dtype} is not recognized, so we'll default to None")
|
||||
return None
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class FastLanguageModel(FastLlamaModel):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_name = "unsloth/llama-3-8b-bnb-4bit",
|
||||
model_name = "unsloth/Llama-3.2-1B-Instruct",
|
||||
max_seq_length = None,
|
||||
dtype = None,
|
||||
load_in_4bit = True,
|
||||
|
|
@ -80,12 +72,19 @@ class FastLanguageModel(FastLlamaModel):
|
|||
use_gradient_checkpointing = "unsloth",
|
||||
resize_model_vocab = None,
|
||||
revision = None,
|
||||
use_exact_model_name = False,
|
||||
*args, **kwargs,
|
||||
):
|
||||
if token is None: token = get_token()
|
||||
|
||||
old_model_name = model_name
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
if not use_exact_model_name:
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
|
||||
if USE_MODELSCOPE and not os.path.exists(model_name):
|
||||
from modelscope import snapshot_download
|
||||
model_name = snapshot_download(model_name)
|
||||
pass
|
||||
|
||||
# First check if it's a normal model via AutoConfig
|
||||
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
|
||||
|
|
@ -165,7 +164,9 @@ class FastLanguageModel(FastLlamaModel):
|
|||
# Get base model for PEFT:
|
||||
if is_peft:
|
||||
# Check base model again for PEFT
|
||||
model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
|
||||
model_name = peft_config.base_model_name_or_path
|
||||
if not use_exact_model_name:
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token = token,
|
||||
|
|
@ -354,6 +355,7 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
revision = None,
|
||||
return_logits = False, # Return logits
|
||||
fullgraph = True, # No graph breaks
|
||||
use_exact_model_name = False,
|
||||
*args, **kwargs,
|
||||
):
|
||||
if token is None: token = get_token()
|
||||
|
|
@ -361,10 +363,16 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
patch_compiled_autograd()
|
||||
patch_compiling_bitsandbytes()
|
||||
if use_gradient_checkpointing == "unsloth":
|
||||
patch_unsloth_smart_gradient_checkpointing()
|
||||
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
|
||||
|
||||
old_model_name = model_name
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
if not use_exact_model_name:
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
|
||||
if USE_MODELSCOPE and not os.path.exists(model_name):
|
||||
from modelscope import snapshot_download
|
||||
model_name = snapshot_download(model_name)
|
||||
pass
|
||||
|
||||
# First check if it's a normal model via AutoConfig
|
||||
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
|
||||
|
|
@ -410,7 +418,7 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
exist_config = os.path.exists(os.path.join(model_name, "config.json"))
|
||||
both_exist = exist_adapter_config and exist_config
|
||||
else:
|
||||
files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json"))
|
||||
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
|
||||
files = (os.path.split(x)[-1] for x in files)
|
||||
if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
|
||||
both_exist = True
|
||||
|
|
@ -443,7 +451,10 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
# Get base model for PEFT:
|
||||
if is_peft:
|
||||
# Check base model again for PEFT
|
||||
model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
|
||||
model_name = peft_config.base_model_name_or_path
|
||||
if not use_exact_model_name:
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
token = token,
|
||||
|
|
@ -454,7 +465,10 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
|
||||
if not was_disabled: enable_progress_bars()
|
||||
|
||||
with contextlib.redirect_stdout(open(os.devnull, "w")):
|
||||
do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
|
||||
redirector = sys.stdout if do_logging else open(os.devnull, "w")
|
||||
|
||||
with contextlib.redirect_stdout(redirector):
|
||||
patch_loss_functions(torch_compile = False)
|
||||
model_types = unsloth_compile_transformers(
|
||||
model_name = model_name,
|
||||
|
|
@ -470,7 +484,7 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
fuse_lm_head = True,
|
||||
gradient_checkpointing = True,
|
||||
manual_replacements = True,
|
||||
fast_lora_forwards = False,
|
||||
fast_lora_forwards = True,
|
||||
fast_residual_stream = False,
|
||||
accurate_accumulation = True,
|
||||
epilogue_fusion = True,
|
||||
|
|
@ -484,6 +498,7 @@ class FastVisionModel(FastBaseVisionModel):
|
|||
return_logits = return_logits,
|
||||
)
|
||||
pass
|
||||
if do_logging: redirector.close()
|
||||
|
||||
# Check if this is local model since the tokenizer gets overwritten
|
||||
if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
|
||||
|
|
|
|||
|
|
@ -64,9 +64,9 @@ def MistralAttention_fast_forward(
|
|||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
n_heads = self.num_heads
|
||||
n_heads = self.config.num_attention_heads
|
||||
n_groups = self.num_key_value_groups
|
||||
n_kv_heads = self.num_key_value_heads
|
||||
n_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
assert(n_kv_heads * n_groups == n_heads)
|
||||
|
||||
|
|
@ -278,16 +278,16 @@ pass
|
|||
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
|
||||
def patch_mistral_nemo_attention(function):
|
||||
function = function.replace(
|
||||
"(self.head_dim * self.num_heads) != self.hidden_size",
|
||||
"(self.head_dim * self.config.num_attention_heads) != self.config.hidden_size",
|
||||
"False",
|
||||
)
|
||||
function = function.replace(
|
||||
"self.head_dim = self.hidden_size // self.num_heads",
|
||||
"self.head_dim = self.config.hidden_size // self.config.num_attention_heads",
|
||||
"self.head_dim = config.head_dim",
|
||||
)
|
||||
function = function.replace(
|
||||
"self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)",
|
||||
"self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)",
|
||||
"self.o_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)",
|
||||
"self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)",
|
||||
)
|
||||
return function
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from transformers import set_seed as transformers_set_seed
|
|||
from unsloth_zoo.peft_utils import (
|
||||
get_peft_regex,
|
||||
SKIP_QUANTIZATION_MODULES,
|
||||
requires_grad_for_gradient_checkpointing,
|
||||
)
|
||||
from triton import __version__ as triton_version
|
||||
|
||||
|
|
@ -275,6 +276,8 @@ class FastBaseVisionModel:
|
|||
use_gradient_checkpointing = use_gradient_checkpointing,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
# Enable gradients on modules which are trainable
|
||||
requires_grad_for_gradient_checkpointing(model)
|
||||
|
||||
model = FastBaseVisionModel.patch_peft_model(model, use_gradient_checkpointing)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue