Restrict flash attn to <=256 head dim. Consolidate attn impl checks (#5051)

* Restrict flash attn to <=256 head dim. Consolidate attn impl checks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Consolidate the changes into single function

* safeguard for dict instead of object

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Datta Nimmaturi 2026-04-16 19:30:17 +05:30 committed by GitHub
parent c5be8b1cd2
commit 6764cb9b90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 239 additions and 134 deletions

View file

@ -65,7 +65,9 @@ __all__ = [
"patch_compiled_autograd",
"process_vision_info",
"unsloth_compile_transformers",
"determine_attention_implementation",
"resolve_model_class",
"resolve_attention_implementation",
"resolve_encoder_attention_implementation",
"_set_attn_impl",
"patch_fast_lora",
"validate_loftq_config",
@ -233,7 +235,7 @@ def apply_unsloth_gradient_checkpointing(
# access on some GPU architectures (B200). Falls back to eager safely.
_FLEX_EXCLUDED_MODELS = ("gpt_oss", "mllama", "nemotron_h", "modernbert")
_EAGER_ONLY_PREFIXES = ("gemma3n",)
_FLASH_ATTENTION_DISABLED_MODELS = ("gemma4", "gemma4_text")
_FLASH_ATTENTION_MAX_HEAD_DIM = 256
_FLASH_ATTENTION_DISABLED_WARNED = set()
@ -245,8 +247,102 @@ def _is_eager_only(model_type):
return any(model_type.startswith(p) for p in _EAGER_ONLY_PREFIXES)
def _is_flash_attention_disabled(model_type):
return model_type in _FLASH_ATTENTION_DISABLED_MODELS
def _config_items(config):
if isinstance(config, dict):
return config.items()
if hasattr(config, "__dict__"):
return vars(config).items()
return ()
def _config_get(config, field_name, default = None):
if isinstance(config, dict):
return config.get(field_name, default)
return getattr(config, field_name, default)
def _config_set(config, field_name, value):
if isinstance(config, dict):
config[field_name] = value
elif config is not None:
setattr(config, field_name, value)
def _iter_attention_configs(config, seen = None):
if config is None or (
not isinstance(config, dict) and not hasattr(config, "__dict__")
):
return
if seen is None:
seen = set()
config_id = id(config)
if config_id in seen:
return
seen.add(config_id)
yield config
for field_name, child_config in _config_items(config):
if not isinstance(field_name, str) or not field_name.endswith("_config"):
continue
if isinstance(child_config, dict) or hasattr(child_config, "__dict__"):
yield from _iter_attention_configs(child_config, seen)
def _collect_attention_head_dims(config):
explicit_head_dims = []
for field_name in (
"head_dim",
"global_head_dim",
"local_head_dim",
"kv_head_dim",
):
value = _config_get(config, field_name, None)
if isinstance(value, int) and value > 0:
explicit_head_dims.append(value)
if len(explicit_head_dims) != 0:
return explicit_head_dims
head_dims = []
hidden_size_names = ("hidden_size", "d_model", "embed_dim", "dim")
num_heads_names = ("num_attention_heads", "num_heads", "n_heads")
for hidden_size_name in hidden_size_names:
hidden_size = _config_get(config, hidden_size_name, None)
if not isinstance(hidden_size, int) or hidden_size <= 0:
continue
for num_heads_name in num_heads_names:
num_heads = _config_get(config, num_heads_name, None)
if (
isinstance(num_heads, int)
and num_heads > 0
and (hidden_size % num_heads) == 0
):
head_dims.append(hidden_size // num_heads)
return head_dims
def _get_max_attention_head_dim(config):
head_dims = []
for attention_config in _iter_attention_configs(config):
head_dims.extend(_collect_attention_head_dims(attention_config))
return max(head_dims) if len(head_dims) != 0 else None
def _get_flash_attention_disable_reason(config):
max_head_dim = _get_max_attention_head_dim(config)
if max_head_dim is not None and max_head_dim > _FLASH_ATTENTION_MAX_HEAD_DIM:
return (
f"max attention head dim {max_head_dim} exceeds the Flash Attention 2 "
f"limit of {_FLASH_ATTENTION_MAX_HEAD_DIM}"
)
return None
def _is_flash_attention_disabled(config):
return _get_flash_attention_disable_reason(config) is not None
def _is_flash_attention_requested(attn_implementation):
@ -256,20 +352,24 @@ def _is_flash_attention_requested(attn_implementation):
def _disable_flash_attention_if_needed(
model_type,
config,
attn_implementation = None,
supports_sdpa = False,
would_use_flash_attention = False,
disable_reason = None,
):
if not _is_flash_attention_disabled(model_type):
if disable_reason is None:
disable_reason = _get_flash_attention_disable_reason(config)
if disable_reason is None:
return attn_implementation
requested_attn_implementation = attn_implementation
if requested_attn_implementation is None:
requested_attn_implementation = getattr(config, "_attn_implementation", None)
requested_attn_implementation = _config_get(
config, "_attn_implementation", None
)
if requested_attn_implementation is None:
requested_attn_implementation = getattr(config, "attn_implementation", None)
requested_attn_implementation = _config_get(config, "attn_implementation", None)
if requested_attn_implementation == "eager":
return _set_attn_impl(config, "eager")
@ -284,16 +384,18 @@ def _disable_flash_attention_if_needed(
if _is_flash_attention_requested(requested_attn_implementation)
else "flash_attention_2"
)
model_type = _config_get(config, "model_type", "")
warning_key = (
model_type,
logged_attn_implementation,
fallback_attn_implementation,
disable_reason,
)
if warning_key not in _FLASH_ATTENTION_DISABLED_WARNED:
_FLASH_ATTENTION_DISABLED_WARNED.add(warning_key)
print(
f"Unsloth: `{logged_attn_implementation}` is not supported "
"for Gemma 4 - "
f"for `{model_type}` because {disable_reason} - "
f"defaulting to `{fallback_attn_implementation}`."
)
@ -301,69 +403,125 @@ def _disable_flash_attention_if_needed(
def _set_attn_impl(config, impl):
"""Helper function to set attention implementation on config and return it."""
if config is not None:
setattr(config, "_attn_implementation", impl)
if hasattr(config, "attn_implementation"):
setattr(config, "attn_implementation", impl)
_config_set(config, "_attn_implementation", impl)
if isinstance(config, dict) or hasattr(config, "attn_implementation"):
_config_set(config, "attn_implementation", impl)
return impl
def determine_attention_implementation(model_class, config):
model_type = getattr(config, "model_type", "").lower()
def resolve_model_class(auto_model, config):
mapping = getattr(auto_model, "_model_mapping", {})
try:
result = mapping[config.__class__]
except Exception:
for config_class, model_class in mapping.items():
if isinstance(config, config_class):
result = model_class
break
else:
return None
# Eager-only models (e.g. gemma3n timm vision towers)
if _is_eager_only(model_type):
_set_attn_impl(config, "eager")
return "eager"
return result[0] if isinstance(result, (list, tuple)) else result
# Models with known Flash Attention incompatibilities. Gemma 4 full-attention
# layers use global_head_dim=512, which exceeds Flash Attention's dense
# head-dim support. Keep explicit eager requests, otherwise prefer SDPA.
if _is_flash_attention_disabled(model_type):
def resolve_attention_implementation(
model_class,
config,
requested_attn_implementation = None,
supports_sdpa = None,
):
model_type_name = _config_get(config, "model_type", "")
model_type = model_type_name.lower()
if supports_sdpa is None:
supports_sdpa = model_class is not None and getattr(
model_class, "_supports_sdpa", False
)
return _disable_flash_attention_if_needed(
model_type,
supports_flash_attention = model_class is not None and (
getattr(model_class, "_supports_flash_attn_2", False)
or getattr(model_class, "_supports_flash_attn", False)
)
disable_reason = _get_flash_attention_disable_reason(config)
flash_attention_disabled = disable_reason is not None
if model_class is None:
attn_impl = _set_attn_impl(config, "sdpa" if supports_sdpa else "eager")
else:
if _is_eager_only(model_type):
attn_impl = _set_attn_impl(config, "eager")
elif flash_attention_disabled:
attn_impl = _disable_flash_attention_if_needed(
config,
supports_sdpa = supports_sdpa,
would_use_flash_attention = (
HAS_FLASH_ATTENTION and supports_flash_attention
),
disable_reason = disable_reason,
)
elif HAS_FLASH_ATTENTION and supports_flash_attention:
attn_impl = _set_attn_impl(config, "flash_attention_2")
elif supports_sdpa:
attn_impl = _set_attn_impl(config, "sdpa")
else:
attn_impl = "eager"
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
try:
from transformers.utils.import_utils import (
is_torch_flex_attn_available,
)
if (
is_torch_flex_attn_available()
and getattr(model_class, "_supports_flex_attn", False)
and not _is_flex_excluded(model_type)
):
attention_dropout = (
_config_get(config, "attention_dropout", 0) or 0
)
if attention_dropout == 0:
attn_impl = _set_attn_impl(config, "flex_attention")
except Exception:
pass
if attn_impl == "eager":
attn_impl = _set_attn_impl(config, "eager")
if requested_attn_implementation is None:
final_attn_impl = attn_impl
elif flash_attention_disabled:
final_attn_impl = _disable_flash_attention_if_needed(
config,
requested_attn_implementation,
supports_sdpa = supports_sdpa,
disable_reason = disable_reason,
)
else:
final_attn_impl = requested_attn_implementation
_set_attn_impl(config, final_attn_impl)
# Flash Attention 2
if HAS_FLASH_ATTENTION and model_class is not None:
supports_fa2 = getattr(model_class, "_supports_flash_attn_2", False) or getattr(
model_class, "_supports_flash_attn", False
if not supports_sdpa and final_attn_impl == "sdpa":
print(
f"Unsloth: {(model_type_name or 'model').title()} does not support SDPA - switching to fast eager."
)
if supports_fa2:
_set_attn_impl(config, "flash_attention_2")
return "flash_attention_2"
final_attn_impl = _set_attn_impl(config, "eager")
# Flex Attention
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") != "0":
try:
from transformers.utils.import_utils import is_torch_flex_attn_available
return final_attn_impl
if (
is_torch_flex_attn_available()
and model_class is not None
and getattr(model_class, "_supports_flex_attn", False)
and not _is_flex_excluded(model_type)
):
attention_dropout = getattr(config, "attention_dropout", 0) or 0
if attention_dropout == 0:
_set_attn_impl(config, "flex_attention")
return "flex_attention"
except Exception:
pass
# SDPA
if model_class is not None and getattr(model_class, "_supports_sdpa", False):
_set_attn_impl(config, "sdpa")
def resolve_encoder_attention_implementation(
auto_model,
config,
model_type = "",
disable_sdpa_model_names = (),
):
model_class = resolve_model_class(auto_model, config)
supports_sdpa = model_class is not None and getattr(
model_class, "_supports_sdpa", False
)
if any(name in model_type.lower() for name in disable_sdpa_model_names):
return "eager"
if supports_sdpa:
return "sdpa"
_set_attn_impl(config, "eager")
return "eager"
return None
def _run_temporary_patches(phase):

View file

@ -2346,7 +2346,7 @@ class FastLlamaModel:
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
IS_FALCON_H1 = model_config.model_type.startswith("falcon_h1")
preferred_attn_impl = determine_attention_implementation(
preferred_attn_impl = resolve_attention_implementation(
model_function, model_config
)

View file

@ -1151,9 +1151,6 @@ class FastModel(FastBaseModel):
)
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
# Disable flex_attention for Gemma-4: flex compile overhead is 2.7x slower
# than SDPA. Our attention patch ensures Q/K/V dtype alignment for SDPA.
os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "0"
# Gemma 3N must be before Gemma 3
elif "gemma3n" in model_types_all:
if transformers_version < Version("4.53.0"):

View file

@ -15,7 +15,11 @@
import logging
from .loader import FastModel, DISABLE_SDPA_MODEL_NAMES
from ._utils import SUPPORTS_BFLOAT16
from ._utils import (
SUPPORTS_BFLOAT16,
resolve_model_class,
resolve_encoder_attention_implementation,
)
import inspect
import json
import os
@ -31,7 +35,6 @@ import transformers
from packaging.version import Version
import re
from transformers import AutoModel, AutoConfig
from transformers.models.auto.auto_factory import _get_model_class
import tempfile
from huggingface_hub import HfApi, get_token
from ..save import unsloth_save_pretrained_torchao, unsloth_save_pretrained_gguf
@ -870,7 +873,7 @@ class FastSentenceTransformer(FastModel):
if auto_model_class is None:
auto_model_class = AutoModel
# try to resolve the class
model_class = _get_model_class(config, auto_model_class._model_mapping)
model_class = resolve_model_class(auto_model_class, config)
if model_class:
sig = inspect.signature(model_class.__init__)
@ -1446,32 +1449,18 @@ class FastSentenceTransformer(FastModel):
):
st_device = "cuda"
# Check if model supports SDPA (Scaled Dot Product Attention) for extra speedup
supports_sdpa = False
if config is not None:
try:
model_class = _get_model_class(
config, kwargs.get("auto_model", AutoModel)._model_mapping
)
supports_sdpa = getattr(model_class, "_supports_sdpa", False)
except:
pass
# Build model_kwargs for SentenceTransformer
model_kwargs = {"torch_dtype": dtype}
# Enable SDPA if supported (1.2x extra speedup on top of torch.compile)
# But disable for models with known SDPA + torch.compile backward issues
_force_eager = False
for _sdpa_model in DISABLE_SDPA_MODEL_NAMES:
if _sdpa_model in model_type.lower():
supports_sdpa = False
_force_eager = True
break
if supports_sdpa:
model_kwargs["attn_implementation"] = "sdpa"
elif _force_eager:
model_kwargs["attn_implementation"] = "eager"
encoder_attn_impl = resolve_encoder_attention_implementation(
kwargs.get("auto_model", AutoModel),
config,
model_type = model_type,
disable_sdpa_model_names = DISABLE_SDPA_MODEL_NAMES,
)
supports_sdpa = encoder_attn_impl == "sdpa"
if encoder_attn_impl is not None:
model_kwargs["attn_implementation"] = encoder_attn_impl
# Print optimization status
sdpa_str = " + SDPA" if supports_sdpa else ""

View file

@ -33,8 +33,8 @@ from ._utils import (
__version__,
importlib_version,
_prepare_model_for_qat,
_is_flash_attention_disabled,
_disable_flash_attention_if_needed,
resolve_model_class,
resolve_attention_implementation,
)
from ._utils import *
from .loader_utils import _get_fp8_mode_and_check_settings
@ -613,55 +613,18 @@ class FastBaseModel:
token = token,
trust_remote_code = trust_remote_code,
)
user_attn_implementation = kwargs.get("attn_implementation", None)
try:
model_class = auto_model._model_mapping[auto_config.__class__]
except Exception:
model_class = None
if model_class is None:
# When model_class cannot be resolved (remote-code or unmapped
# configs), preserve the old fallback of sdpa when supported.
attn_impl = _set_attn_impl(
auto_config, "sdpa" if supports_sdpa else "eager"
)
else:
attn_impl = determine_attention_implementation(model_class, auto_config)
model_class = resolve_model_class(auto_model, auto_config)
attn_impl = resolve_attention_implementation(
model_class,
auto_config,
requested_attn_implementation = kwargs.get("attn_implementation", None),
supports_sdpa = supports_sdpa,
)
# Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with
# FP8 weights. We just need to update it here for sanity.
auto_config.model_name = model_name
# Re-resolve model_class after potential config change
try:
model_class = auto_model._model_mapping[auto_config.__class__]
except Exception:
model_class = None
if not ("attn_implementation" in kwargs):
kwargs["attn_implementation"] = attn_impl
model_type = getattr(auto_config, "model_type", "").lower()
if _is_flash_attention_disabled(model_type):
supports_fa2 = model_class is not None and (
getattr(model_class, "_supports_flash_attn_2", False)
or getattr(model_class, "_supports_flash_attn", False)
)
kwargs["attn_implementation"] = _disable_flash_attention_if_needed(
model_type,
auto_config,
kwargs.get("attn_implementation"),
supports_sdpa = supports_sdpa,
would_use_flash_attention = (
user_attn_implementation is None
and HAS_FLASH_ATTENTION
and supports_fa2
),
)
if not supports_sdpa and kwargs.get("attn_implementation") == "sdpa":
print(
f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager."
)
del kwargs["attn_implementation"]
# Re-stamp config so it stays consistent with the actual impl
_set_attn_impl(auto_config, "eager")
kwargs["attn_implementation"] = attn_impl
bnb_config = None
user_quantization_config = kwargs.get("quantization_config", None)
@ -804,9 +767,7 @@ class FastBaseModel:
token = token,
trust_remote_code = trust_remote_code,
)
setattr(auto_config, "_attn_implementation", config_attn_impl)
if hasattr(auto_config, "attn_implementation"):
setattr(auto_config, "attn_implementation", config_attn_impl)
_set_attn_impl(auto_config, config_attn_impl)
model_config = auto_config
verify_fp8_support_if_applicable(model_config)