mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Restrict flash attn to <=256 head dim. Consolidate attn impl checks (#5051)
* Restrict flash attn to <=256 head dim. Consolidate attn impl checks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Consolidate the changes into single function * safeguard for dict instead of object * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c5be8b1cd2
commit
6764cb9b90
5 changed files with 239 additions and 134 deletions
|
|
@ -65,7 +65,9 @@ __all__ = [
|
|||
"patch_compiled_autograd",
|
||||
"process_vision_info",
|
||||
"unsloth_compile_transformers",
|
||||
"determine_attention_implementation",
|
||||
"resolve_model_class",
|
||||
"resolve_attention_implementation",
|
||||
"resolve_encoder_attention_implementation",
|
||||
"_set_attn_impl",
|
||||
"patch_fast_lora",
|
||||
"validate_loftq_config",
|
||||
|
|
@ -233,7 +235,7 @@ def apply_unsloth_gradient_checkpointing(
|
|||
# access on some GPU architectures (B200). Falls back to eager safely.
|
||||
_FLEX_EXCLUDED_MODELS = ("gpt_oss", "mllama", "nemotron_h", "modernbert")
|
||||
_EAGER_ONLY_PREFIXES = ("gemma3n",)
|
||||
_FLASH_ATTENTION_DISABLED_MODELS = ("gemma4", "gemma4_text")
|
||||
_FLASH_ATTENTION_MAX_HEAD_DIM = 256
|
||||
_FLASH_ATTENTION_DISABLED_WARNED = set()
|
||||
|
||||
|
||||
|
|
@ -245,8 +247,102 @@ def _is_eager_only(model_type):
|
|||
return any(model_type.startswith(p) for p in _EAGER_ONLY_PREFIXES)
|
||||
|
||||
|
||||
def _is_flash_attention_disabled(model_type):
|
||||
return model_type in _FLASH_ATTENTION_DISABLED_MODELS
|
||||
def _config_items(config):
|
||||
if isinstance(config, dict):
|
||||
return config.items()
|
||||
if hasattr(config, "__dict__"):
|
||||
return vars(config).items()
|
||||
return ()
|
||||
|
||||
|
||||
def _config_get(config, field_name, default = None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(field_name, default)
|
||||
return getattr(config, field_name, default)
|
||||
|
||||
|
||||
def _config_set(config, field_name, value):
|
||||
if isinstance(config, dict):
|
||||
config[field_name] = value
|
||||
elif config is not None:
|
||||
setattr(config, field_name, value)
|
||||
|
||||
|
||||
def _iter_attention_configs(config, seen = None):
|
||||
if config is None or (
|
||||
not isinstance(config, dict) and not hasattr(config, "__dict__")
|
||||
):
|
||||
return
|
||||
if seen is None:
|
||||
seen = set()
|
||||
config_id = id(config)
|
||||
if config_id in seen:
|
||||
return
|
||||
seen.add(config_id)
|
||||
yield config
|
||||
|
||||
for field_name, child_config in _config_items(config):
|
||||
if not isinstance(field_name, str) or not field_name.endswith("_config"):
|
||||
continue
|
||||
if isinstance(child_config, dict) or hasattr(child_config, "__dict__"):
|
||||
yield from _iter_attention_configs(child_config, seen)
|
||||
|
||||
|
||||
def _collect_attention_head_dims(config):
|
||||
explicit_head_dims = []
|
||||
|
||||
for field_name in (
|
||||
"head_dim",
|
||||
"global_head_dim",
|
||||
"local_head_dim",
|
||||
"kv_head_dim",
|
||||
):
|
||||
value = _config_get(config, field_name, None)
|
||||
if isinstance(value, int) and value > 0:
|
||||
explicit_head_dims.append(value)
|
||||
|
||||
if len(explicit_head_dims) != 0:
|
||||
return explicit_head_dims
|
||||
|
||||
head_dims = []
|
||||
|
||||
hidden_size_names = ("hidden_size", "d_model", "embed_dim", "dim")
|
||||
num_heads_names = ("num_attention_heads", "num_heads", "n_heads")
|
||||
for hidden_size_name in hidden_size_names:
|
||||
hidden_size = _config_get(config, hidden_size_name, None)
|
||||
if not isinstance(hidden_size, int) or hidden_size <= 0:
|
||||
continue
|
||||
for num_heads_name in num_heads_names:
|
||||
num_heads = _config_get(config, num_heads_name, None)
|
||||
if (
|
||||
isinstance(num_heads, int)
|
||||
and num_heads > 0
|
||||
and (hidden_size % num_heads) == 0
|
||||
):
|
||||
head_dims.append(hidden_size // num_heads)
|
||||
|
||||
return head_dims
|
||||
|
||||
|
||||
def _get_max_attention_head_dim(config):
|
||||
head_dims = []
|
||||
for attention_config in _iter_attention_configs(config):
|
||||
head_dims.extend(_collect_attention_head_dims(attention_config))
|
||||
return max(head_dims) if len(head_dims) != 0 else None
|
||||
|
||||
|
||||
def _get_flash_attention_disable_reason(config):
|
||||
max_head_dim = _get_max_attention_head_dim(config)
|
||||
if max_head_dim is not None and max_head_dim > _FLASH_ATTENTION_MAX_HEAD_DIM:
|
||||
return (
|
||||
f"max attention head dim {max_head_dim} exceeds the Flash Attention 2 "
|
||||
f"limit of {_FLASH_ATTENTION_MAX_HEAD_DIM}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _is_flash_attention_disabled(config):
|
||||
return _get_flash_attention_disable_reason(config) is not None
|
||||
|
||||
|
||||
def _is_flash_attention_requested(attn_implementation):
|
||||
|
|
@ -256,20 +352,24 @@ def _is_flash_attention_requested(attn_implementation):
|
|||
|
||||
|
||||
def _disable_flash_attention_if_needed(
|
||||
model_type,
|
||||
config,
|
||||
attn_implementation = None,
|
||||
supports_sdpa = False,
|
||||
would_use_flash_attention = False,
|
||||
disable_reason = None,
|
||||
):
|
||||
if not _is_flash_attention_disabled(model_type):
|
||||
if disable_reason is None:
|
||||
disable_reason = _get_flash_attention_disable_reason(config)
|
||||
if disable_reason is None:
|
||||
return attn_implementation
|
||||
|
||||
requested_attn_implementation = attn_implementation
|
||||
if requested_attn_implementation is None:
|
||||
requested_attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
requested_attn_implementation = _config_get(
|
||||
config, "_attn_implementation", None
|
||||
)
|
||||
if requested_attn_implementation is None:
|
||||
requested_attn_implementation = getattr(config, "attn_implementation", None)
|
||||
requested_attn_implementation = _config_get(config, "attn_implementation", None)
|
||||
|
||||
if requested_attn_implementation == "eager":
|
||||
return _set_attn_impl(config, "eager")
|
||||
|
|
@ -284,16 +384,18 @@ def _disable_flash_attention_if_needed(
|
|||
if _is_flash_attention_requested(requested_attn_implementation)
|
||||
else "flash_attention_2"
|
||||
)
|
||||
model_type = _config_get(config, "model_type", "")
|
||||
warning_key = (
|
||||
model_type,
|
||||
logged_attn_implementation,
|
||||
fallback_attn_implementation,
|
||||
disable_reason,
|
||||
)
|
||||
if warning_key not in _FLASH_ATTENTION_DISABLED_WARNED:
|
||||
_FLASH_ATTENTION_DISABLED_WARNED.add(warning_key)
|
||||
print(
|
||||
f"Unsloth: `{logged_attn_implementation}` is not supported "
|
||||
"for Gemma 4 - "
|
||||
f"for `{model_type}` because {disable_reason} - "
|
||||
f"defaulting to `{fallback_attn_implementation}`."
|
||||
)
|
||||
|
||||
|
|
@ -301,69 +403,125 @@ def _disable_flash_attention_if_needed(
|
|||
|
||||
|
||||
def _set_attn_impl(config, impl):
|
||||
"""Helper function to set attention implementation on config and return it."""
|
||||
if config is not None:
|
||||
setattr(config, "_attn_implementation", impl)
|
||||
if hasattr(config, "attn_implementation"):
|
||||
setattr(config, "attn_implementation", impl)
|
||||
_config_set(config, "_attn_implementation", impl)
|
||||
if isinstance(config, dict) or hasattr(config, "attn_implementation"):
|
||||
_config_set(config, "attn_implementation", impl)
|
||||
return impl
|
||||
|
||||
|
||||
def determine_attention_implementation(model_class, config):
|
||||
model_type = getattr(config, "model_type", "").lower()
|
||||
def resolve_model_class(auto_model, config):
|
||||
mapping = getattr(auto_model, "_model_mapping", {})
|
||||
try:
|
||||
result = mapping[config.__class__]
|
||||
except Exception:
|
||||
for config_class, model_class in mapping.items():
|
||||
if isinstance(config, config_class):
|
||||
result = model_class
|
||||
break
|
||||
else:
|
||||
return None
|
||||
|
||||
# Eager-only models (e.g. gemma3n timm vision towers)
|
||||
if _is_eager_only(model_type):
|
||||
_set_attn_impl(config, "eager")
|
||||
return "eager"
|
||||
return result[0] if isinstance(result, (list, tuple)) else result
|
||||
|
||||
# Models with known Flash Attention incompatibilities. Gemma 4 full-attention
|
||||
# layers use global_head_dim=512, which exceeds Flash Attention's dense
|
||||
# head-dim support. Keep explicit eager requests, otherwise prefer SDPA.
|
||||
if _is_flash_attention_disabled(model_type):
|
||||
|
||||
def resolve_attention_implementation(
|
||||
model_class,
|
||||
config,
|
||||
requested_attn_implementation = None,
|
||||
supports_sdpa = None,
|
||||
):
|
||||
model_type_name = _config_get(config, "model_type", "")
|
||||
model_type = model_type_name.lower()
|
||||
if supports_sdpa is None:
|
||||
supports_sdpa = model_class is not None and getattr(
|
||||
model_class, "_supports_sdpa", False
|
||||
)
|
||||
return _disable_flash_attention_if_needed(
|
||||
model_type,
|
||||
supports_flash_attention = model_class is not None and (
|
||||
getattr(model_class, "_supports_flash_attn_2", False)
|
||||
or getattr(model_class, "_supports_flash_attn", False)
|
||||
)
|
||||
disable_reason = _get_flash_attention_disable_reason(config)
|
||||
flash_attention_disabled = disable_reason is not None
|
||||
|
||||
if model_class is None:
|
||||
attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager")
|
||||
else:
|
||||
if _is_eager_only(model_type):
|
||||
attn_impl = _set_attn_impl(config, "eager")
|
||||
elif flash_attention_disabled:
|
||||
attn_impl = _disable_flash_attention_if_needed(
|
||||
config,
|
||||
supports_sdpa = supports_sdpa,
|
||||
would_use_flash_attention = (
|
||||
HAS_FLASH_ATTENTION and supports_flash_attention
|
||||
),
|
||||
disable_reason = disable_reason,
|
||||
)
|
||||
elif HAS_FLASH_ATTENTION and supports_flash_attention:
|
||||
attn_impl = _set_attn_impl(config, "flash_attention_2")
|
||||
elif supports_sdpa:
|
||||
attn_impl = _set_attn_impl(config, "sdpa")
|
||||
else:
|
||||
attn_impl = "eager"
|
||||
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
|
||||
try:
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_flex_attn_available,
|
||||
)
|
||||
|
||||
if (
|
||||
is_torch_flex_attn_available()
|
||||
and getattr(model_class, "_supports_flex_attn", False)
|
||||
and not _is_flex_excluded(model_type)
|
||||
):
|
||||
attention_dropout = (
|
||||
_config_get(config, "attention_dropout", 0) or 0
|
||||
)
|
||||
if attention_dropout == 0:
|
||||
attn_impl = _set_attn_impl(config, "flex_attention")
|
||||
except Exception:
|
||||
pass
|
||||
if attn_impl == "eager":
|
||||
attn_impl = _set_attn_impl(config, "eager")
|
||||
|
||||
if requested_attn_implementation is None:
|
||||
final_attn_impl = attn_impl
|
||||
elif flash_attention_disabled:
|
||||
final_attn_impl = _disable_flash_attention_if_needed(
|
||||
config,
|
||||
requested_attn_implementation,
|
||||
supports_sdpa = supports_sdpa,
|
||||
disable_reason = disable_reason,
|
||||
)
|
||||
else:
|
||||
final_attn_impl = requested_attn_implementation
|
||||
_set_attn_impl(config, final_attn_impl)
|
||||
|
||||
# Flash Attention 2
|
||||
if HAS_FLASH_ATTENTION and model_class is not None:
|
||||
supports_fa2 = getattr(model_class, "_supports_flash_attn_2", False) or getattr(
|
||||
model_class, "_supports_flash_attn", False
|
||||
if not supports_sdpa and final_attn_impl == "sdpa":
|
||||
print(
|
||||
f"Unsloth: {(model_type_name or 'model').title()} does not support SDPA - switching to fast eager."
|
||||
)
|
||||
if supports_fa2:
|
||||
_set_attn_impl(config, "flash_attention_2")
|
||||
return "flash_attention_2"
|
||||
final_attn_impl = _set_attn_impl(config, "eager")
|
||||
|
||||
# Flex Attention
|
||||
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
|
||||
try:
|
||||
from transformers.utils.import_utils import is_torch_flex_attn_available
|
||||
return final_attn_impl
|
||||
|
||||
if (
|
||||
is_torch_flex_attn_available()
|
||||
and model_class is not None
|
||||
and getattr(model_class, "_supports_flex_attn", False)
|
||||
and not _is_flex_excluded(model_type)
|
||||
):
|
||||
attention_dropout = getattr(config, "attention_dropout", 0) or 0
|
||||
if attention_dropout == 0:
|
||||
_set_attn_impl(config, "flex_attention")
|
||||
return "flex_attention"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# SDPA
|
||||
if model_class is not None and getattr(model_class, "_supports_sdpa", False):
|
||||
_set_attn_impl(config, "sdpa")
|
||||
def resolve_encoder_attention_implementation(
|
||||
auto_model,
|
||||
config,
|
||||
model_type = "",
|
||||
disable_sdpa_model_names = (),
|
||||
):
|
||||
model_class = resolve_model_class(auto_model, config)
|
||||
supports_sdpa = model_class is not None and getattr(
|
||||
model_class, "_supports_sdpa", False
|
||||
)
|
||||
if any(name in model_type.lower() for name in disable_sdpa_model_names):
|
||||
return "eager"
|
||||
if supports_sdpa:
|
||||
return "sdpa"
|
||||
|
||||
_set_attn_impl(config, "eager")
|
||||
return "eager"
|
||||
return None
|
||||
|
||||
|
||||
def _run_temporary_patches(phase):
|
||||
|
|
|
|||
|
|
@ -2346,7 +2346,7 @@ class FastLlamaModel:
|
|||
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
|
||||
IS_FALCON_H1 = model_config.model_type.startswith("falcon_h1")
|
||||
|
||||
preferred_attn_impl = determine_attention_implementation(
|
||||
preferred_attn_impl = resolve_attention_implementation(
|
||||
model_function, model_config
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1151,9 +1151,6 @@ class FastModel(FastBaseModel):
|
|||
)
|
||||
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
|
||||
os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
|
||||
# Disable flex_attention for Gemma-4: flex compile overhead is 2.7x slower
|
||||
# than SDPA. Our attention patch ensures Q/K/V dtype alignment for SDPA.
|
||||
os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "0"
|
||||
# Gemma 3N must be before Gemma 3
|
||||
elif "gemma3n" in model_types_all:
|
||||
if transformers_version < Version("4.53.0"):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@
|
|||
import logging
|
||||
|
||||
from .loader import FastModel, DISABLE_SDPA_MODEL_NAMES
|
||||
from ._utils import SUPPORTS_BFLOAT16
|
||||
from ._utils import (
|
||||
SUPPORTS_BFLOAT16,
|
||||
resolve_model_class,
|
||||
resolve_encoder_attention_implementation,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
|
|
@ -31,7 +35,6 @@ import transformers
|
|||
from packaging.version import Version
|
||||
import re
|
||||
from transformers import AutoModel, AutoConfig
|
||||
from transformers.models.auto.auto_factory import _get_model_class
|
||||
import tempfile
|
||||
from huggingface_hub import HfApi, get_token
|
||||
from ..save import unsloth_save_pretrained_torchao, unsloth_save_pretrained_gguf
|
||||
|
|
@ -870,7 +873,7 @@ class FastSentenceTransformer(FastModel):
|
|||
if auto_model_class is None:
|
||||
auto_model_class = AutoModel
|
||||
# try to resolve the class
|
||||
model_class = _get_model_class(config, auto_model_class._model_mapping)
|
||||
model_class = resolve_model_class(auto_model_class, config)
|
||||
|
||||
if model_class:
|
||||
sig = inspect.signature(model_class.__init__)
|
||||
|
|
@ -1446,32 +1449,18 @@ class FastSentenceTransformer(FastModel):
|
|||
):
|
||||
st_device = "cuda"
|
||||
|
||||
# Check if model supports SDPA (Scaled Dot Product Attention) for extra speedup
|
||||
supports_sdpa = False
|
||||
if config is not None:
|
||||
try:
|
||||
model_class = _get_model_class(
|
||||
config, kwargs.get("auto_model", AutoModel)._model_mapping
|
||||
)
|
||||
supports_sdpa = getattr(model_class, "_supports_sdpa", False)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Build model_kwargs for SentenceTransformer
|
||||
model_kwargs = {"torch_dtype": dtype}
|
||||
|
||||
# Enable SDPA if supported (1.2x extra speedup on top of torch.compile)
|
||||
# But disable for models with known SDPA + torch.compile backward issues
|
||||
_force_eager = False
|
||||
for _sdpa_model in DISABLE_SDPA_MODEL_NAMES:
|
||||
if _sdpa_model in model_type.lower():
|
||||
supports_sdpa = False
|
||||
_force_eager = True
|
||||
break
|
||||
if supports_sdpa:
|
||||
model_kwargs["attn_implementation"] = "sdpa"
|
||||
elif _force_eager:
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
encoder_attn_impl = resolve_encoder_attention_implementation(
|
||||
kwargs.get("auto_model", AutoModel),
|
||||
config,
|
||||
model_type = model_type,
|
||||
disable_sdpa_model_names = DISABLE_SDPA_MODEL_NAMES,
|
||||
)
|
||||
supports_sdpa = encoder_attn_impl == "sdpa"
|
||||
if encoder_attn_impl is not None:
|
||||
model_kwargs["attn_implementation"] = encoder_attn_impl
|
||||
|
||||
# Print optimization status
|
||||
sdpa_str = " + SDPA" if supports_sdpa else ""
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ from ._utils import (
|
|||
__version__,
|
||||
importlib_version,
|
||||
_prepare_model_for_qat,
|
||||
_is_flash_attention_disabled,
|
||||
_disable_flash_attention_if_needed,
|
||||
resolve_model_class,
|
||||
resolve_attention_implementation,
|
||||
)
|
||||
from ._utils import *
|
||||
from .loader_utils import _get_fp8_mode_and_check_settings
|
||||
|
|
@ -613,55 +613,18 @@ class FastBaseModel:
|
|||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
user_attn_implementation = kwargs.get("attn_implementation", None)
|
||||
try:
|
||||
model_class = auto_model._model_mapping[auto_config.__class__]
|
||||
except Exception:
|
||||
model_class = None
|
||||
if model_class is None:
|
||||
# When model_class cannot be resolved (remote-code or unmapped
|
||||
# configs), preserve the old fallback of sdpa when supported.
|
||||
attn_impl = _set_attn_impl(
|
||||
auto_config, "sdpa" if supports_sdpa else "eager"
|
||||
)
|
||||
else:
|
||||
attn_impl = determine_attention_implementation(model_class, auto_config)
|
||||
model_class = resolve_model_class(auto_model, auto_config)
|
||||
attn_impl = resolve_attention_implementation(
|
||||
model_class,
|
||||
auto_config,
|
||||
requested_attn_implementation = kwargs.get("attn_implementation", None),
|
||||
supports_sdpa = supports_sdpa,
|
||||
)
|
||||
|
||||
# Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with
|
||||
# FP8 weights. We just need to update it here for sanity.
|
||||
auto_config.model_name = model_name
|
||||
# Re-resolve model_class after potential config change
|
||||
try:
|
||||
model_class = auto_model._model_mapping[auto_config.__class__]
|
||||
except Exception:
|
||||
model_class = None
|
||||
|
||||
if not ("attn_implementation" in kwargs):
|
||||
kwargs["attn_implementation"] = attn_impl
|
||||
model_type = getattr(auto_config, "model_type", "").lower()
|
||||
if _is_flash_attention_disabled(model_type):
|
||||
supports_fa2 = model_class is not None and (
|
||||
getattr(model_class, "_supports_flash_attn_2", False)
|
||||
or getattr(model_class, "_supports_flash_attn", False)
|
||||
)
|
||||
kwargs["attn_implementation"] = _disable_flash_attention_if_needed(
|
||||
model_type,
|
||||
auto_config,
|
||||
kwargs.get("attn_implementation"),
|
||||
supports_sdpa = supports_sdpa,
|
||||
would_use_flash_attention = (
|
||||
user_attn_implementation is None
|
||||
and HAS_FLASH_ATTENTION
|
||||
and supports_fa2
|
||||
),
|
||||
)
|
||||
if not supports_sdpa and kwargs.get("attn_implementation") == "sdpa":
|
||||
print(
|
||||
f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager."
|
||||
)
|
||||
del kwargs["attn_implementation"]
|
||||
# Re-stamp config so it stays consistent with the actual impl
|
||||
_set_attn_impl(auto_config, "eager")
|
||||
kwargs["attn_implementation"] = attn_impl
|
||||
|
||||
bnb_config = None
|
||||
user_quantization_config = kwargs.get("quantization_config", None)
|
||||
|
|
@ -804,9 +767,7 @@ class FastBaseModel:
|
|||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
setattr(auto_config, "_attn_implementation", config_attn_impl)
|
||||
if hasattr(auto_config, "attn_implementation"):
|
||||
setattr(auto_config, "attn_implementation", config_attn_impl)
|
||||
_set_attn_impl(auto_config, config_attn_impl)
|
||||
model_config = auto_config
|
||||
|
||||
verify_fp8_support_if_applicable(model_config)
|
||||
|
|
|
|||
Loading…
Reference in a new issue