* use exact model name

* Update save.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* print

* Update _utils.py

* Update _utils.py

* Update llama.py

* Update _utils.py

* Update vision.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update loader.py

* accurate_accumulation

* Update loader.py

* Update loader.py

* Update _utils.py

* Update loader.py

* Update loader.py

* Update loader.py

* Update loader.py

* Update pyproject.toml

* Update __init__.py

* Update pyproject.toml

* Update __init__.py

* Update __init__.py

* Fix Triton heuristics

https://github.com/triton-lang/triton/issues/5224

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Xformers

* Update loader.py

* Update loader.py

* Rewind

* Update _utils.py

* Update _utils.py

* requires grad

* Update loader.py

* Update _utils.py

* Update loader.py

* changing model to base_model if peft model is already used

* Improve debugging experience (#1512)

* Create CONTRIBUTING.md (#1472)

Creating contributing guidelines

* Update CONTRIBUTING.md

improved sentence

* Improve logging control in `unsloth_compile_transformers` by conditionally redirecting stdout based on UNSLOTH_DISABLE_LOGGER environment variable

---------

Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com>

* Update loader.py

* Update llama.py

* Update llama.py

* Revert "Update llama.py"

This reverts commit b7ddf962d2.

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Auto change is_bfloat16_supported

* Update llama.py

* Force data-type

* Update llama.py

* All attention refactor fix (#1491)

* change initilization of n_heads, n_kv_heads, hidden_size in llama.py

* do the same for cohere, mistral, gemma2, granite

* do the same for flexattention,cohere, mistral, granite

* Update llama.py

* Update llama.py

* Update granite to work with latest post_patch methods (#1502)

* Update granite to work with latest post_patch methods

* Pass position_embeddings for granite even if transformers<4.47

* Update llama.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Minor fixes for granite models (#1503)

* Update granite.py

Grab residual multiplier directly from layer

* Update llama.py

Version should read >= 4.47.1 as that is the version requiring the changes

* Update granite.py

* Update llama.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* support modelscope models and datasets (#1481)

* support modelscope

* change modelscope args

* remove useless import

* remove useless import

* fix

* wip

* fix

* remove useless code

* add readme

* add some comments

* change print to raise error

* update comment

* Update loader.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Itsuro Tajima <tajima@georepublic.de>
Co-authored-by: Muhammad Osama <muhammadosama1994@gmail.com>
Co-authored-by: Edd <68678137+Erland366@users.noreply.github.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com>
Co-authored-by: Kareem <81531392+KareemMusleh@users.noreply.github.com>
Co-authored-by: Datta Nimmaturi <datta.nimmaturi@nutanix.com>
Co-authored-by: Z <coffeevampirebusiness@gmail.com>
Co-authored-by: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com>
This commit is contained in:
Daniel Han 2025-01-07 04:23:14 -08:00 committed by GitHub
parent 48627f876c
commit 63782ea3af
16 changed files with 307 additions and 160 deletions

View file

@ -212,6 +212,9 @@ For **advanced installation instructions** or if you see weird errors during ins
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`.
> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community.
```python
from unsloth import FastLanguageModel

View file

@ -148,20 +148,20 @@ cu124onlytorch250 = [
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu121onlytorch251 = [
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu124onlytorch251 = [
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118 = [
"unsloth[huggingface]",

View file

@ -30,11 +30,14 @@ Happy fine-tuning!
"""
import argparse
import os
def run(args):
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers.utils import strtobool
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
@ -86,8 +89,13 @@ def run(args):
texts.append(text)
return {"text": texts}
# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
if use_modelscope:
from modelscope import MsDataset
dataset = MsDataset.load(args.dataset, split="train")
else:
# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)
print("Data is formatted and ready!")

View file

@ -17,16 +17,6 @@ from packaging.version import Version
import os, re, subprocess, inspect
import numpy as np
# # Define a list of modules to check
# MODULES_TO_CHECK = ["bitsandbytes"]
# # Check if any of the modules in the list have been imported
# for module in MODULES_TO_CHECK:
# if module in sys.modules:
# raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
# pass
# pass
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
@ -55,7 +45,12 @@ else:
pass
# Reduce VRAM usage by reducing fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]"
# And optimize pinning of memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
"expandable_segments:True,"\
"roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\
"pinned_use_cuda_host_register:True,"\
"pinned_num_register_threads:8"
# Hugging Face Hub faster downloads
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
@ -89,6 +84,36 @@ elif (major_torch == 2) and (minor_torch < 2):
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
pass
# Fix Xformers performance issues since 0.0.25
import importlib.util
from pathlib import Path
from importlib.metadata import version as importlib_version
from packaging.version import Version
try:
xformers_version = importlib_version("xformers")
if Version(xformers_version) < Version("0.0.29"):
xformers_location = importlib.util.find_spec("xformers").origin
xformers_location = os.path.split(xformers_location)[0]
cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py"
if cutlass.exists():
with open(cutlass, "r+") as f:
text = f.read()
# See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591
if "num_splits_key=-1," in text:
text = text.replace("num_splits_key=-1,", "num_splits_key=None,")
f.seek(0)
f.write(text)
f.truncate()
print("Unsloth: Patching Xformers to fix some performance issues.")
pass
pass
pass
pass
except:
pass
pass
# Torch 2.4 has including_emulation
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = (major_version >= 8)
@ -166,9 +191,18 @@ pass
# Check for unsloth_zoo
try:
unsloth_zoo_version = importlib_version("unsloth_zoo")
if Version(unsloth_zoo_version) < Version("2025.1.1"):
try:
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
except:
try:
os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo")
except:
raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`")
import unsloth_zoo
except:
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`")
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`")
pass
from .models import *

View file

@ -25,11 +25,6 @@ from unsloth_zoo.loss_utils import (
)
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
})
@triton.jit
def _cross_entropy_forward(
logits_ptr ,
logits_row_stride ,
@ -95,13 +90,15 @@ def _cross_entropy_forward(
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass
_cross_entropy_forward = triton.jit(_cross_entropy_forward)
_cross_entropy_forward = triton.heuristics(
{
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
}
)(_cross_entropy_forward)
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr ,
logits_row_stride ,
@ -177,13 +174,15 @@ def _chunked_cross_entropy_forward(
pass
tl.store(logsumexp_ptr, logsumexp)
pass
_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
_chunked_cross_entropy_forward = triton.heuristics(
{
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
}
)(_chunked_cross_entropy_forward)
@triton.heuristics({
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
})
@triton.jit
def _cross_entropy_backward(
logits_ptr ,
logits_row_stride ,
@ -264,10 +263,16 @@ def _cross_entropy_backward(
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
_cross_entropy_backward = triton.heuristics(
{
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
}
)(_cross_entropy_backward)
MAX_FUSED_SIZE = 65536 # 2**16
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):

View file

@ -43,9 +43,9 @@ if not HAS_FLEX_ATTENTION:
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention
@ -130,7 +130,7 @@ else:
pass
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
head_dim = self.head_dim
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
@ -147,9 +147,9 @@ torch_matmul = torch.matmul
torch_tanh = torch.tanh
torch_nn_functional_softmax = torch.nn.functional.softmax
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
head_dim = self.head_dim
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
n_groups = self.num_key_value_groups
# Grouped query attention

View file

@ -53,8 +53,6 @@ def _rms_layernorm_forward(
pass
@triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),})
@triton.jit
def _rms_layernorm_backward(
dY, dY_row_stride,
dX, dX_row_stride,
@ -97,6 +95,12 @@ def _rms_layernorm_backward(
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
tl.store(dX + col_offsets, output, mask = mask)
pass
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
_rms_layernorm_backward = triton.heuristics(
{
"GEMMA": lambda args: bool(args["GEMMA"]),
}
)(_rms_layernorm_backward)
@triton.jit

View file

@ -18,8 +18,6 @@ import torch
from .utils import calculate_settings
ROPE_GROUP_SIZE : int = 4
@triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),})
@triton.jit
def _rope_embedding(
Q, Q_row_stride,
cos, cos_row_stride,
@ -69,6 +67,12 @@ def _rope_embedding(
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
pass
pass
_rope_embedding = triton.jit(_rope_embedding)
_rope_embedding = triton.heuristics(
{
"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
}
)(_rope_embedding)
class Fast_RoPE_Embedding(torch.autograd.Function):

View file

@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2024.12.12"
__version__ = "2025.1.1"
__all__ = [
"SUPPORTS_BFLOAT16",
"is_bfloat16_supported",
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
@ -30,7 +33,6 @@ __all__ = [
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
"patch_linear_scaling",
@ -58,7 +60,6 @@ __all__ = [
"fused_linear_cross_entropy",
"patch_unsloth_smart_gradient_checkpointing",
"unpatch_unsloth_smart_gradient_checkpointing",
"create_gradient_checkpointing_buffer",
"patch_compiled_autograd",
"process_vision_info",
@ -97,7 +98,6 @@ from unsloth_zoo.gradient_checkpointing import (
patch_unsloth_smart_gradient_checkpointing,
unpatch_unsloth_smart_gradient_checkpointing,
create_gradient_checkpointing_buffer,
)
from unsloth_zoo.loss_utils import (
HAS_CUT_CROSS_ENTROPY,
@ -556,6 +556,7 @@ def prepare_model_for_kbit_training(
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
pass
return model
pass
@ -1203,8 +1204,6 @@ def unsloth_compile_transformers(
return
pass
if disable: return
model_types = get_transformers_model_type(
model_name = model_name,
token = token,
@ -1212,6 +1211,8 @@ def unsloth_compile_transformers(
trust_remote_code = trust_remote_code,
)
if disable: return
for model_type in model_types:
_unsloth_compile_transformers(
model_type,

View file

@ -94,9 +94,9 @@ def CohereAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
@ -259,12 +259,14 @@ def CohereAttention_fast_forward_inference(
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
@ -281,10 +283,10 @@ def CohereAttention_fast_forward_inference(
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
pass
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")

View file

@ -98,9 +98,9 @@ def Gemma2Attention_fast_forward(
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
@ -255,12 +255,14 @@ def Gemma2Attention_fast_forward_inference(
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
@ -276,7 +278,7 @@ def Gemma2Attention_fast_forward_inference(
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e

View file

@ -20,7 +20,8 @@ from .llama import (
LlamaLinearScalingRotaryEmbedding,
)
from .mistral import *
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
try:
from transformers.models.granite.modeling_granite import (
GraniteAttention,
@ -84,9 +85,9 @@ def GraniteAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
@ -181,6 +182,11 @@ def GraniteDecoderLayer_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
):
residual_multiplier = \
self.residual_multiplier \
if hasattr(self, "residual_multiplier") else \
self.config.residual_multiplier
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
@ -196,13 +202,13 @@ def GraniteDecoderLayer_fast_forward(
position_embeddings = position_embeddings,
_flag_for_generation=self._flag_for_generation,
)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
@ -217,13 +223,13 @@ def GraniteDecoderLayer_fast_forward(
padding_mask=padding_mask,
position_embeddings = position_embeddings,
)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
pass
outputs = (hidden_states,)
@ -257,12 +263,14 @@ def GraniteAttention_fast_forward_inference(
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
@ -278,7 +286,7 @@ def GraniteAttention_fast_forward_inference(
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
@ -367,6 +375,10 @@ def GraniteModel_fast_forward_inference(
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
hidden_states *= self.model.embedding_multiplier
residual_multiplier = \
self.residual_multiplier \
if hasattr(self, "residual_multiplier") else \
self.config.residual_multiplier
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
@ -398,12 +410,12 @@ def GraniteModel_fast_forward_inference(
position_embeddings = position_embeddings,
)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
next_decoder_cache.append(present_key_value)
pass
@ -421,6 +433,18 @@ class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, config):
super().__init__(config = config)
def patched_init(original_init):
def new_init(self, *args, **kwargs):
# we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here
# https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243
# The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference
# So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later
config = kwargs.get("config", args[0] if args else None)
if config is not None:
self.config = config
original_init(self, *args, **kwargs)
return new_init
class FastGraniteModel(FastLlamaModel):
@staticmethod
@ -435,12 +459,13 @@ class FastGraniteModel(FastLlamaModel):
exec(function, globals())
GraniteAttention.__init__ = eval(init_name)
pass
GraniteAttention .forward = GraniteAttention_fast_forward
GraniteSdpaAttention .forward = GraniteAttention_fast_forward
GraniteFlashAttention2.forward = GraniteAttention_fast_forward
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
GraniteModel .forward = LlamaModel_fast_forward
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
GraniteAttention .forward = GraniteAttention_fast_forward
GraniteSdpaAttention .forward = GraniteAttention_fast_forward
GraniteFlashAttention2.forward = GraniteAttention_fast_forward
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
GraniteModel .forward = LlamaModel_fast_forward
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
GraniteForCausalLM .__init__ = patched_init(GraniteForCausalLM.__init__)
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GraniteForCausalLM)
@ -452,7 +477,7 @@ class FastGraniteModel(FastLlamaModel):
@staticmethod
def post_patch(model):
def post_patch(model, tokenizer):
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
@ -517,7 +542,7 @@ class FastGraniteModel(FastLlamaModel):
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model
return model, tokenizer
pass
pass

View file

@ -20,6 +20,10 @@ from ._utils import *
from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers import __version__ as transformers_version
from unsloth_zoo.utils import Version
transformers_version = Version(transformers_version)
# Transformers moved rotary embeddings out of all attention layers
IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1")
from transformers.models.llama.modeling_llama import (
logger,
BaseModelOutputWithPast,
@ -146,12 +150,14 @@ def LlamaAttention_fast_forward_inference(
K1, V1 = past_key_value
dtype = Xn.dtype
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
attention_size = n_heads*head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
@ -168,10 +174,10 @@ def LlamaAttention_fast_forward_inference(
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
self.temp_O = self.temp_QA[1][:,:,:hidden_size]
pass
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
@ -356,9 +362,9 @@ def LlamaAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
@ -786,7 +792,7 @@ def LlamaModel_fast_forward(
pass
pass
if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"):
if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"):
# Transformers main has made it mandatory to pass position_embeddings
# https://github.com/huggingface/transformers/pull/34858
position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings)
@ -996,18 +1002,20 @@ def CausalLM_fast_forward(fast_forward_inference):
lm_head = self.lm_head.weight
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
logit_scaling = getattr(self.config, "logit_scale", 0)
dtype = lm_head.dtype
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
elif num_logits_to_keep != 0:
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(dtype))
else:
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
# < 1024 Normal Unsloth uses less VRAM!
if bsz*q_len <= 1024: RETURN_LOGITS = True
if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
loss = fused_linear_cross_entropy(
hidden_states = hidden_states,
@ -1029,7 +1037,7 @@ def CausalLM_fast_forward(fast_forward_inference):
)
return output
pass
logits = self.lm_head(hidden_states.to(lm_head.dtype))
logits = self.lm_head(hidden_states.to(dtype))
pass
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
@ -1607,6 +1615,9 @@ class FastLlamaModel:
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
dtype = torch.bfloat16
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
@ -1879,6 +1890,13 @@ class FastLlamaModel:
internal_model = internal_model.model
pass
internal_model._saved_temp_tokenizer = tokenizer
# For transformers > 4.47.1, we need to add rotary_emb to all attention layers
if IS_ATTENTION_REFACTOR or hasattr(model.model, "rotary_emb"):
rotary_emb = model.model.rotary_emb
for layer in model.model.layers:
layer.self_attn.rotary_emb = rotary_emb
pass
return model, tokenizer
pass
@ -1967,29 +1985,41 @@ class FastLlamaModel:
if "embed_tokens" in new_target_modules:
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
if new_dtype == torch.float16:
# See https://github.com/unslothai/unsloth/pull/1200
# Tesla T4 must use float32 and not float16
new_dtype = torch.float32
pass
model.get_input_embeddings().modules_to_save.default\
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
# [TODO] Move old embed_tokens to CPU - should be disk!
model.model.model.embed_tokens.original_module\
model.get_input_embeddings().original_module\
.to(device = "cpu", non_blocking = True)
model.model.model.embed_tokens.original_module.requires_grad_(False)
model.get_input_embeddings().original_module.requires_grad_(False)
pass
if "lm_head" in new_target_modules:
print("Unsloth: Training lm_head in mixed precision to save VRAM")
dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
if new_dtype == torch.float16:
# See https://github.com/unslothai/unsloth/pull/1200
# Tesla T4 must use float32 and not float16
new_dtype = torch.float32
pass
model.get_output_embeddings().modules_to_save.default\
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
# [TODO] Move old lm_head to CPU - should be disk!
model.model.lm_head.original_module\
model.get_output_embeddings().original_module\
.to(device = "cpu", non_blocking = True)
model.model.lm_head.original_module.requires_grad_(False)
model.get_output_embeddings().original_module.requires_grad_(False)
pass
return model
@ -2216,25 +2246,36 @@ class FastLlamaModel:
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
# Now patch lm_head and embed_tokens
if train_embed_tokens:
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
assert(hasattr(model.model.model.embed_tokens, "modules_to_save"))
assert(hasattr(model.get_input_embeddings(), "modules_to_save"))
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype
model.model.model.embed_tokens.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True)
new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
if new_dtype == torch.float16:
# See https://github.com/unslothai/unsloth/pull/1200
# Tesla T4 must use float32 and not float16
new_dtype = torch.float32
pass
model.get_input_embeddings().modules_to_save.default\
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
pass
if train_lm_head:
print("Unsloth: Training lm_head in mixed precision to save VRAM")
assert(hasattr(model.model.lm_head, "modules_to_save"))
assert(hasattr(model.get_output_embeddings(), "modules_to_save"))
dtype = model.model.lm_head.modules_to_save.default.weight.dtype
model.model.lm_head.modules_to_save.default\
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True)
model.model.lm_head.modules_to_save.default.requires_grad_(True)
new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
if new_dtype == torch.float16:
# See https://github.com/unslothai/unsloth/pull/1200
# Tesla T4 must use float32 and not float16
new_dtype = torch.float32
pass
model.get_output_embeddings().modules_to_save.default\
.to(device = "cuda:0", dtype = new_dtype, non_blocking = True)
model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
pass
# Patch tokenizer to pad to the right

View file

@ -31,8 +31,17 @@ except:
pass
from huggingface_hub import HfFileSystem
# [TODO] Move USE_MODELSCOPE to utils
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if USE_MODELSCOPE:
import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
pass
pass
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from unsloth_zoo.utils import Version
from unsloth_zoo.utils import Version, _get_dtype
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
SUPPORTS_GEMMA = transformers_version >= Version("4.38")
@ -47,28 +56,11 @@ if SUPPORTS_GEMMA2:
pass
import torch
def _get_dtype(dtype):
__DTYPE_MAP = {
"float32": torch.float32,
torch.float32: torch.float32,
"float16": torch.float16,
torch.float16: torch.float16,
"bfloat16": torch.bfloat16,
torch.bfloat16: torch.bfloat16,
}
if dtype is None or dtype == None: return None
elif dtype in __DTYPE_MAP: return __DTYPE_MAP[dtype]
else:
print(f"Unsloth: {dtype} is not recognized, so we'll default to None")
return None
pass
pass
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name = "unsloth/llama-3-8b-bnb-4bit",
model_name = "unsloth/Llama-3.2-1B-Instruct",
max_seq_length = None,
dtype = None,
load_in_4bit = True,
@ -80,12 +72,19 @@ class FastLanguageModel(FastLlamaModel):
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None,
revision = None,
use_exact_model_name = False,
*args, **kwargs,
):
if token is None: token = get_token()
old_model_name = model_name
model_name = get_model_name(model_name, load_in_4bit)
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass
# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
@ -165,7 +164,9 @@ class FastLanguageModel(FastLlamaModel):
# Get base model for PEFT:
if is_peft:
# Check base model again for PEFT
model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
model_name = peft_config.base_model_name_or_path
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
@ -354,6 +355,7 @@ class FastVisionModel(FastBaseVisionModel):
revision = None,
return_logits = False, # Return logits
fullgraph = True, # No graph breaks
use_exact_model_name = False,
*args, **kwargs,
):
if token is None: token = get_token()
@ -361,10 +363,16 @@ class FastVisionModel(FastBaseVisionModel):
patch_compiled_autograd()
patch_compiling_bitsandbytes()
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing()
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
old_model_name = model_name
model_name = get_model_name(model_name, load_in_4bit)
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass
# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
@ -410,7 +418,7 @@ class FastVisionModel(FastBaseVisionModel):
exist_config = os.path.exists(os.path.join(model_name, "config.json"))
both_exist = exist_adapter_config and exist_config
else:
files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json"))
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
files = (os.path.split(x)[-1] for x in files)
if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
both_exist = True
@ -443,7 +451,10 @@ class FastVisionModel(FastBaseVisionModel):
# Get base model for PEFT:
if is_peft:
# Check base model again for PEFT
model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit)
model_name = peft_config.base_model_name_or_path
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
@ -454,7 +465,10 @@ class FastVisionModel(FastBaseVisionModel):
if not was_disabled: enable_progress_bars()
with contextlib.redirect_stdout(open(os.devnull, "w")):
do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
redirector = sys.stdout if do_logging else open(os.devnull, "w")
with contextlib.redirect_stdout(redirector):
patch_loss_functions(torch_compile = False)
model_types = unsloth_compile_transformers(
model_name = model_name,
@ -470,7 +484,7 @@ class FastVisionModel(FastBaseVisionModel):
fuse_lm_head = True,
gradient_checkpointing = True,
manual_replacements = True,
fast_lora_forwards = False,
fast_lora_forwards = True,
fast_residual_stream = False,
accurate_accumulation = True,
epilogue_fusion = True,
@ -484,6 +498,7 @@ class FastVisionModel(FastBaseVisionModel):
return_logits = return_logits,
)
pass
if do_logging: redirector.close()
# Check if this is local model since the tokenizer gets overwritten
if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \

View file

@ -64,9 +64,9 @@ def MistralAttention_fast_forward(
bsz, q_len, _ = hidden_states.size()
n_heads = self.num_heads
n_heads = self.config.num_attention_heads
n_groups = self.num_key_value_groups
n_kv_heads = self.num_key_value_heads
n_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
assert(n_kv_heads * n_groups == n_heads)
@ -278,16 +278,16 @@ pass
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_attention(function):
function = function.replace(
"(self.head_dim * self.num_heads) != self.hidden_size",
"(self.head_dim * self.config.num_attention_heads) != self.config.hidden_size",
"False",
)
function = function.replace(
"self.head_dim = self.hidden_size // self.num_heads",
"self.head_dim = self.config.hidden_size // self.config.num_attention_heads",
"self.head_dim = config.head_dim",
)
function = function.replace(
"self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)",
"self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)",
"self.o_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)",
"self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)",
)
return function
pass

View file

@ -30,6 +30,7 @@ from transformers import set_seed as transformers_set_seed
from unsloth_zoo.peft_utils import (
get_peft_regex,
SKIP_QUANTIZATION_MODULES,
requires_grad_for_gradient_checkpointing,
)
from triton import __version__ as triton_version
@ -275,6 +276,8 @@ class FastBaseVisionModel:
use_gradient_checkpointing = use_gradient_checkpointing,
)
model = get_peft_model(model, lora_config)
# Enable gradients on modules which are trainable
requires_grad_for_gradient_checkpointing(model)
model = FastBaseVisionModel.patch_peft_model(model, use_gradient_checkpointing)