diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1e0b015c4..820a111a9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -65,7 +65,9 @@ __all__ = [ "patch_compiled_autograd", "process_vision_info", "unsloth_compile_transformers", - "determine_attention_implementation", + "resolve_model_class", + "resolve_attention_implementation", + "resolve_encoder_attention_implementation", "_set_attn_impl", "patch_fast_lora", "validate_loftq_config", @@ -233,7 +235,7 @@ def apply_unsloth_gradient_checkpointing( # access on some GPU architectures (B200). Falls back to eager safely. _FLEX_EXCLUDED_MODELS = ("gpt_oss", "mllama", "nemotron_h", "modernbert") _EAGER_ONLY_PREFIXES = ("gemma3n",) -_FLASH_ATTENTION_DISABLED_MODELS = ("gemma4", "gemma4_text") +_FLASH_ATTENTION_MAX_HEAD_DIM = 256 _FLASH_ATTENTION_DISABLED_WARNED = set() @@ -245,8 +247,102 @@ def _is_eager_only(model_type): return any(model_type.startswith(p) for p in _EAGER_ONLY_PREFIXES) -def _is_flash_attention_disabled(model_type): - return model_type in _FLASH_ATTENTION_DISABLED_MODELS +def _config_items(config): + if isinstance(config, dict): + return config.items() + if hasattr(config, "__dict__"): + return vars(config).items() + return () + + +def _config_get(config, field_name, default = None): + if isinstance(config, dict): + return config.get(field_name, default) + return getattr(config, field_name, default) + + +def _config_set(config, field_name, value): + if isinstance(config, dict): + config[field_name] = value + elif config is not None: + setattr(config, field_name, value) + + +def _iter_attention_configs(config, seen = None): + if config is None or ( + not isinstance(config, dict) and not hasattr(config, "__dict__") + ): + return + if seen is None: + seen = set() + config_id = id(config) + if config_id in seen: + return + seen.add(config_id) + yield config + + for field_name, child_config in _config_items(config): + if not isinstance(field_name, str) or not field_name.endswith("_config"): + continue + if isinstance(child_config, dict) or hasattr(child_config, "__dict__"): + yield from _iter_attention_configs(child_config, seen) + + +def _collect_attention_head_dims(config): + explicit_head_dims = [] + + for field_name in ( + "head_dim", + "global_head_dim", + "local_head_dim", + "kv_head_dim", + ): + value = _config_get(config, field_name, None) + if isinstance(value, int) and value > 0: + explicit_head_dims.append(value) + + if len(explicit_head_dims) != 0: + return explicit_head_dims + + head_dims = [] + + hidden_size_names = ("hidden_size", "d_model", "embed_dim", "dim") + num_heads_names = ("num_attention_heads", "num_heads", "n_heads") + for hidden_size_name in hidden_size_names: + hidden_size = _config_get(config, hidden_size_name, None) + if not isinstance(hidden_size, int) or hidden_size <= 0: + continue + for num_heads_name in num_heads_names: + num_heads = _config_get(config, num_heads_name, None) + if ( + isinstance(num_heads, int) + and num_heads > 0 + and (hidden_size % num_heads) == 0 + ): + head_dims.append(hidden_size // num_heads) + + return head_dims + + +def _get_max_attention_head_dim(config): + head_dims = [] + for attention_config in _iter_attention_configs(config): + head_dims.extend(_collect_attention_head_dims(attention_config)) + return max(head_dims) if len(head_dims) != 0 else None + + +def _get_flash_attention_disable_reason(config): + max_head_dim = _get_max_attention_head_dim(config) + if max_head_dim is not None and max_head_dim > _FLASH_ATTENTION_MAX_HEAD_DIM: + return ( + f"max attention head dim {max_head_dim} exceeds the Flash Attention 2 " + f"limit of {_FLASH_ATTENTION_MAX_HEAD_DIM}" + ) + return None + + +def _is_flash_attention_disabled(config): + return _get_flash_attention_disable_reason(config) is not None def _is_flash_attention_requested(attn_implementation): @@ -256,20 +352,24 @@ def _is_flash_attention_requested(attn_implementation): def _disable_flash_attention_if_needed( - model_type, config, attn_implementation = None, supports_sdpa = False, would_use_flash_attention = False, + disable_reason = None, ): - if not _is_flash_attention_disabled(model_type): + if disable_reason is None: + disable_reason = _get_flash_attention_disable_reason(config) + if disable_reason is None: return attn_implementation requested_attn_implementation = attn_implementation if requested_attn_implementation is None: - requested_attn_implementation = getattr(config, "_attn_implementation", None) + requested_attn_implementation = _config_get( + config, "_attn_implementation", None + ) if requested_attn_implementation is None: - requested_attn_implementation = getattr(config, "attn_implementation", None) + requested_attn_implementation = _config_get(config, "attn_implementation", None) if requested_attn_implementation == "eager": return _set_attn_impl(config, "eager") @@ -284,16 +384,18 @@ def _disable_flash_attention_if_needed( if _is_flash_attention_requested(requested_attn_implementation) else "flash_attention_2" ) + model_type = _config_get(config, "model_type", "") warning_key = ( model_type, logged_attn_implementation, fallback_attn_implementation, + disable_reason, ) if warning_key not in _FLASH_ATTENTION_DISABLED_WARNED: _FLASH_ATTENTION_DISABLED_WARNED.add(warning_key) print( f"Unsloth: `{logged_attn_implementation}` is not supported " - "for Gemma 4 - " + f"for `{model_type}` because {disable_reason} - " f"defaulting to `{fallback_attn_implementation}`." ) @@ -301,69 +403,125 @@ def _disable_flash_attention_if_needed( def _set_attn_impl(config, impl): - """Helper function to set attention implementation on config and return it.""" if config is not None: - setattr(config, "_attn_implementation", impl) - if hasattr(config, "attn_implementation"): - setattr(config, "attn_implementation", impl) + _config_set(config, "_attn_implementation", impl) + if isinstance(config, dict) or hasattr(config, "attn_implementation"): + _config_set(config, "attn_implementation", impl) return impl -def determine_attention_implementation(model_class, config): - model_type = getattr(config, "model_type", "").lower() +def resolve_model_class(auto_model, config): + mapping = getattr(auto_model, "_model_mapping", {}) + try: + result = mapping[config.__class__] + except Exception: + for config_class, model_class in mapping.items(): + if isinstance(config, config_class): + result = model_class + break + else: + return None - # Eager-only models (e.g. gemma3n timm vision towers) - if _is_eager_only(model_type): - _set_attn_impl(config, "eager") - return "eager" + return result[0] if isinstance(result, (list, tuple)) else result - # Models with known Flash Attention incompatibilities. Gemma 4 full-attention - # layers use global_head_dim=512, which exceeds Flash Attention's dense - # head-dim support. Keep explicit eager requests, otherwise prefer SDPA. - if _is_flash_attention_disabled(model_type): + +def resolve_attention_implementation( + model_class, + config, + requested_attn_implementation = None, + supports_sdpa = None, +): + model_type_name = _config_get(config, "model_type", "") + model_type = model_type_name.lower() + if supports_sdpa is None: supports_sdpa = model_class is not None and getattr( model_class, "_supports_sdpa", False ) - return _disable_flash_attention_if_needed( - model_type, + supports_flash_attention = model_class is not None and ( + getattr(model_class, "_supports_flash_attn_2", False) + or getattr(model_class, "_supports_flash_attn", False) + ) + disable_reason = _get_flash_attention_disable_reason(config) + flash_attention_disabled = disable_reason is not None + + if model_class is None: + attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager") + else: + if _is_eager_only(model_type): + attn_impl = _set_attn_impl(config, "eager") + elif flash_attention_disabled: + attn_impl = _disable_flash_attention_if_needed( + config, + supports_sdpa = supports_sdpa, + would_use_flash_attention = ( + HAS_FLASH_ATTENTION and supports_flash_attention + ), + disable_reason = disable_reason, + ) + elif HAS_FLASH_ATTENTION and supports_flash_attention: + attn_impl = _set_attn_impl(config, "flash_attention_2") + elif supports_sdpa: + attn_impl = _set_attn_impl(config, "sdpa") + else: + attn_impl = "eager" + if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0": + try: + from transformers.utils.import_utils import ( + is_torch_flex_attn_available, + ) + + if ( + is_torch_flex_attn_available() + and getattr(model_class, "_supports_flex_attn", False) + and not _is_flex_excluded(model_type) + ): + attention_dropout = ( + _config_get(config, "attention_dropout", 0) or 0 + ) + if attention_dropout == 0: + attn_impl = _set_attn_impl(config, "flex_attention") + except Exception: + pass + if attn_impl == "eager": + attn_impl = _set_attn_impl(config, "eager") + + if requested_attn_implementation is None: + final_attn_impl = attn_impl + elif flash_attention_disabled: + final_attn_impl = _disable_flash_attention_if_needed( config, + requested_attn_implementation, supports_sdpa = supports_sdpa, + disable_reason = disable_reason, ) + else: + final_attn_impl = requested_attn_implementation + _set_attn_impl(config, final_attn_impl) - # Flash Attention 2 - if HAS_FLASH_ATTENTION and model_class is not None: - supports_fa2 = getattr(model_class, "_supports_flash_attn_2", False) or getattr( - model_class, "_supports_flash_attn", False + if not supports_sdpa and final_attn_impl == "sdpa": + print( + f"Unsloth: {(model_type_name or 'model').title()} does not support SDPA - switching to fast eager." ) - if supports_fa2: - _set_attn_impl(config, "flash_attention_2") - return "flash_attention_2" + final_attn_impl = _set_attn_impl(config, "eager") - # Flex Attention - if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0": - try: - from transformers.utils.import_utils import is_torch_flex_attn_available + return final_attn_impl - if ( - is_torch_flex_attn_available() - and model_class is not None - and getattr(model_class, "_supports_flex_attn", False) - and not _is_flex_excluded(model_type) - ): - attention_dropout = getattr(config, "attention_dropout", 0) or 0 - if attention_dropout == 0: - _set_attn_impl(config, "flex_attention") - return "flex_attention" - except Exception: - pass - # SDPA - if model_class is not None and getattr(model_class, "_supports_sdpa", False): - _set_attn_impl(config, "sdpa") +def resolve_encoder_attention_implementation( + auto_model, + config, + model_type = "", + disable_sdpa_model_names = (), +): + model_class = resolve_model_class(auto_model, config) + supports_sdpa = model_class is not None and getattr( + model_class, "_supports_sdpa", False + ) + if any(name in model_type.lower() for name in disable_sdpa_model_names): + return "eager" + if supports_sdpa: return "sdpa" - - _set_attn_impl(config, "eager") - return "eager" + return None def _run_temporary_patches(phase): diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 425df1c08..999711efd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2346,7 +2346,7 @@ class FastLlamaModel: model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__] IS_FALCON_H1 = model_config.model_type.startswith("falcon_h1") - preferred_attn_impl = determine_attention_implementation( + preferred_attn_impl = resolve_attention_implementation( model_function, model_config ) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index cd12544ae..fc91178d8 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -1151,9 +1151,6 @@ class FastModel(FastBaseModel): ) os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1" - # Disable flex_attention for Gemma-4: flex compile overhead is 2.7x slower - # than SDPA. Our attention patch ensures Q/K/V dtype alignment for SDPA. - os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "0" # Gemma 3N must be before Gemma 3 elif "gemma3n" in model_types_all: if transformers_version < Version("4.53.0"): diff --git a/unsloth/models/sentence_transformer.py b/unsloth/models/sentence_transformer.py index ad59165a5..541875a3f 100644 --- a/unsloth/models/sentence_transformer.py +++ b/unsloth/models/sentence_transformer.py @@ -15,7 +15,11 @@ import logging from .loader import FastModel, DISABLE_SDPA_MODEL_NAMES -from ._utils import SUPPORTS_BFLOAT16 +from ._utils import ( + SUPPORTS_BFLOAT16, + resolve_model_class, + resolve_encoder_attention_implementation, +) import inspect import json import os @@ -31,7 +35,6 @@ import transformers from packaging.version import Version import re from transformers import AutoModel, AutoConfig -from transformers.models.auto.auto_factory import _get_model_class import tempfile from huggingface_hub import HfApi, get_token from ..save import unsloth_save_pretrained_torchao, unsloth_save_pretrained_gguf @@ -870,7 +873,7 @@ class FastSentenceTransformer(FastModel): if auto_model_class is None: auto_model_class = AutoModel # try to resolve the class - model_class = _get_model_class(config, auto_model_class._model_mapping) + model_class = resolve_model_class(auto_model_class, config) if model_class: sig = inspect.signature(model_class.__init__) @@ -1446,32 +1449,18 @@ class FastSentenceTransformer(FastModel): ): st_device = "cuda" - # Check if model supports SDPA (Scaled Dot Product Attention) for extra speedup - supports_sdpa = False - if config is not None: - try: - model_class = _get_model_class( - config, kwargs.get("auto_model", AutoModel)._model_mapping - ) - supports_sdpa = getattr(model_class, "_supports_sdpa", False) - except: - pass - # Build model_kwargs for SentenceTransformer model_kwargs = {"torch_dtype": dtype} - # Enable SDPA if supported (1.2x extra speedup on top of torch.compile) - # But disable for models with known SDPA + torch.compile backward issues - _force_eager = False - for _sdpa_model in DISABLE_SDPA_MODEL_NAMES: - if _sdpa_model in model_type.lower(): - supports_sdpa = False - _force_eager = True - break - if supports_sdpa: - model_kwargs["attn_implementation"] = "sdpa" - elif _force_eager: - model_kwargs["attn_implementation"] = "eager" + encoder_attn_impl = resolve_encoder_attention_implementation( + kwargs.get("auto_model", AutoModel), + config, + model_type = model_type, + disable_sdpa_model_names = DISABLE_SDPA_MODEL_NAMES, + ) + supports_sdpa = encoder_attn_impl == "sdpa" + if encoder_attn_impl is not None: + model_kwargs["attn_implementation"] = encoder_attn_impl # Print optimization status sdpa_str = " + SDPA" if supports_sdpa else "" diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index e31617f89..90c93ea3f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -33,8 +33,8 @@ from ._utils import ( __version__, importlib_version, _prepare_model_for_qat, - _is_flash_attention_disabled, - _disable_flash_attention_if_needed, + resolve_model_class, + resolve_attention_implementation, ) from ._utils import * from .loader_utils import _get_fp8_mode_and_check_settings @@ -613,55 +613,18 @@ class FastBaseModel: token = token, trust_remote_code = trust_remote_code, ) - user_attn_implementation = kwargs.get("attn_implementation", None) - try: - model_class = auto_model._model_mapping[auto_config.__class__] - except Exception: - model_class = None - if model_class is None: - # When model_class cannot be resolved (remote-code or unmapped - # configs), preserve the old fallback of sdpa when supported. - attn_impl = _set_attn_impl( - auto_config, "sdpa" if supports_sdpa else "eager" - ) - else: - attn_impl = determine_attention_implementation(model_class, auto_config) + model_class = resolve_model_class(auto_model, auto_config) + attn_impl = resolve_attention_implementation( + model_class, + auto_config, + requested_attn_implementation = kwargs.get("attn_implementation", None), + supports_sdpa = supports_sdpa, + ) # Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with # FP8 weights. We just need to update it here for sanity. auto_config.model_name = model_name - # Re-resolve model_class after potential config change - try: - model_class = auto_model._model_mapping[auto_config.__class__] - except Exception: - model_class = None - - if not ("attn_implementation" in kwargs): - kwargs["attn_implementation"] = attn_impl - model_type = getattr(auto_config, "model_type", "").lower() - if _is_flash_attention_disabled(model_type): - supports_fa2 = model_class is not None and ( - getattr(model_class, "_supports_flash_attn_2", False) - or getattr(model_class, "_supports_flash_attn", False) - ) - kwargs["attn_implementation"] = _disable_flash_attention_if_needed( - model_type, - auto_config, - kwargs.get("attn_implementation"), - supports_sdpa = supports_sdpa, - would_use_flash_attention = ( - user_attn_implementation is None - and HAS_FLASH_ATTENTION - and supports_fa2 - ), - ) - if not supports_sdpa and kwargs.get("attn_implementation") == "sdpa": - print( - f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager." - ) - del kwargs["attn_implementation"] - # Re-stamp config so it stays consistent with the actual impl - _set_attn_impl(auto_config, "eager") + kwargs["attn_implementation"] = attn_impl bnb_config = None user_quantization_config = kwargs.get("quantization_config", None) @@ -804,9 +767,7 @@ class FastBaseModel: token = token, trust_remote_code = trust_remote_code, ) - setattr(auto_config, "_attn_implementation", config_attn_impl) - if hasattr(auto_config, "attn_implementation"): - setattr(auto_config, "attn_implementation", config_attn_impl) + _set_attn_impl(auto_config, config_attn_impl) model_config = auto_config verify_fp8_support_if_applicable(model_config)