unsloth/unsloth/models/_utils.py
2026-04-16 12:06:10 -07:00

3409 lines
123 KiB
Python

# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "2026.4.6"
__all__ = [
"SUPPORTS_BFLOAT16",
"is_bfloat16_supported",
"is_vLLM_available",
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"importlib_version",
"HAS_FLASH_ATTENTION",
"HAS_FLASH_ATTENTION_SOFTCAPPING",
"USE_MODELSCOPE",
"platform_system",
"resolve_hip_gpu_stats_name",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
"patch_linear_scaling",
"patch_llama_rope_scaling",
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
# "accelerate_old_send_to_device",
# "accelerate_new_send_to_device",
"patch_gradient_accumulation_fix",
"apply_accepts_loss_kwargs_fix",
"patch_compiling_bitsandbytes",
"patch_regional_compilation",
"patch_layernorm",
"patch_torch_compile",
"patch_model_and_tokenizer",
"patch_unsloth_gradient_checkpointing",
"unpatch_unsloth_gradient_checkpointing",
"patch_gradient_checkpointing",
"unpatch_gradient_checkpointing",
"HAS_CUT_CROSS_ENTROPY",
"EMPTY_LOGITS",
"fused_linear_cross_entropy",
"unsloth_fused_ce_loss",
"patch_unsloth_smart_gradient_checkpointing",
"unpatch_unsloth_smart_gradient_checkpointing",
"apply_unsloth_gradient_checkpointing",
"patch_compiled_autograd",
"process_vision_info",
"unsloth_compile_transformers",
"resolve_model_class",
"resolve_attention_implementation",
"resolve_encoder_attention_implementation",
"_set_attn_impl",
"patch_fast_lora",
"validate_loftq_config",
"RaiseUninitialized",
"fast_inference_setup",
"patch_peft_fast_inference",
"error_out_no_vllm",
"dequantize_module_weight",
"patch_hf_quantizer",
"verify_fp8_support_if_applicable",
"_get_inference_mode_context_manager",
"hf_login",
"is_moe_model",
"get_moe_target_parameters",
"make_fast_generate_wrapper",
]
import torch
from typing import Union, Optional, List, Any, Callable, Tuple, Iterator
from platform import system as platform_system
platform_system = platform_system()
import numpy as np
import contextlib
import re
from dataclasses import dataclass, field
import functools
import textwrap
import logging
import warnings, subprocess, inspect, psutil, os, math
from unsloth_zoo.utils import Version, get_quant_type
from importlib.metadata import version as importlib_version
from ..device_type import (
is_hip,
get_device_type,
DEVICE_TYPE,
DEVICE_TYPE_TORCH,
DEVICE_COUNT,
ALLOW_PREQUANTIZED_MODELS,
)
from ..import_fixes import UNSLOTH_ENABLE_LOGGING
from unsloth_zoo.log import logger
from unsloth_zoo.tokenizer_utils import (
patch_tokenizer as _patch_tokenizer,
)
from unsloth_zoo.rl_environments import (
check_python_modules,
create_locked_down_function,
execute_with_time_limit,
Benchmarker,
)
from unsloth_zoo.patching_utils import (
patch_compiling_bitsandbytes,
patch_layernorm,
patch_torch_compile,
patch_model_and_tokenizer,
patch_compiled_autograd,
)
from unsloth_zoo.gradient_checkpointing import (
Unsloth_Offloaded_Gradient_Checkpointer,
unsloth_offloaded_gradient_checkpoint,
patch_unsloth_gradient_checkpointing,
unpatch_unsloth_gradient_checkpointing,
Unsloth_Gradient_Checkpointer,
unsloth_gradient_checkpoint,
patch_gradient_checkpointing,
unpatch_gradient_checkpointing,
patch_unsloth_smart_gradient_checkpointing,
unpatch_unsloth_smart_gradient_checkpointing,
)
from unsloth_zoo.loss_utils import (
HAS_CUT_CROSS_ENTROPY,
fused_linear_cross_entropy,
_unsloth_get_batch_samples,
unsloth_fused_ce_loss,
)
from unsloth_zoo.vision_utils import (
process_vision_info,
)
from unsloth_zoo.compiler import (
get_transformers_model_type,
unsloth_compile_transformers as _unsloth_compile_transformers,
)
from unsloth_zoo.training_utils import (
prepare_model_for_training,
)
def resolve_hip_gpu_stats_name(gpu_stats):
name = str(getattr(gpu_stats, "name", "") or "").strip()
name = re.sub(r"\s*\([^)]*\)\s*$", "", name).strip()
normalized_name = name.lower().strip(". ")
if normalized_name and normalized_name not in ("amd radeon graphics",):
return name + ". "
try:
torch_name = str(torch.cuda.get_device_name(0) or "").strip()
torch_name = re.sub(r"\s*\([^)]*\)\s*$", "", torch_name).strip()
except Exception:
torch_name = ""
normalized_torch_name = torch_name.lower().strip(". ")
if normalized_torch_name and normalized_torch_name not in ("amd radeon graphics",):
return torch_name + ". "
arch_name = ""
for key in ("gcnArchName", "gcn_arch_name", "arch_name", "gfx_arch_name"):
value = getattr(gpu_stats, key, None)
if value is not None and str(value).strip():
arch_name = str(value).strip()
break
if arch_name:
arch_name = arch_name.strip()
match = re.search(r"(gfx[0-9a-z]+)", arch_name, flags = re.I)
if match:
return f"AMD {match.group(1).lower()} GPU. "
return "AMD GPU. "
from unsloth_zoo.temporary_patches import (
TEMPORARY_PATCHES,
)
def apply_unsloth_gradient_checkpointing(
use_gradient_checkpointing, max_seq_length, dtype
):
"""
Apply gradient checkpointing with smart heuristics.
For seq < 512, the overhead of gradient offloading in gc="unsloth" mode
is not worth it. Benchmarks show standard gc is faster for small sequences.
Args:
use_gradient_checkpointing: "unsloth", True, False, or None
max_seq_length: The maximum sequence length
dtype: The model dtype for patching
Returns:
The effective use_gradient_checkpointing value (may change from "unsloth" to True)
"""
if use_gradient_checkpointing == "unsloth":
# Gradient offloading overhead is not worth it for small sequences.
# Benchmarks show crossover point is around seq_len 384-512.
# For seq < 512, standard gradient checkpointing is faster.
if max_seq_length < 512:
unpatch_unsloth_smart_gradient_checkpointing()
return True
else:
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
return "unsloth"
elif use_gradient_checkpointing in (True, False):
# User explicitly set True or False - unpatch any previous "unsloth" patching
unpatch_unsloth_smart_gradient_checkpointing()
return use_gradient_checkpointing
return use_gradient_checkpointing
# Models that don't work with flex_attention:
# GPT-OSS: left padding issues cause incorrect outputs.
# Mllama: BlockMask Q_LEN!=KV_LEN ValueError on decode.
# NemotronH: hybrid Mamba-2 + Transformer, raises NotImplementedError.
# Gemma3N: timm vision wrappers don't support flex_attention.
# ModernBERT: create_block_mask with _compile=True hits CUDA illegal memory
# access on some GPU architectures (B200). Falls back to eager safely.
_FLEX_EXCLUDED_MODELS = ("gpt_oss", "mllama", "nemotron_h", "modernbert")
_EAGER_ONLY_PREFIXES = ("gemma3n",)
_FLASH_ATTENTION_MAX_HEAD_DIM = 256
_FLASH_ATTENTION_DISABLED_WARNED = set()
def _is_flex_excluded(model_type):
return model_type in _FLEX_EXCLUDED_MODELS
def _is_eager_only(model_type):
return any(model_type.startswith(p) for p in _EAGER_ONLY_PREFIXES)
def _config_items(config):
if isinstance(config, dict):
return config.items()
if hasattr(config, "__dict__"):
return vars(config).items()
return ()
def _config_get(config, field_name, default = None):
if isinstance(config, dict):
return config.get(field_name, default)
return getattr(config, field_name, default)
def _config_set(config, field_name, value):
if isinstance(config, dict):
config[field_name] = value
elif config is not None:
setattr(config, field_name, value)
def _iter_attention_configs(config, seen = None):
if config is None or (
not isinstance(config, dict) and not hasattr(config, "__dict__")
):
return
if seen is None:
seen = set()
config_id = id(config)
if config_id in seen:
return
seen.add(config_id)
yield config
for field_name, child_config in _config_items(config):
if not isinstance(field_name, str) or not field_name.endswith("_config"):
continue
if isinstance(child_config, dict) or hasattr(child_config, "__dict__"):
yield from _iter_attention_configs(child_config, seen)
def _collect_attention_head_dims(config):
explicit_head_dims = []
for field_name in (
"head_dim",
"global_head_dim",
"local_head_dim",
"kv_head_dim",
):
value = _config_get(config, field_name, None)
if isinstance(value, int) and value > 0:
explicit_head_dims.append(value)
if len(explicit_head_dims) != 0:
return explicit_head_dims
head_dims = []
hidden_size_names = ("hidden_size", "d_model", "embed_dim", "dim")
num_heads_names = ("num_attention_heads", "num_heads", "n_heads")
for hidden_size_name in hidden_size_names:
hidden_size = _config_get(config, hidden_size_name, None)
if not isinstance(hidden_size, int) or hidden_size <= 0:
continue
for num_heads_name in num_heads_names:
num_heads = _config_get(config, num_heads_name, None)
if (
isinstance(num_heads, int)
and num_heads > 0
and (hidden_size % num_heads) == 0
):
head_dims.append(hidden_size // num_heads)
return head_dims
def _get_max_attention_head_dim(config):
head_dims = []
for attention_config in _iter_attention_configs(config):
head_dims.extend(_collect_attention_head_dims(attention_config))
return max(head_dims) if len(head_dims) != 0 else None
def _get_flash_attention_disable_reason(config):
max_head_dim = _get_max_attention_head_dim(config)
if max_head_dim is not None and max_head_dim > _FLASH_ATTENTION_MAX_HEAD_DIM:
return (
f"max attention head dim {max_head_dim} exceeds the Flash Attention 2 "
f"limit of {_FLASH_ATTENTION_MAX_HEAD_DIM}"
)
return None
def _is_flash_attention_disabled(config):
return _get_flash_attention_disable_reason(config) is not None
def _is_flash_attention_requested(attn_implementation):
return isinstance(attn_implementation, str) and attn_implementation.startswith(
"flash_attention"
)
def _disable_flash_attention_if_needed(
config,
attn_implementation = None,
supports_sdpa = False,
would_use_flash_attention = False,
disable_reason = None,
):
if disable_reason is None:
disable_reason = _get_flash_attention_disable_reason(config)
if disable_reason is None:
return attn_implementation
requested_attn_implementation = attn_implementation
if requested_attn_implementation is None:
requested_attn_implementation = _config_get(
config, "_attn_implementation", None
)
if requested_attn_implementation is None:
requested_attn_implementation = _config_get(config, "attn_implementation", None)
if requested_attn_implementation == "eager":
return _set_attn_impl(config, "eager")
fallback_attn_implementation = "sdpa" if supports_sdpa else "eager"
if (
_is_flash_attention_requested(requested_attn_implementation)
or would_use_flash_attention
):
logged_attn_implementation = (
requested_attn_implementation
if _is_flash_attention_requested(requested_attn_implementation)
else "flash_attention_2"
)
model_type = _config_get(config, "model_type", "")
warning_key = (
model_type,
logged_attn_implementation,
fallback_attn_implementation,
disable_reason,
)
if warning_key not in _FLASH_ATTENTION_DISABLED_WARNED:
_FLASH_ATTENTION_DISABLED_WARNED.add(warning_key)
print(
f"Unsloth: `{logged_attn_implementation}` is not supported "
f"for `{model_type}` because {disable_reason} - "
f"defaulting to `{fallback_attn_implementation}`."
)
return _set_attn_impl(config, fallback_attn_implementation)
def _set_attn_impl(config, impl):
if config is not None:
_config_set(config, "_attn_implementation", impl)
if isinstance(config, dict) or hasattr(config, "attn_implementation"):
_config_set(config, "attn_implementation", impl)
return impl
def resolve_model_class(auto_model, config):
mapping = getattr(auto_model, "_model_mapping", {})
try:
result = mapping[config.__class__]
except Exception:
for config_class, model_class in mapping.items():
if isinstance(config, config_class):
result = model_class
break
else:
return None
return result[0] if isinstance(result, (list, tuple)) else result
def resolve_attention_implementation(
model_class,
config,
requested_attn_implementation = None,
supports_sdpa = None,
):
model_type_name = _config_get(config, "model_type", "")
model_type = model_type_name.lower()
if supports_sdpa is None:
supports_sdpa = model_class is not None and getattr(
model_class, "_supports_sdpa", False
)
supports_flash_attention = model_class is not None and (
getattr(model_class, "_supports_flash_attn_2", False)
or getattr(model_class, "_supports_flash_attn", False)
)
disable_reason = _get_flash_attention_disable_reason(config)
flash_attention_disabled = disable_reason is not None
if model_class is None:
attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager")
else:
if _is_eager_only(model_type):
attn_impl = _set_attn_impl(config, "eager")
elif flash_attention_disabled:
attn_impl = _disable_flash_attention_if_needed(
config,
supports_sdpa = supports_sdpa,
would_use_flash_attention = (
HAS_FLASH_ATTENTION and supports_flash_attention
),
disable_reason = disable_reason,
)
elif HAS_FLASH_ATTENTION and supports_flash_attention:
attn_impl = _set_attn_impl(config, "flash_attention_2")
elif supports_sdpa:
attn_impl = _set_attn_impl(config, "sdpa")
else:
attn_impl = "eager"
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
try:
from transformers.utils.import_utils import (
is_torch_flex_attn_available,
)
if (
is_torch_flex_attn_available()
and getattr(model_class, "_supports_flex_attn", False)
and not _is_flex_excluded(model_type)
):
attention_dropout = (
_config_get(config, "attention_dropout", 0) or 0
)
if attention_dropout == 0:
attn_impl = _set_attn_impl(config, "flex_attention")
except Exception:
pass
if attn_impl == "eager":
attn_impl = _set_attn_impl(config, "eager")
if requested_attn_implementation is None:
final_attn_impl = attn_impl
elif flash_attention_disabled:
final_attn_impl = _disable_flash_attention_if_needed(
config,
requested_attn_implementation,
supports_sdpa = supports_sdpa,
disable_reason = disable_reason,
)
else:
final_attn_impl = requested_attn_implementation
_set_attn_impl(config, final_attn_impl)
if not supports_sdpa and final_attn_impl == "sdpa":
print(
f"Unsloth: {(model_type_name or 'model').title()} does not support SDPA - switching to fast eager."
)
final_attn_impl = _set_attn_impl(config, "eager")
return final_attn_impl
def resolve_encoder_attention_implementation(
auto_model,
config,
model_type = "",
disable_sdpa_model_names = (),
):
model_class = resolve_model_class(auto_model, config)
supports_sdpa = model_class is not None and getattr(
model_class, "_supports_sdpa", False
)
if any(name in model_type.lower() for name in disable_sdpa_model_names):
return "eager"
if supports_sdpa:
return "sdpa"
return None
def _run_temporary_patches(phase):
import inspect
for temporary_patch in TEMPORARY_PATCHES:
try:
sig = inspect.signature(temporary_patch)
if "phase" in sig.parameters:
temporary_patch(phase = phase)
else:
temporary_patch()
except (ValueError, TypeError):
temporary_patch()
_run_temporary_patches("init")
# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(
action = "ignore", category = FutureWarning, module = "huggingface_hub"
)
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(
action = "ignore", category = RuntimeWarning, module = "multiprocessing"
)
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "bitsandbytes")
# Stop "Special tokens have been added in the vocabulary, ..."
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL + 1)
TORCHAO_MSG = "Error: torchao not found, please install with `pip install torchao`"
# Ignore logging messages
class HideLoggingMessage(logging.Filter):
__slots__ = ("text",)
def __init__(self, text):
self.text = text
def filter(self, x):
return not (self.text in x.getMessage())
# Replace warning messages (analogous to HideLoggingMessage but for warnings.warn)
class ReplaceWarningMessage:
"""
Intercepts warnings.warn calls and replaces matching messages with Unsloth branded ones.
Uses a list of registered (match_text, replacement, category) rules checked in order.
"""
_rules = []
_original_showwarning = None
_installed = False
@classmethod
def add_rule(cls, match_text, replacement, category = None):
cls._rules.append((match_text, replacement, category))
if not cls._installed:
cls._install()
@classmethod
def _install(cls):
cls._original_showwarning = warnings.showwarning
cls._installed = True
def _patched_showwarning(
message, category, filename, lineno, file = None, line = None
):
msg_str = str(message)
for match_text, replacement, match_category in cls._rules:
if match_text in msg_str and (
match_category is None or category is match_category
):
print(replacement)
return
cls._original_showwarning(message, category, filename, lineno, file, line)
warnings.showwarning = _patched_showwarning
# Stop vLLM messages
if not UNSLOTH_ENABLE_LOGGING:
try:
from vllm.worker.worker import logger as vllm_worker_logger
vllm_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
del vllm_worker_logger
except:
pass
try:
from vllm.v1.worker.gpu_worker import logger as vllm_gpu_worker_logger
vllm_gpu_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
del vllm_gpu_worker_logger
except:
pass
try:
from vllm.executor.executor_base import logger as vllm_executor_logger
vllm_executor_logger.addFilter(HideLoggingMessage("to fall asleep"))
vllm_executor_logger.addFilter(HideLoggingMessage("to wake up"))
vllm_executor_logger.addFilter(HideLoggingMessage("Executor is not sleeping"))
del vllm_executor_logger
except:
pass
try:
from vllm.v1.executor.abstract import logger as vllm_v1_executor_logger
vllm_v1_executor_logger.addFilter(HideLoggingMessage("to fall asleep"))
vllm_v1_executor_logger.addFilter(HideLoggingMessage("to wake up"))
vllm_v1_executor_logger.addFilter(
HideLoggingMessage("Executor is not sleeping")
)
del vllm_v1_executor_logger
except:
pass
try:
from vllm.core.block.prefix_caching_block import (
logger as vllm_prefix_caching_logger,
)
vllm_prefix_caching_logger.addFilter(HideLoggingMessage("reset prefix cache"))
del vllm_prefix_caching_logger
except:
pass
try:
from vllm.v1.core.block_pool import logger as vllm_block_pool_logger
vllm_block_pool_logger.addFilter(HideLoggingMessage("reset prefix cache"))
del vllm_block_pool_logger
except:
pass
try:
from vllm.lora.models import logger as vllm_lora_model_logger
vllm_lora_model_logger.addFilter(
HideLoggingMessage(
"Regarding multimodal models, vLLM currently only supports adding"
)
)
del vllm_lora_model_logger
except:
pass
try:
from vllm.attention.utils.fa_utils import (
logger as vllm_attention_utils_fa_utils_logger,
)
vllm_attention_utils_fa_utils_logger.addFilter(
HideLoggingMessage("Cannot use FA version")
)
del vllm_attention_utils_fa_utils_logger
except:
pass
# The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.
from transformers.training_args import logger as transformers_training_args_logger
transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED.
transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed"))
# average_tokens_across_devices is set to True but it is invalid when world size is1
transformers_training_args_logger.addFilter(
HideLoggingMessage("average_tokens_across_devices")
)
del transformers_training_args_logger
# No label_names provided for model class
from transformers.trainer import logger as transformers_trainer_logger
transformers_trainer_logger.addFilter(HideLoggingMessage("No label_names"))
# The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config.
transformers_trainer_logger.addFilter(HideLoggingMessage("The tokenizer has new"))
del transformers_trainer_logger
# Using the default loss: `ForCausalLMLoss`.
try:
from transformers.modeling_utils import logger as transformers_modeling_utils_logger
transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss"))
del transformers_modeling_utils_logger
except:
pass
# The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
try:
from accelerate.utils.modeling import logger as accelerate_utils_modeling_logger
accelerate_utils_modeling_logger.addFilter(
HideLoggingMessage("The model weights are not tied")
)
del accelerate_utils_modeling_logger
except:
pass
# Setting `pad_token_id` to `eos_token_id`
try:
from transformers.generation.utils import (
logger as transformers_generation_utils_logger,
)
transformers_generation_utils_logger.addFilter(
HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`")
)
# "You have set `compile_config`
transformers_generation_utils_logger.addFilter(HideLoggingMessage("compile_config"))
del transformers_generation_utils_logger
except:
pass
# The following generation flags are not valid and may be ignored:
try:
from transformers.generation.configuration_utils import (
logger as configuration_logger,
)
configuration_logger.addFilter(HideLoggingMessage("following generation flags"))
del configuration_logger
except:
pass
# Gemma3 It is strongly recommended to train Gemma3 models with the `eager`
try:
from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger
gemma3_logger.addFilter(HideLoggingMessage("strongly recommended"))
del gemma3_logger
except:
pass
# Gemma4 It is strongly recommended to train Gemma4 models with the `eager`
try:
from transformers.models.gemma4.modeling_gemma4 import logger as gemma4_logger
gemma4_logger.addFilter(HideLoggingMessage("strongly recommended"))
del gemma4_logger
except:
pass
# Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed.
try:
from huggingface_hub.file_download import logger as hub_logger
hub_logger.addFilter(HideLoggingMessage("hf_xet"))
del hub_logger
except:
pass
# MXFP4 quantization requires triton >= 3.4.0
try:
from transformers.quantizers.quantizer_mxfp4 import logger as mxfp4_logger
mxfp4_logger.addFilter(HideLoggingMessage("requires triton"))
del mxfp4_logger
except:
pass
# You passed `quantization_config` or equivalent parameters
try:
warnings.filterwarnings(
action = "ignore",
message = r".*quantization_config.*",
category = UserWarning,
append = True,
)
except:
pass
# UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463
try:
warnings.filterwarnings(
action = "ignore",
message = r".*Logical operators 'and' and 'or'.*",
category = UserWarning,
append = True,
)
except:
pass
# Using a slow image processor as `use_fast`
try:
from transformers.processing_utils import logger as processing_utils_logger
processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
del processing_utils_logger
except:
pass
# Using a slow image processor as `use_fast`
try:
from transformers.models.auto.image_processing_auto import (
logger as processing_utils_logger,
)
processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
del processing_utils_logger
except:
pass
# `use_cache=True` is incompatible with gradient checkpointing
try:
from transformers.trainer import logger as trainer_logger
trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
del trainer_logger
except:
pass
# `use_cache=True` is incompatible with gradient checkpointing
try:
from transformers.utils.generic import logger as trainer_logger
trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
del trainer_logger
except:
pass
# We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')
try:
from transformers.modeling_utils import logger as modeling_utils_logger
modeling_utils_logger.addFilter(HideLoggingMessage("anti-pattern"))
del modeling_utils_logger
except:
pass
# Errors out on
# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
from transformers.modeling_utils import logger as transformers_logger
class _RaiseUninitialized(logging.Handler):
def __init__(self):
super().__init__()
def emit(self, record):
record_lower = str(record).lower()
if (
("some weights of" in record_lower)
and ("score.weight" not in record_lower)
and ("classifier.weight" not in record_lower)
and ("cls.predictions" not in record_lower)
and ("predictions.decoder" not in record_lower)
and (os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1") == "1")
):
raise Exception(
f"Unsloth: Critical error since some weights are not initialized.\n"
f"Please try updating Unsloth, transformers and timm via:\n"
f"`pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm`\n"
f"{str(record)}"
)
class RaiseUninitialized:
def __init__(self):
self.error_handler = _RaiseUninitialized()
transformers_logger.addHandler(self.error_handler)
def remove(self):
transformers_logger.removeHandler(self.error_handler)
try:
from transformers.trainer import logger as transformers_trainer_logger
transformers_trainer_logger.addFilter(
HideLoggingMessage("The model is already on multiple devices.")
)
except:
pass
# Hide HF Hub unauthenticated request warnings
try:
from huggingface_hub.utils._http import logger as hf_http_logger
hf_http_logger.addFilter(
HideLoggingMessage("You are sending unauthenticated requests")
)
del hf_http_logger
except:
pass
# Replace PEFT target_parameters warning with Unsloth branded message for MoE models
ReplaceWarningMessage.add_rule(
match_text = "target_parameters",
replacement = (
"Unsloth: PEFT set target_parameters but found no matching parameters.\n"
"This is expected for MoE models - Unsloth handles MoE expert LoRA targeting separately."
),
category = RuntimeWarning,
)
# Patch get_model_param_count to record correct 4bit / 8bit
from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled
def extract_quant_model_param_count(model):
"""
Calculate quant model param count based on difference in param class. Returns int for param count.
"""
count: int = 0
for name, p in model.named_parameters():
if p.__class__.__name__ == "Params4bit":
count += 2 * p.numel()
else:
count += p.numel()
return count
def get_model_param_count(model, trainable_only = False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if is_deepspeed_zero3_enabled():
def numel(p):
return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
else:
def numel(p):
return p.numel()
s = sum(
numel(p) for p in model.parameters() if not trainable_only or p.requires_grad
)
if (
(not trainable_only)
and hasattr(model, "config")
and hasattr(model.config, "quantization_config")
):
approx = extract_quant_model_param_count(model)
if approx is not None:
s = approx
return s
import transformers.trainer_pt_utils
transformers.trainer_pt_utils.get_model_param_count = get_model_param_count
import transformers.trainer
transformers.trainer.get_model_param_count = get_model_param_count
# =============================================
# =============================================
# Edits all Config files to enable RoPE Scaling for all models
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
if "head_dim (" not in config:
add_head_dim = (
"If it is not specified, will default to `8`.\n"
" head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"
" The attention head dimension."
)
config = config.replace(
"If it is not specified, will default to `8`.", add_head_dim
)
add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
config = config.replace("num_key_value_heads=8,", add_head_dim)
add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
return config
try:
# Some Config files use layer_type_validation
# for eg Gemma-2, so we must import it to stop errors.
from transformers.configuration_utils import layer_type_validation
except:
pass
try:
# Transformers 5.0+ uses RotaryEmbeddingConfigMixin as a base class for configs
from transformers.modeling_rope_utils import RotaryEmbeddingConfigMixin
except:
pass
from transformers import __version__ as transformers_version
try:
from transformers import PreTrainedConfig
except:
from transformers import PretrainedConfig
model_architectures = [
"llama",
"mistral",
"gemma",
"gemma2",
"qwen2",
"granite",
"qwen3",
"qwen3_moe",
"falcon_h1",
]
# Transformers 5.x uses class-level annotations with @strict, @auto_docstring,
# and interval() in config classes. exec(inspect.getsource(...)) fails because
# those symbols are not in scope. Skip the exec-based config patching for 5.x
# since those configs already use rope_parameters (the v5 replacement for
# rope_scaling).
_skip_config_exec_patch = Version(transformers_version) >= Version("5.0.0")
for model_name in model_architectures:
if _skip_config_exec_patch:
break
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
config_filename = f"{model_name.title().replace('_','')}Config" # qwen3 arch folder is qwen3_moe but config is Qwen3Config. Need to remove underscore(_) for now
try:
exec(f"from {config_filepath} import {config_filename}", globals())
except:
continue
try:
config = inspect.getsource(eval(config_filename))
except:
continue
if "RopeParameters" in config:
try:
exec(f"from {config_filepath} import RopeParameters", globals())
except:
continue
if "rope_scaling" in config:
continue
config = re.sub(
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
r"rope_scaling=None,"
r"\n **kwargs):\n"
r"\n self.rope_scaling = rope_scaling\n",
config,
)
# Just for Mistral Nemo
if model_name == "mistral":
if Version(transformers_version) <= Version("4.42.4"):
config = patch_mistral_nemo_config(config)
try:
exec(config, globals())
exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
except Exception:
continue
# =============================================
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
torch_version = torch.__version__
if DEVICE_TYPE in ("cuda", "hip"):
if Version(torch_version) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
elif DEVICE_TYPE == "xpu":
if Version(torch_version) < Version("2.6.0"):
raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# =============================================
# =============================================
# Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
# import transformers.cache_utils
# if hasattr(transformers.cache_utils, "DynamicCache") and \
# transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":
# source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
# start = source.find("def")
# spaces = start*" "
# source = source.split("\n")
# source = "\n".join(x[start:] for x in source)
# where = source.find("raise KeyError")
# source = source[:where] + \
# f"if len(self) == 0:\n{spaces}{spaces}"\
# " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
# f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
# source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
# exec(source)
# transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
# pass
# =============================================
# =============================================
# Weird Databricks errors
from transformers.utils import is_openai_available
if is_openai_available():
try:
from openai import OpenAI
except:
print("Unsloth: OpenAI failed to import - ignoring for now.")
import transformers.utils
def _is_openai_available():
return False
transformers.utils.is_openai_available = _is_openai_available
# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
from transformers import AutoTokenizer
from transformers.utils.import_utils import _is_package_available
SUPPORTS_BFLOAT16 = False
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False
if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()
torch.cuda.get_device_capability = functools.cache(torch.cuda.get_device_capability)
if major_version >= 8:
SUPPORTS_BFLOAT16 = True
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
try:
# See https://github.com/unslothai/unsloth/issues/1437
from flash_attn.flash_attn_interface import flash_attn_gpu
except:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True
# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(
flash_attn_version
) >= Version("2.6.3")
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
"To update flash-attn, do the below:\n"
'\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
)
except:
print(
"Unsloth: Your Flash Attention 2 installation seems to be broken. "
"Using Xformers instead. No performance changes will be seen."
)
# Stop Flash Attention from importing!
import transformers.utils.import_utils
transformers.utils.import_utils.is_flash_attn_2_available = (
lambda *args, **kwargs: False
)
import transformers.utils
transformers.utils.is_flash_attn_2_available = (
lambda *args, **kwargs: False
)
HAS_FLASH_ATTENTION = False
else:
HAS_FLASH_ATTENTION = False
else:
# Tri Dao's benchmark shows xformers is faster for now.
HAS_FLASH_ATTENTION = False
elif DEVICE_TYPE == "hip":
SUPPORTS_BFLOAT16 = True
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
try:
try:
# See https://github.com/unslothai/unsloth/issues/1437
from flash_attn.flash_attn_interface import flash_attn_gpu
except:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True
# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version(
"2.6.3"
)
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
"To update flash-attn, do the below:\n"
'\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
)
except:
print(
"Unsloth: Your Flash Attention 2 installation seems to be broken. "
"Using Xformers instead. No performance changes will be seen."
)
# Stop Flash Attention from importing!
import transformers.utils.import_utils
transformers.utils.import_utils.is_flash_attn_2_available = (
lambda *args, **kwargs: False
)
import transformers.utils
transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
HAS_FLASH_ATTENTION = False
elif DEVICE_TYPE == "xpu":
SUPPORTS_BFLOAT16 = True
# =============================================
# Get Xformers
# Silence xformers CUDA mismatch warnings before import
try:
_xformers_logger = logging.getLogger("xformers")
_xformers_logger.setLevel(logging.ERROR)
del _xformers_logger
except:
pass
try:
from xformers import __version__ as xformers_version
# Xformers <= 0.0.32.post2 has a broken FA3 dispatch on Blackwell/RTX 50x GPUs.
# The FA3 check used `capability >= (9, 0)` which matches SM 10.0/11.0/12.0,
# causing sm_90a kernels to be attempted on non-Hopper GPUs (CUDA error in
# flash_fwd_launch_template.h:188). Fixed in 0.0.33 with `<= (9, 0)`.
# See https://github.com/facebookresearch/xformers/issues/1329
if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()
if (f"{major_version}.{minor_version}" in ("10.0", "11.0", "12.0")) and (
Version(xformers_version) <= Version("0.0.32.post2")
):
raise NotImplementedError(
f"Unsloth: Xformers {xformers_version} has a broken FA3 dispatch on "
f"SM {major_version}.{minor_version} GPUs. Please upgrade to >= 0.0.33 or build from source via\n"
"```\n"
"pip install ninja\n"
"pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n"
"```\n"
)
# Temporarily disable 0.0.27 and higher - inference issues
if False: # Version(xformers_version) >= Version("0.0.27"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
"then press Disconnect Runtime and then Restart it.\n"
"\n"
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
"\n"
f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"
'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
)
if Version(torch_version) < Version("2.2.0") and Version(
xformers_version
) >= Version("0.0.24"):
raise ImportError(
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
f"Please install xformers < 0.0.24 for torch = {torch_version}."
)
elif Version(torch_version) < Version("2.3.0") and Version(
xformers_version
) >= Version("0.0.26"):
raise ImportError(
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
f"Please install xformers < 0.0.26 for torch = {torch_version}."
)
elif Version(torch_version) < Version("2.4.0") and Version(
xformers_version
) > Version("0.0.27"):
raise ImportError(
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
f"Please install xformers <= 0.0.27 for torch = {torch_version}."
)
from xformers._cpp_lib import _register_extensions
try:
_register_extensions() # Check if C++ modules are loaded correctly
except Exception as error:
raise ImportError(
"Unsloth: Xformers was not installed correctly.\n"
"Please install xformers separately first.\n"
"Then confirm if it's correctly installed by running:\n"
"python -m xformers.info\n\n"
"Longer error message:\n" + str(error)
)
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
except ModuleNotFoundError:
xformers = None
xformers_attention = None
xformers_version = None
except Exception as e:
if UNSLOTH_ENABLE_LOGGING:
print(
"========\nSwitching to PyTorch attention since your Xformers is broken.\n========\n"
)
print(str(e))
xformers = None
xformers_attention = None
xformers_version = None
# Check TRL version
from trl import __version__ as trl_version
# Unsloth now supports all TRL versions!
if False: # Version(trl_version) >= Version("0.9.0"):
raise ImportError(
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
"then press Disconnect Runtime and then Restart it.\n"
"\n"
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
"\n"
f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"
"Please downgrade TRL via `pip install --force-reinstall trl"
)
# =============================================
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
# accelerate_old_send_to_device = None
# accelerate_new_send_to_device = None
# if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"):
# import accelerate.utils.operations
# if hasattr(accelerate.utils.operations, "send_to_device") and \
# accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
# accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
# from accelerate.utils.operations import *
# send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
# send_to_device = re.sub(
# r"([ ]{4,})return tensor\.to\(device\)",
# r"\1try: return tensor.to(device)\n\1except: return tensor",
# send_to_device,
# ).replace("def send_to_device", "def _fixed_send_to_device")
# exec(send_to_device)
# # accelerate.utils.operations.send_to_device = _fixed_send_to_device
# accelerate_new_send_to_device = _fixed_send_to_device
# pass
# pass
# Transformers 4.46 breaks dynamic caching. This is a hack
import transformers.generation.configuration_utils
if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
if (
type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS)
is list
):
if (
"dynamic"
not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS
):
transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append(
"dynamic"
)
# =============================================
# =============================================
# Torch compile settings
UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1"
UNSLOTH_COMPILE_IGNORE_ERRORS = (
os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1"
)
# Just remove max_autotune_gemm warning
from torch._inductor.runtime.hints import DeviceProperties
@functools.lru_cache(None)
def is_big_gpu(index) -> bool:
if DEVICE_TYPE == "xpu":
prop = DeviceProperties.create(
torch.device("xpu", index) if type(index) is int else index
)
min_sms = 16
else:
prop = DeviceProperties.create(
torch.device("cuda", index) if type(index) is int else index
)
min_sms = 80
avail_sms = prop.multi_processor_count
if avail_sms < min_sms:
return False
return True
import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
patch_torch_compile(
debug = UNSLOTH_COMPILE_DEBUG,
O3 = UNSLOTH_COMPILE_MAXIMUM,
ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS,
)
torch_compile_options = {
"epilogue_fusion": True,
"max_autotune": True,
"shape_padding": True,
"trace.enabled": UNSLOTH_COMPILE_DEBUG,
"triton.cudagraphs": False,
}
import accelerate
def torch_compile_kwargs(*args, **kwargs):
print("Unsloth: Enabled auto compiling")
return {
"dynamic": True,
"fullgraph": False,
"options": torch_compile_options,
}
accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
del accelerate
def patch_regional_compilation():
# Regional torch 2.5 Recompilation - weirdly very slow??
if torch.nn.ModuleList.__name__ == "UnslothModuleList":
return
# Only works for torch 2.5
if Version(torch.__version__) < Version("2.5.0"):
return
old_module_list = torch.nn.ModuleList
os.environ["UNSLOTH_PATCHED"] = "1"
def UnslothModuleList(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list:
args = [
old_module_list(
[
torch.compile(
x,
dynamic = True,
options = torch_compile_options,
fullgraph = False,
)
for x in args[0]
]
)
]
return old_module_list(*args, **kwargs)
UnslothModuleList.__doc__ = old_module_list.__doc__
torch.nn.ModuleList = UnslothModuleList
return
# =============================================
def prepare_model_for_kbit_training(
model: Any,
use_gradient_checkpointing: Optional = True,
use_reentrant: Optional[bool] = True,
) -> Any:
return prepare_model_for_training(
model = model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = use_reentrant,
full_finetuning = False,
train_layernorms = False,
train_embedding = False,
train_lm_head = False,
float32_mixed_precision = True,
)
# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft import __version__ as peft_version
from peft.utils.integrations import dequantize_module_weight
if Version(peft_version) < Version("0.12.0"):
from peft.tuners.lora.layer import LoraLayer
try:
source = inspect.getsource(LoraLayer.update_layer)
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
source = source.replace(source[start:end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
source = "\n".join(x[spaces:] for x in lines)
source = re.sub(r"([^\.])nn\.", r"\1torch.nn.", source)
source = source.replace("def update_layer", "def LoraLayer_update_layer")
exec(source, globals())
# Fix up incorrect downcasting of LoRA weights
from peft.tuners.lora.layer import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
from peft.tuners.lora import LoraLayer
LoraLayer.update_layer = LoraLayer_update_layer
except:
logger.warning_once(
"Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"
"Luckily, your training run will still work in the meantime!"
)
# =============================================
import importlib
global USE_MODELSCOPE
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if USE_MODELSCOPE:
if importlib.util.find_spec("modelscope") is None:
raise ImportError(
f"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`"
)
import socket
@functools.lru_cache(1)
def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
return False
OFFLINE_TRUE = {"1", "true", "yes", "on"}
if os.environ.get("HF_HUB_OFFLINE", "").strip().lower() in OFFLINE_TRUE:
return False
try:
socket.setdefaulttimeout(timeout)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect((host, port))
return True
finally:
sock.close()
except socket.error as ex:
return False
import psutil
def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
n_cpus = psutil.cpu_count(logical = False)
keynames = "\n" + "\n".join(os.environ.keys())
# Check modelscope for down detection
global USE_MODELSCOPE
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if statistics is None:
# Prefer filesystem markers (harder to misidentify) before env-key matching
try:
from pathlib import Path
if Path("/kaggle/working").exists():
statistics = "kaggle"
elif Path("/content").exists() and Path("/opt/colab").exists():
statistics = "colab" if n_cpus == 1 else "colabpro"
elif Path("/runpod-volume").exists():
statistics = "runpod"
except Exception:
pass
# Fallback to env-key detection
if statistics is None:
if "\nKAGGLE_" in keynames:
statistics = "kaggle"
elif "\nCOLAB_" in keynames and n_cpus == 1:
statistics = "colab"
elif "\nCOLAB_" in keynames:
statistics = "colabpro"
elif "\nRUNPOD_" in keynames:
statistics = "runpod"
elif "\nAWS_" in keynames:
statistics = "aws"
elif "\nAZURE_" in keynames:
statistics = "azure"
# elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
elif "\nINVOCATION_ID" in keynames:
statistics = "lambda"
# else: statistics = "other"
else:
def try_vllm_check():
vendor_files = (
"/sys/class/dmi/id/product_version",
"/sys/class/dmi/id/bios_vendor",
"/sys/class/dmi/id/product_name",
"/sys/class/dmi/id/chassis_asset_tag",
"/sys/class/dmi/id/sys_vendor",
)
for vendor_file in vendor_files:
path = Path(vendor_file)
if path.is_file():
file_content = path.read_text().lower()
if "amazon" in file_content:
return "aws"
elif "microsoft corporation" in file_content:
return "azure"
elif "google" in file_content:
return "gcp"
return "other"
try:
statistics = try_vllm_check()
except Exception:
statistics = "other"
if statistics is not None:
import tempfile
from huggingface_hub import snapshot_download
from unsloth_zoo.rl_environments import execute_with_time_limit
if has_internet():
def stats_check():
with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
snapshot_download(
f"unslothai/{statistics}",
force_download = True,
cache_dir = f,
local_dir = f,
)
time_limited_stats_check = execute_with_time_limit(120)(stats_check)
try:
time_limited_stats_check()
except TimeoutError:
raise TimeoutError(
"Unsloth: HuggingFace seems to be down after trying for 120 seconds :(\n"
"Check https://status.huggingface.co/ for more details.\n"
"As a temporary measure, use modelscope with the same model name ie:\n"
"```\n"
"pip install modelscope\n"
"import os; os.environ['UNSLOTH_USE_MODELSCOPE'] = '1'\n"
"from unsloth import FastLanguageModel\n"
"model = FastLanguageModel.from_pretrained('unsloth/gpt-oss-20b')\n"
"```"
)
except Exception:
logger.debug("Unsloth: stats_check failed with an exception.")
# Don't retry without a time limit — would freeze offline
def get_statistics(local_files_only = False):
# We log some basic stats about which environment is being used.
# This is also to check if HuggingFace is down or not!
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by setting UNSLOTH_DISABLE_STATISTICS
import os
if (
"UNSLOTH_DISABLE_STATISTICS" in os.environ
or os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
):
return
if local_files_only:
return
from huggingface_hub.utils import (
disable_progress_bars,
enable_progress_bars,
are_progress_bars_disabled,
)
disabled = False
if not are_progress_bars_disabled():
disable_progress_bars()
disabled = True
_get_statistics(None)
_get_statistics("repeat", force_download = False)
total_memory = (
torch.xpu.get_device_properties(0).total_memory
if DEVICE_TYPE == "xpu"
else torch.cuda.get_device_properties(0).total_memory
)
vram = total_memory / 1024 / 1024 / 1024
if vram <= 8:
vram = 8
elif vram <= 16:
vram = 16
elif vram <= 20:
vram = 20
elif vram <= 24:
vram = 24
elif vram <= 40:
vram = 40
elif vram <= 48:
vram = 48
elif vram <= 80:
vram = 80
else:
vram = 96
_get_statistics(f"vram-{vram}")
_get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}")
if disabled:
enable_progress_bars()
# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import (
BitsAndBytesConfig,
QuantizationMethod,
)
BitsAndBytesConfig__init__ = inspect.getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
BitsAndBytesConfig__init__ = "\n".join(
x[length_spaces:] for x in BitsAndBytesConfig__init__
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
"__init__",
"_BitsAndBytesConfig__init__",
)
exec(BitsAndBytesConfig__init__, globals())
if DEVICE_COUNT == 1 and int(os.environ.get("WORLD_SIZE", "1")) <= 1:
from accelerate.utils.dataclasses import DistributedType
def _prepare_backend(self, *args, **kwargs):
return None, DistributedType.NO
import accelerate.state
accelerate.state.PartialState._prepare_backend = _prepare_backend
accelerate.accelerator.Accelerator.distributed_type = (
lambda *args, **kwargs: DistributedType.NO
)
# to move multiple tensors to the same device
def move_to_device(target_device, *tensors):
"""
Move multiple tensors to target device if they're not already there.
Args:
target_device: The target device to move tensors to
*tensors: Variable number of tensors to potentially move
Returns:
tuple: The tensors on the target device (same objects if already on device, new if moved)
"""
if isinstance(target_device, int):
target_device = torch.device(target_device)
elif isinstance(target_device, str):
# if string we expect it to be a device name like "cuda:0"
target_device = torch.device(target_device)
elif isinstance(target_device, torch.device):
pass
else:
raise ValueError(f"Invalid target device: {target_device}")
moved_tensors = []
for tensor in tensors:
if tensor.device != target_device:
moved_tensors.append(tensor.to(target_device))
else:
moved_tensors.append(tensor)
return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]
import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = (
_BitsAndBytesConfig__init__
)
# =============================================
# Offloading to disk for modules (lm_head, embed_tokens)
import pickle
def offload_to_disk(
W, model, name, temporary_location: str = "_unsloth_temporary_saved_buffers"
):
file_location = os.path.join(temporary_location, model.config._name_or_path)
if not os.path.exists(file_location):
os.makedirs(file_location)
filename = os.path.join(file_location, f"{name}.pt")
W = W.weight if hasattr(W, "weight") else W
torch.save(
W,
filename,
pickle_module = pickle,
pickle_protocol = pickle.HIGHEST_PROTOCOL,
)
# We must use weights_only = False due to pickling
offloaded_W = torch.load(
filename, map_location = "cpu", mmap = True, weights_only = False
)
offloaded_W._offloaded_file_location = filename
return offloaded_W
def offload_input_embeddings(
model, temporary_location: str = "_unsloth_temporary_saved_buffers"
):
offloaded_W = offload_to_disk(
model.get_input_embeddings(), model, "input_embeddings", temporary_location
)
new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_input_embeddings(new_input_embeddings)
return
def offload_output_embeddings(
model, temporary_location: str = "_unsloth_temporary_saved_buffers"
):
offloaded_W = offload_to_disk(
model.get_output_embeddings(), model, "output_embeddings", temporary_location
)
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
new_output_embeddings.in_features = offloaded_W.shape[1]
new_output_embeddings.out_features = offloaded_W.shape[0]
new_output_embeddings._offloaded_file_location = (
offloaded_W._offloaded_file_location
)
model.set_output_embeddings(new_output_embeddings)
return
# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
def is_vLLM_available():
return _is_package_available("vllm")
# Patches models to add RoPE Scaling
def patch_linear_scaling(
model_name = "gemma2",
rope_module = None,
scaled_rope_module = None,
attention_module = None,
):
assert rope_module is not None and scaled_rope_module is not None
assert attention_module is not None
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = (
f"import torch.nn as nn\n"
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
f"from {model_filepath} import logger, "
f"{model_name.title()}Attention, {model_name.title()}Config"
)
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
r"self\.rotary\_emb \= .+?\)",
function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
return None, exec_code + "\n\n" + function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
model_name = "llama",
rope_module = None,
scaled_rope_module = None,
extended_rope_module = None,
attention_module = None,
longrope_module = None,
):
assert (
rope_module is not None
and scaled_rope_module is not None
and extended_rope_module is not None
)
assert attention_module is not None
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = (
f"import torch.nn as nn\n"
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
f"from {model_filepath} import logger, "
f"{model_name.title()}Attention, {model_name.title()}Config"
)
try:
function = inspect.getsource(attention_module.__init__)
except:
# Most likely already patched!
return None, None
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type1 = self.config.rope_scaling.get("type", None)
scaling_type2 = self.config.rope_scaling.get("rope_type", None)
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
scaling_factor = self.config.rope_scaling.get("factor")
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "llama3":
self.rotary_emb = {extended_rope_function}(
dim = self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
elif scaling_type == "longrope":
self.rotary_emb = {longrope_rope_function}(
dim = self.head_dim,
max_position_embeddings = self.max_position_embeddings,
original_max_position_embeddings = self.config.original_max_position_embeddings,
base = self.rope_theta,
short_factor = self.config.rope_scaling['short_factor'],
long_factor = self.config.rope_scaling['long_factor' ],
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
extended_rope_function = extended_rope_module.__name__,
longrope_rope_function = (
longrope_module if longrope_module is not None else rope_module
).__name__,
)
rotary_emb = re.findall(
r"self\.rotary\_emb \= .+?\)",
function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal = 1, out = mask)
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out = mask)
return mask
def test_mask_creation():
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
for n in range(2, 23):
for s in range(1, 23):
correct_mask = (
AttentionMaskConverter(
is_causal = True,
sliding_window = s,
)
.to_causal_4d(
1,
n,
n,
dtype = torch.float16,
)
.squeeze(0)
.squeeze(0)
)
correct_mask = correct_mask == correct_mask.min()
our_mask = create_boolean_mask(n = n, sliding_window = s)
assert torch.all(correct_mask == our_mask)
correct_mask = (
AttentionMaskConverter(
is_causal = True,
sliding_window = None,
)
.to_causal_4d(
1,
n,
n,
dtype = torch.float16,
)
.squeeze(0)
.squeeze(0)
)
correct_mask = correct_mask == correct_mask.min()
our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert torch.all(correct_mask == our_mask)
def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
num_items_in_batch = None
if "num_items_in_batch" in kwargs:
num_items_in_batch = kwargs["num_items_in_batch"]
if num_items_in_batch is None:
# Remove it since the model does not support it!
kwargs.pop("num_items_in_batch")
elif "num_items_in_batch" not in inputs:
inputs["num_items_in_batch"] = num_items_in_batch
# Get gradient accumulation steps if possible
if (
num_items_in_batch is None
and getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1
):
inner_model = model
if hasattr(inner_model, "base_model"):
inner_model = inner_model.base_model
if hasattr(inner_model, "model"):
inner_model = inner_model.model
name = inner_model.__class__.__name__
logger.warning_once(
f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
"Using gradient accumulation will be very slightly less accurate.\n"
"Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
)
# Gemma3 multimodal models in transformers 5.x require token_type_ids during training.
# For text-only SFT, token_type_ids should be all zeros (no image tokens).
if "token_type_ids" not in inputs and "input_ids" in inputs:
_inner = model
for _attr in ("base_model", "model", "model"):
_inner = getattr(_inner, _attr, _inner)
if getattr(getattr(_inner, "config", None), "model_type", "") in ("gemma3",):
import sys as _sys
_mod = _sys.modules.get(type(_inner).__module__)
_has_ccm = _mod is not None and hasattr(_mod, "create_causal_mask_mapping")
if _has_ccm and _inner.training:
inputs["token_type_ids"] = torch.zeros_like(inputs["input_ids"])
# Gemma4 uses mm_token_type_ids (not token_type_ids) for VLM masking
if "mm_token_type_ids" not in inputs and "input_ids" in inputs:
_inner = model
for _attr in ("base_model", "model", "model"):
_inner = getattr(_inner, _attr, _inner)
if getattr(getattr(_inner, "config", None), "model_type", "") in ("gemma4",):
import sys as _sys
_mod = _sys.modules.get(type(_inner).__module__)
_has_ccm = _mod is not None and hasattr(_mod, "create_causal_mask_mapping")
if _has_ccm and _inner.training:
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
return outputs
def patch_gradient_accumulation_fix(Trainer):
# Fixes gradient accumulation
# Fixes Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.
import inspect
if hasattr(Trainer, "get_batch_samples"):
if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples":
return
if (
not inspect.getsource(Trainer.get_batch_samples)
.strip()
.endswith("return batch_samples, num_items_in_batch")
):
raise NotImplementedError(
"Unsloth: Please make a Github issue immediately!!"
)
else:
if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples":
Trainer.get_batch_samples = _unsloth_get_batch_samples
# Also fix passing in num_items_in_batch
if not hasattr(Trainer, "_old_compute_loss"):
# Fix transformers 4.57.0 causing `Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.`
function = inspect.getsource(Trainer.compute_loss)
if "loss *=" in function or "loss*=" in function:
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
# Import all variables that need importing
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
if item in function:
good_items.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in good_items)
+ ")",
globals(),
)
# Replace loss*= with loss = loss *
function = re.sub(
r"loss[\s]{0,}\*\=",
"loss = loss *",
function,
)
exec(function, globals())
Trainer.compute_loss = compute_loss
Trainer._old_compute_loss = Trainer.compute_loss
Trainer.compute_loss = _unsloth_pre_compute_loss
else:
logger.warning_once(
"Unsloth: We fixed a gradient accumulation bug, "
"but it seems like you don't have the latest transformers version!\n"
"Please update transformers, TRL and unsloth via:\n"
"`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`"
)
# Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
if not (
Trainer.training_step.__name__ == "_unsloth_training_step"
or "num_items_in_batch"
not in inspect.signature(Trainer.training_step).parameters
):
function = inspect.getsource(Trainer.training_step)
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
# Import all variables that need importing
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
if item in function:
good_items.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in good_items)
+ ")",
globals(),
)
# Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
# summed it up and did the division before hand, we have to negate it.
function = function.replace(
"loss *= self.args.gradient_accumulation_steps",
"if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
)
function = function.replace(
"def training_step", "def _unsloth_training_step", 1
)
# Fix 4.47.0 issue where num_items_in_batch was removed
# See https://github.com/huggingface/transformers/pull/35121
function = function.replace(
"if self.model_accepts_loss_kwargs:",
"if False:",
)
# Fix when num_items_in_batch is nothing
# https://github.com/huggingface/transformers/pull/35207
function = re.sub(
r"else:\n"
r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"
r"(.+?)if num_items_in_batch is None\:\n"
r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
"else:\n"
"\2if num_items_in_batch is None:\n"
"\3loss = loss / self.args.gradient_accumulation_steps\n"
"\1self.accelerator.backward(loss, **kwargs)",
function,
)
exec(function, globals())
Trainer.training_step = _unsloth_training_step
# Wrap Trainer.__init__: (1) pre-init, shadow accepts_loss_kwargs on whatever
# model was passed in (covers PEFT wrapping done after FastModel.from_pretrained);
# (2) post-init, clamp accelerator GA to 1 for the transformers 5.0-5.5
# GradientAccumulationPlugin regression. No-op on 4.x and 5.6+. See #4982.
if not getattr(Trainer, "_unsloth_init_wrapped_for_accelerate_gas", False):
_original_trainer_init = Trainer.__init__
def _unsloth_trainer_init(self, *args, **kwargs):
model = kwargs.get("model")
if model is None and len(args) > 0:
model = args[0]
if model is not None:
try:
apply_accepts_loss_kwargs_fix(model)
except Exception:
pass
_original_trainer_init(self, *args, **kwargs)
try:
accelerator = getattr(self, "accelerator", None)
if (
accelerator is not None
and getattr(accelerator, "gradient_accumulation_steps", 1) > 1
):
accelerator.gradient_accumulation_steps = 1
gs = getattr(accelerator, "gradient_state", None)
if gs is not None and hasattr(gs, "plugin_kwargs"):
try:
gs.plugin_kwargs["num_steps"] = 1
except Exception:
pass
except Exception:
pass
_unsloth_trainer_init.__wrapped__ = _original_trainer_init
Trainer.__init__ = _unsloth_trainer_init
Trainer._unsloth_init_wrapped_for_accelerate_gas = True
def _unsloth_compile_cache_leaves():
# Accepts `UNSLOTH_COMPILE_LOCATION` overrides (the env var unsloth_zoo honors).
leaves = {"unsloth_compiled_cache", "unsloth_cache", "unsloth_compiled"}
loc = os.environ.get("UNSLOTH_COMPILE_LOCATION", "") or ""
loc = loc.rstrip("/\\")
if loc:
leaves.add(os.path.basename(loc) or loc)
return leaves
def _forward_is_unsloth_compiled(model):
# True iff forward was installed from the Unsloth compile cache directory.
# __module__ stays as the transformers module, so check co_filename.
leaves = _unsloth_compile_cache_leaves()
def check(m):
if m is None:
return False
fwd = getattr(type(m), "forward", None)
if fwd is None:
return False
code = getattr(fwd, "__code__", None)
fn = getattr(code, "co_filename", "") if code is not None else ""
fn = fn.replace("\\", "/")
parts = set(fn.split("/"))
return any(leaf in parts for leaf in leaves)
if check(model):
return True
seen = set()
m = model
for _ in range(4):
if m is None or id(m) in seen:
break
seen.add(id(m))
nxt = getattr(m, "base_model", None)
if nxt is None or nxt is m:
nxt = getattr(m, "model", None)
if nxt is None or nxt is m:
break
if check(nxt):
return True
m = nxt
return False
def _find_concrete_accepts_loss_kwargs(model):
# Walk wrapper chain for first class that declares accepts_loss_kwargs in its
# own __mro__ dict. Avoids PEFT __getattr__ forwarding and our own shadow.
seen = set()
m = model
for _ in range(6):
if m is None or id(m) in seen:
break
seen.add(id(m))
for klass in type(m).__mro__:
if "accepts_loss_kwargs" in klass.__dict__:
return klass.__dict__[
"accepts_loss_kwargs"
], f"{klass.__name__}.accepts_loss_kwargs"
nxt = getattr(m, "base_model", None)
if nxt is None or nxt is m:
nxt = getattr(m, "model", None)
if nxt is None or nxt is m:
break
m = nxt
return None, "no explicit accepts_loss_kwargs on any wrapper level"
def _shadow_accepts_loss_kwargs(model, value):
# Set the attribute at every wrapper level so HF's hasattr check resolves
# regardless of where accelerator / peft unwrap lands.
seen = set()
m = model
for _ in range(8):
if m is None or id(m) in seen:
break
seen.add(id(m))
try:
setattr(m, "accepts_loss_kwargs", value)
except Exception:
pass
nxt = getattr(m, "base_model", None)
if nxt is None or nxt is m:
nxt = getattr(m, "model", None)
if nxt is None or nxt is m:
break
m = nxt
def apply_accepts_loss_kwargs_fix(model):
# Shadow the correct accepts_loss_kwargs on the model so HF Trainer picks it
# up via hasattr(unwrapped_model, ...). Replaces the old Trainer.__init__
# source rewrite. Priority: compiled forward -> True; else first class attr
# in wrapper chain; else leave HF default. Issue #4982.
if _forward_is_unsloth_compiled(model):
_shadow_accepts_loss_kwargs(model, True)
return "True (Unsloth compiled forward)"
value, reason = _find_concrete_accepts_loss_kwargs(model)
if value is None:
return f"default (signature inspection, {reason})"
_shadow_accepts_loss_kwargs(model, value)
return f"{value} ({reason})"
def patch_tokenizer(model, tokenizer):
model, tokenizer = _patch_tokenizer(model, tokenizer)
if model is not None:
model.config.update({"unsloth_version": __version__})
return model, tokenizer
def patch_fast_lora():
import peft.tuners.lora.bnb
peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward
def unsloth_compile_transformers(
dtype,
model_name,
model_types,
token = None,
revision = None,
trust_remote_code = False,
sdpa_dynamic_mask = True,
sdpa_bool_masks = True,
sdpa_gqa_replace = True,
sdpa_dynamic_compile = True,
compile_attention = True,
disable_causal_masks = True,
compile_torch_modules = True,
compile_custom_modules = True,
compile_function_calls = True,
fuse_lm_head = True,
gradient_checkpointing = True,
manual_replacements = True,
fast_lora_forwards = True,
fast_residual_stream = True,
accurate_accumulation = True,
epilogue_fusion = True,
max_autotune = False,
shape_padding = True,
cudagraphs = False,
debug = False,
fullgraph = True,
import_from_cache = False,
disable = False,
return_logits = False,
unsloth_force_compile = False,
):
if Version(torch_version) < Version("2.4.0"):
print(
"="
* 30
+ "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"
f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"
"For now your models will not get optimized, but will still work for now!"
)
return
if trust_remote_code and unsloth_force_compile == False:
print(
"Unsloth: We can't trace models if `trust_remote_code = True`, "
"so turning off some optimizations!"
)
return model_types, False
model_types = list(dict().fromkeys(model_types).keys())
if disable:
return model_types, False
supports_sdpa = [True]
# Run patches BEFORE compiler so class replacements (e.g. GptOssTopKRouter,
# GptOssExperts) are in place before the compiler caches references to them.
_run_temporary_patches("pre_compile")
for model_type in model_types:
_unsloth_compile_transformers(
model_type,
sdpa_dynamic_mask = sdpa_dynamic_mask,
sdpa_bool_masks = sdpa_bool_masks,
sdpa_gqa_replace = sdpa_gqa_replace,
sdpa_dynamic_compile = sdpa_dynamic_compile,
compile_attention = compile_attention,
disable_causal_masks = disable_causal_masks,
compile_torch_modules = compile_torch_modules,
compile_custom_modules = compile_custom_modules,
compile_function_calls = compile_function_calls,
fuse_lm_head = fuse_lm_head,
gradient_checkpointing = gradient_checkpointing,
manual_replacements = manual_replacements,
fast_lora_forwards = fast_lora_forwards,
fast_residual_stream = fast_residual_stream,
accurate_accumulation = accurate_accumulation,
epilogue_fusion = epilogue_fusion,
max_autotune = max_autotune,
shape_padding = shape_padding,
cudagraphs = cudagraphs,
debug = debug,
fullgraph = fullgraph,
import_from_cache = import_from_cache,
disable = disable,
return_logits = return_logits,
supports_sdpa = supports_sdpa,
)
# Redo patches which override compiler
_run_temporary_patches("post_compile")
return model_types, supports_sdpa[0]
# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
LOGITS_ERROR_STRING = (
"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "
'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'
"```\nimport os\n"
"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
"trainer.train()\n```\n"
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
)
def raise_logits_error(*args, **kwargs):
raise NotImplementedError(LOGITS_ERROR_STRING)
def return_none(*args, **kwargs):
return None
class EmptyLogits:
def __init__(self):
return
def raise_getattr_error(self, attr):
return return_none if attr == "to" else raise_logits_error
__getitem__ = raise_logits_error
__getattr__ = raise_getattr_error
def __repr__(self):
return LOGITS_ERROR_STRING
def __str__(self):
return LOGITS_ERROR_STRING
EMPTY_LOGITS = EmptyLogits()
functions = dir(torch.Tensor)
for j, function in enumerate(functions):
if function.startswith("__") and function.endswith("__"):
exec(
f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()
)
try:
exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
except:
continue
def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, model):
from peft import LoraConfig
if loftq_config is None:
loftq_config = {}
signature = str(inspect.signature(LoraConfig))
SUPPORTS_LOFTQ = "loftq_config" in signature
if lora_dropout != 0:
logger.warning_once(
f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
if bias != "none":
logger.warning_once(
f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
if not (
type(init_lora_weights) is bool
or init_lora_weights == "gaussian"
or init_lora_weights == "loftq"
or init_lora_weights == "corda"
):
raise ValueError(
'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq", "corda"].'
)
if init_lora_weights == "loftq":
if not SUPPORTS_LOFTQ:
import peft
raise RuntimeError(
f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"
"Please install PEFT 0.7.2 or higher.\n"
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
if loftq_config == {}:
from peft import LoftQConfig
logger.warning_once(
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
if hasattr(model.config, "quantization_config"):
raise ValueError(
"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"
"Reload your model without any quantization by setting `load_in_4bit = False`."
)
return loftq_config
def fast_inference_setup(model_name, model_config):
fast_inference = True
if not is_vLLM_available():
logger.warning_once(
"Unsloth: vLLM is not installed! Will use Unsloth inference!"
)
fast_inference = False
from unsloth_zoo.vllm_utils import (
patch_vllm,
vllm_dynamic_quant_supported,
)
patch_vllm()
if model_name.endswith("unsloth-bnb-4bit"):
if not vllm_dynamic_quant_supported(model_name, model_config):
# Instead use -bnb-4bit variant
logger.warning_once(
f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"
f"we do not yet support fast inference for {model_name}"
)
model_name = model_name[: -len("unsloth-bnb-4bit")] + "bnb-4bit"
return fast_inference, model_name
def patch_peft_fast_inference(model):
vllm_engine = getattr(model.model, "vllm_engine", None)
if vllm_engine is not None:
model.vllm_engine = model.model.vllm_engine
model.fast_generate = model.model.fast_generate
model.fast_generate_batches = model.model.fast_generate_batches
# Also saving and loading LoRA
from unsloth_zoo.vllm_utils import save_lora, load_lora
model.save_lora = functools.partial(save_lora, model)
model.load_lora = functools.partial(load_lora, model)
def error_out_no_vllm(*args, **kwargs):
raise NotImplementedError(
"Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead"
)
try:
from torchao.core.config import AOBaseConfig
try:
from torchao.quantization import Int4WeightOnlyConfig
except:
print("Unsloth: TorchAO changed `torchao.quantization.Int4WeightOnlyConfig`")
Int4WeightOnlyConfig = None
except:
AOBaseConfig = None
Int4WeightOnlyConfig = None
@dataclass
class TorchAOConfig:
qat_scheme: Optional[str] = "int4"
# Each (config, filter_fn) pair defines a quantization rule
base_config_and_filter_fns: List[
Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
] = field(
default_factory = lambda: [
(
Int4WeightOnlyConfig(group_size = 128),
lambda m, _: isinstance(m, torch.nn.Linear)
and getattr(m, "in_features", 0) >= 128,
),
]
)
# Optional transformation to apply before quantization setup
prequantization_transform: Optional[Callable[[torch.nn.Module], None]] = None
def _untie_input_output_embeddings(model: torch.nn.Module) -> None:
"""
Utility to untie input/output embeddings in a HuggingFace model.
This is useful if we want to quantize the input/ouput embeddings differently.
Model is modified in-place.
"""
# 1) Persist setting in config
if hasattr(model.config, "tie_word_embeddings"):
model.config.tie_word_embeddings = False
# 2) Find input and output embeddings
in_emb = model.get_input_embeddings()
out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
if out_proj is None:
raise AttributeError("Couldn't locate output projection (lm_head).")
# (Optional) sanity: shapes should match [vocab, hidden]
assert (
out_proj.weight.shape == in_emb.weight.shape
), f"Shape mismatch: out_proj {out_proj.weight.shape} vs in_emb {in_emb.weight.shape}"
# 3) Only clone if they are actually tied (shared storage)
if out_proj.weight.data_ptr() == in_emb.weight.data_ptr():
with torch.no_grad():
W = in_emb.weight.detach().clone()
out_proj.weight = torch.nn.Parameter(W) # new storage, keeps dtype/device
# 4) Prevent future automatic re-tying
def _no_tie(self):
return
model.tie_weights = _no_tie.__get__(model, model.__class__)
# 5) Verify no shared storage
assert (
out_proj.weight.data_ptr() != in_emb.weight.data_ptr()
), "Embeddings still tied!"
def _filter_fn_to_fqns(
model: torch.nn.Module,
filter_fn: Callable[[torch.nn.Module, str], bool],
) -> Iterator[str]:
"""
Given a model and a filter function (m, fqn) -> bool,
yield fully qualified names (FQNs) of modules that match.
"""
for fqn, module in model.named_modules():
if filter_fn(module, fqn):
yield fqn
def _convert_torchao_model(model):
from transformers import TorchAoConfig
from torchao.quantization import quantize_, ModuleFqnToConfig
from torchao.quantization.qat import QATConfig
from torchao.utils import TorchAOBaseTensor
module_to_fqn_dict = {}
for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)
# Default filter function used for quantize_
if filter_fn is None:
if "_default" in module_to_fqn_dict:
raise ValueError("Cannot use multiple default quantization configs")
module_to_fqn_dict["_default"] = base_config
else:
for fqn in _filter_fn_to_fqns(model, filter_fn):
if fqn in module_to_fqn_dict:
raise ValueError(f"Found multiple quantization configs for {fqn}")
module_to_fqn_dict[fqn] = base_config
in_emb = model.get_input_embeddings()
out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
kwargs = {}
if isinstance(in_emb.weight, TorchAOBaseTensor) or (
out_proj is not None and isinstance(out_proj.weight, TorchAOBaseTensor)
):
kwargs["include_input_output_embeddings"] = True
kwargs["modules_to_not_convert"] = []
quant_config = ModuleFqnToConfig(module_to_fqn_dict)
quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
model.config.quantization_config = quantization_config
def _prepare_model_for_qat(
model: torch.nn.Module, qat_scheme: Union[str, TorchAOConfig]
) -> torch.nn.Module:
"""
Transform a model for Quantization-Aware Training (QAT) during fine-tuning.
On a high level, this means fake quantizing the base (frozen) model during training.
Fake quantization refers to simulating quantization numerics in high precision (e.g. bf16).
This helps mitigate quantization degradations when the model is quantized after training.
QAT can be optionally combined with LoRA fine-tuning to for additional throughput improvement.
For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700
"""
try:
from torchao.quantization import PerRow, quantize_
from torchao.quantization.granularity import PerGroup, PerAxis
from torchao.quantization.qat import QATConfig
except ImportError:
raise ImportError(TORCHAO_MSG)
# Gemma3 models have issues with int8 embedding quantization due to their
# large vocabulary size (262144). Auto-switch to int4 weight-only instead.
if qat_scheme == "int8-int4":
model_types = get_transformers_model_type(model.config)
is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
if is_gemma3:
print(
"Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
"Switching to int4 weight-only QAT for training stability."
)
qat_scheme = "int4"
if not isinstance(qat_scheme, TorchAOConfig):
torchao_config: Optional[TorchAOConfig] = None
if qat_scheme == "fp8-int4":
try:
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
except ImportError:
raise ImportError(TORCHAO_MSG)
group_size = 128
base_config = Float8DynamicActivationInt4WeightConfig()
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
elif qat_scheme == "fp8-fp8":
try:
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
)
except ImportError:
raise ImportError(TORCHAO_MSG)
base_config = Float8DynamicActivationFloat8WeightConfig(
granularity = PerRow()
)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
)
elif qat_scheme == "int8-int4":
try:
from torchao.quantization import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
)
except ImportError:
raise ImportError(TORCHAO_MSG)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme,
base_config_and_filter_fns = [
(
IntxWeightOnlyConfig(
weight_dtype = torch.int8, granularity = PerAxis(0)
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
),
(
Int8DynamicActivationIntxWeightConfig(
weight_dtype = torch.int4, weight_granularity = PerGroup(32)
),
None,
),
],
prequantization_transform = _untie_input_output_embeddings,
)
elif qat_scheme == "int4":
try:
from torchao.quantization import Int4WeightOnlyConfig
except ImportError:
raise ImportError(TORCHAO_MSG)
group_size = 128
base_config = Int4WeightOnlyConfig(group_size = group_size)
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
elif qat_scheme == "int8":
try:
from torchao.quantization import IntxWeightOnlyConfig
from torchao.quantization.granularity import PerAxis
except ImportError:
raise ImportError(TORCHAO_MSG)
base_config = IntxWeightOnlyConfig(
weight_dtype = torch.int8,
granularity = PerAxis(0),
)
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
elif qat_scheme == "cactus":
try:
from torchao.quantization import IntxWeightOnlyConfig
except ImportError:
raise ImportError(TORCHAO_MSG)
# IntxWeightOnlyConfig already defaults to
# `mapping_type = MappingType.SYMMETRIC`, so we intentionally do not
# import `MappingType` here. Matches the upstream Cactus runtime
# int8 / per-group-32 / symmetric weight-only configuration.
group_size = 32
base_config = IntxWeightOnlyConfig(
weight_dtype = torch.int8,
granularity = PerGroup(group_size),
)
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
and m.in_features % group_size == 0
)
# Warn if any Linear layer is skipped by the cactus filter because
# its in_features is not divisible by `group_size`. torchao's
# PerGroup(32) quantizer rejects non-divisible widths at
# `quantize_()` time, so the filter excludes those layers to keep
# the QAT prepare step from crashing. Surface that silently-skipped
# coverage gap to the user so they know some Linears will stay in
# full precision during training.
skipped_cactus_layers = [
name
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear)
and module.in_features >= group_size
and module.in_features % group_size != 0
]
if skipped_cactus_layers:
preview = ", ".join(skipped_cactus_layers[:8])
if len(skipped_cactus_layers) > 8:
preview += f", ... ({len(skipped_cactus_layers) - 8} more)"
warnings.warn(
f"Unsloth: qat_scheme='cactus' uses PerGroup({group_size}) "
"which requires in_features to be divisible by "
f"{group_size}. The following Linear layers will be kept "
f"in full precision during QAT: {preview}",
stacklevel = 2,
)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
else:
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"
else:
torchao_config = qat_scheme
# Save Torchao metadata everywhere
inner_model = model
while hasattr(inner_model, "model"):
inner_model._torchao_config = torchao_config
inner_model = inner_model.model
inner_model._torchao_config = torchao_config
if torchao_config.prequantization_transform is not None:
torchao_config.prequantization_transform(model)
for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
return model
def patch_hf_quantizer():
# To tell hf trainer that the quantized model is trainable
def make_trainable(self):
return True
try:
from transformers.quantizers.quantizer_finegrained_fp8 import (
FineGrainedFP8HfQuantizer,
)
FineGrainedFP8HfQuantizer.is_trainable = property(make_trainable)
FineGrainedFP8HfQuantizer.is_qat_trainable = property(make_trainable)
except Exception as e:
logger.warning(f"Failed to patch FineGrainedFP8HfQuantizer. Error {e}")
try:
from transformers.quantizers.quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
FbgemmFp8HfQuantizer.is_trainable = property(make_trainable)
FbgemmFp8HfQuantizer.is_qat_trainable = property(make_trainable)
except Exception as e:
logger.warning(f"Failed to patch FbgemmFp8HfQuantizer. Error {e}")
try:
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
TorchAoHfQuantizer.is_trainable = property(make_trainable)
TorchAoHfQuantizer.is_qat_trainable = property(make_trainable)
except Exception as e:
logger.warning(f"Failed to patch TorchAoHfQuantizer. Error {e}")
patch_hf_quantizer()
def verify_fp8_support_if_applicable(model_config):
quant_method = get_quant_type(model_config)
if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda":
raise ValueError(
f"Unsloth: FP8 quantization is only supported on CUDA GPUs. You are using {DEVICE_TYPE}."
)
# [TODO] Need to add FP8 support for Intel XPUs
if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()
if quant_method == "fbgemm_fp8" and major_version < 9:
# While L4 does support FP8 as data type, it doesn't have fbgemm (package) support yet. So we restrict it.
raise ValueError(
f"Unsloth: FBGEMM FP8 quantization is only supported on H100 and higher GPUs. L4 is not supported. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
)
if quant_method == "fp8" and major_version * 10 + minor_version < 89:
# In case of block quantized, we allow L4 because we fall back to torchao kernels.
raise ValueError(
f"Unsloth: FP8 quantization is only supported on L4 and higher GPUs with compute capability 8.9 or higher. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
)
def _get_inference_mode_context_manager(model: torch.nn.Module):
"""
If the state dict was quantized using torchao, we will run into
the following error when calling ops like aten.t() in inference mode.
This is a bug in PyTorch that affects all tensor subclasses.
Cannot set version_counter for inference tensor
For now, we work around this issue by using `torch.no_grad()` in this case.
See https://github.com/pytorch/pytorch/issues/164872 for more details.
Otherwise, just return `torch.inference_mode()`.
"""
torchao_config = getattr(model, "torchao_config", None)
if torchao_config is not None and torchao_config.qat_scheme is None:
return torch.no_grad()
else:
return torch.inference_mode()
def hf_login(token: Optional[str] = None) -> Optional[str]:
if token is None:
try:
from huggingface_hub import get_token
token = get_token()
if token is None:
return None
except:
return None
try:
from huggingface_hub import login
login(token = token)
return token
except Exception as e:
logger.info(f"Failed to login to huggingface using token with error: {e}")
return token
# =============================================
# MoE (Mixture of Experts) Detection and LoRA Utilities
def is_moe_model(model) -> bool:
"""
Detect if a model is a Mixture of Experts (MoE) model.
Args:
model: The model to check (can be HF model or config)
Returns:
True if the model is an MoE model, False otherwise
"""
config = getattr(model, "config", model)
# Different MoE models use different config attribute names:
# - Qwen3-MoE: num_experts
# - GLM4-MoE: n_routed_experts, num_local_experts
# - Mixtral: num_local_experts
num_experts = None
for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
num_experts = getattr(config, attr, None)
if num_experts is not None:
break
# Check text_config for VL models
if num_experts is None and hasattr(config, "text_config"):
for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
num_experts = getattr(config.text_config, attr, None)
if num_experts is not None:
break
return num_experts is not None and num_experts > 0
def _resolve_moe_parameter_name(model, default_name: str, alternate_name: str) -> str:
"""
Resolve the actual parameter path for MoE expert weights.
Most current Unsloth MoE models expose expert weights under
``mlp.experts.*``. Gemma4 stores them directly under ``experts.*``.
Prefer the path that exists on the loaded module when possible.
"""
if hasattr(model, "named_parameters"):
try:
for name, _ in model.named_parameters():
if name == default_name or name.endswith("." + default_name):
return default_name
if name == alternate_name or name.endswith("." + alternate_name):
return alternate_name
except Exception:
pass
config = getattr(model, "config", model)
model_types = {getattr(config, "model_type", None)}
text_config = getattr(config, "text_config", None)
if text_config is not None:
model_types.add(getattr(text_config, "model_type", None))
if any(
isinstance(model_type, str) and model_type.startswith("gemma4")
for model_type in model_types
):
return alternate_name
return default_name
def get_moe_target_parameters(model, target_modules = None) -> Optional[List[str]]:
"""
Get the target_parameters for MoE expert layers if applicable.
For MoE models, returns the parameter paths for expert weights
(gate_up_proj, down_proj) that should be targeted by PEFT's
target_parameters for LoRA on nn.Parameter. The exact parameter path
depends on the model layout, for example ``mlp.experts.*`` or
``experts.*``.
Only includes MoE parameters that match what's in target_modules:
- If "down_proj" is in target_modules -> includes "mlp.experts.down_proj"
- If "gate_proj" or "up_proj" is in target_modules -> includes "mlp.experts.gate_up_proj"
Args:
model: The model to get target parameters for
target_modules: List/tuple of target module names to match against
Returns:
List of parameter paths for MoE experts, or None if not an MoE model
"""
if not is_moe_model(model):
return None
config = getattr(model, "config", model)
# Get num_experts from various possible config attributes
num_experts = None
for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
num_experts = getattr(config, attr, None)
if num_experts is not None:
break
if num_experts is None and hasattr(config, "text_config"):
for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
num_experts = getattr(config.text_config, attr, None)
if num_experts is not None:
break
if num_experts is None:
num_experts = 0
# Determine which MoE parameters to include based on target_modules
moe_params = []
# Normalize target_modules to a set for efficient lookup
if target_modules is None:
# If no target_modules specified, include all MoE params
target_set = {"gate_proj", "up_proj", "down_proj", "gate_up_proj"}
elif isinstance(target_modules, str):
target_set = {target_modules}
# Heuristic for regex matching MLPs
if "proj" in target_modules and (
"mlp" in target_modules or "ffn" in target_modules
):
target_set.update({"gate_proj", "up_proj", "down_proj", "gate_up_proj"})
else:
target_set = set(target_modules) if target_modules else set()
gate_up_name = _resolve_moe_parameter_name(
model,
default_name = "mlp.experts.gate_up_proj",
alternate_name = "experts.gate_up_proj",
)
down_name = _resolve_moe_parameter_name(
model,
default_name = "mlp.experts.down_proj",
alternate_name = "experts.down_proj",
)
# gate_up_proj combines both gate_proj and up_proj in MoE
# Also match "gate_up_proj" directly since users may specify the fused name
if (
"gate_proj" in target_set
or "up_proj" in target_set
or "gate_up_proj" in target_set
):
moe_params.append(gate_up_name)
if "down_proj" in target_set:
moe_params.append(down_name)
if moe_params:
print(
f"Unsloth: Detected MoE model with {num_experts = } and {target_modules = }. Enabling LoRA on MoE parameters: {moe_params}"
)
return moe_params
return None
def make_fast_generate_wrapper(original_generate):
"""
Creates a wrapper around model.generate that checks for incorrect
vLLM-style usage when fast_inference=False.
"""
@functools.wraps(original_generate)
def _fast_generate_wrapper(*args, **kwargs):
# Check for vLLM-specific arguments
if "sampling_params" in kwargs:
raise ValueError(
"Unsloth: `sampling_params` is only supported when `fast_inference=True` (vLLM). "
"Since `fast_inference=False`, use HuggingFace generate arguments instead:\n"
" model.fast_generate(**tokens.to('cuda'), max_new_tokens=64, temperature=1.0, top_p=0.95)"
)
if "lora_request" in kwargs:
raise ValueError(
"Unsloth: `lora_request` is only supported when `fast_inference=True` (vLLM). "
"Since `fast_inference=False`, LoRA weights are already merged into the model."
)
# Check if first positional argument is a string or list of strings
if len(args) > 0:
first_arg = args[0]
is_string_input = False
if isinstance(first_arg, str):
is_string_input = True
elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0:
if isinstance(first_arg[0], str):
is_string_input = True
if is_string_input:
raise ValueError(
"Unsloth: Passing text strings to `fast_generate` is only supported "
"when `fast_inference=True` (vLLM). Since `fast_inference=False`, you must "
"tokenize the input first:\n\n"
" messages = tokenizer.apply_chat_template(\n"
' [{"role": "user", "content": "Your prompt here"}],\n'
" tokenize=True, add_generation_prompt=True,\n"
' return_tensors="pt", return_dict=True\n'
" )\n"
" output = model.fast_generate(\n"
" **messages.to('cuda'),\n"
" max_new_tokens=64,\n"
" temperature=1.0,\n"
" )"
)
# Call original generate
return original_generate(*args, **kwargs)
return _fast_generate_wrapper
# Fix llm_int8_skip_modules not being respected for VLMs with dynamic quantization.
# Dynamic quant checkpoints (eg gemma-3-4b-it-unsloth-bnb-4bit) encode skip paths as
# "language_model.model.layers.*", but the live module tree surfaces them as
# "model.language_model.layers.*". This prefix mismatch causes should_convert_module
# to miss the skip list, so modules meant to stay in 16-bit get wrapped in Linear4bit
# without a quant_state, producing "Skipping ... no quant_state found" warnings.
# We patch should_convert_module to expand both the module name and the skip patterns
# into all equivalent alias forms before delegating to the original matcher.
# Ref: https://github.com/unslothai/unsloth/issues/4208
import transformers.quantizers.quantizers_utils as _quantizers_utils
if (
hasattr(_quantizers_utils, "should_convert_module")
and getattr(_quantizers_utils.should_convert_module, "__name__", "")
!= "patched_should_convert_module"
):
_original_should_convert_module = _quantizers_utils.should_convert_module
def _get_full_name_aliases(full_name):
aliases = {full_name}
if not isinstance(full_name, str):
return aliases
if full_name.startswith("model.language_model."):
aliases.add(full_name[len("model.") :])
if "language_model.model." in full_name:
aliases.add(full_name.replace("language_model.model.", "language_model."))
if full_name.startswith("model.language_model.model."):
aliases.add(
full_name[len("model.") :].replace(
"language_model.model.", "language_model."
)
)
return aliases
def _get_pattern_aliases(pattern):
aliases = {pattern}
if not isinstance(pattern, str):
return aliases
if "language_model.model." in pattern:
aliases.add(pattern.replace("language_model.model.", "language_model."))
return aliases
def _expand_patterns(patterns):
expanded = set()
for pattern in patterns:
expanded.update(_get_pattern_aliases(pattern))
return expanded
def patched_should_convert_module(full_name, patterns = None):
if patterns is None:
return _original_should_convert_module(full_name, patterns)
expanded_patterns = _expand_patterns(patterns)
return all(
_original_should_convert_module(candidate, expanded_patterns)
for candidate in _get_full_name_aliases(full_name)
)
patched_should_convert_module._original_should_convert_module = (
_original_should_convert_module
)
_quantizers_utils.should_convert_module = patched_should_convert_module
try:
import transformers.integrations.bitsandbytes
transformers.integrations.bitsandbytes.should_convert_module = (
patched_should_convert_module
)
except Exception:
pass