fix Gemma4 flash attn disable (#5045)

* fix pass attn implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
DoubleMathew 2026-04-15 17:50:48 -05:00 committed by GitHub
parent 3869fbe1cc
commit a4d4dfe4ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 95 additions and 1 deletions

View file

@ -233,6 +233,8 @@ def apply_unsloth_gradient_checkpointing(
# access on some GPU architectures (B200). Falls back to eager safely.
_FLEX_EXCLUDED_MODELS = ("gpt_oss", "mllama", "nemotron_h", "modernbert")
_EAGER_ONLY_PREFIXES = ("gemma3n",)
_FLASH_ATTENTION_DISABLED_MODELS = ("gemma4", "gemma4_text")
_FLASH_ATTENTION_DISABLED_WARNED = set()
def _is_flex_excluded(model_type):
@ -243,6 +245,61 @@ def _is_eager_only(model_type):
return any(model_type.startswith(p) for p in _EAGER_ONLY_PREFIXES)
def _is_flash_attention_disabled(model_type):
return model_type in _FLASH_ATTENTION_DISABLED_MODELS
def _is_flash_attention_requested(attn_implementation):
return isinstance(attn_implementation, str) and attn_implementation.startswith(
"flash_attention"
)
def _disable_flash_attention_if_needed(
model_type,
config,
attn_implementation = None,
supports_sdpa = False,
would_use_flash_attention = False,
):
if not _is_flash_attention_disabled(model_type):
return attn_implementation
requested_attn_implementation = attn_implementation
if requested_attn_implementation is None:
requested_attn_implementation = getattr(config, "_attn_implementation", None)
if requested_attn_implementation is None:
requested_attn_implementation = getattr(config, "attn_implementation", None)
if requested_attn_implementation == "eager":
return _set_attn_impl(config, "eager")
fallback_attn_implementation = "sdpa" if supports_sdpa else "eager"
if (
_is_flash_attention_requested(requested_attn_implementation)
or would_use_flash_attention
):
logged_attn_implementation = (
requested_attn_implementation
if _is_flash_attention_requested(requested_attn_implementation)
else "flash_attention_2"
)
warning_key = (
model_type,
logged_attn_implementation,
fallback_attn_implementation,
)
if warning_key not in _FLASH_ATTENTION_DISABLED_WARNED:
_FLASH_ATTENTION_DISABLED_WARNED.add(warning_key)
print(
f"Unsloth: `{logged_attn_implementation}` is not supported "
"for Gemma 4 - "
f"defaulting to `{fallback_attn_implementation}`."
)
return _set_attn_impl(config, fallback_attn_implementation)
def _set_attn_impl(config, impl):
"""Helper function to set attention implementation on config and return it."""
if config is not None:
@ -260,6 +317,19 @@ def determine_attention_implementation(model_class, config):
_set_attn_impl(config, "eager")
return "eager"
# Models with known Flash Attention incompatibilities. Gemma 4 full-attention
# layers use global_head_dim=512, which exceeds Flash Attention's dense
# head-dim support. Keep explicit eager requests, otherwise prefer SDPA.
if _is_flash_attention_disabled(model_type):
supports_sdpa = model_class is not None and getattr(
model_class, "_supports_sdpa", False
)
return _disable_flash_attention_if_needed(
model_type,
config,
supports_sdpa = supports_sdpa,
)
# Flash Attention 2
if HAS_FLASH_ATTENTION and model_class is not None:
supports_fa2 = getattr(model_class, "_supports_flash_attn_2", False) or getattr(

View file

@ -29,7 +29,13 @@ except:
from ..kernels import (
post_patch_loss_function,
)
from ._utils import __version__, importlib_version, _prepare_model_for_qat
from ._utils import (
__version__,
importlib_version,
_prepare_model_for_qat,
_is_flash_attention_disabled,
_disable_flash_attention_if_needed,
)
from ._utils import *
from .loader_utils import _get_fp8_mode_and_check_settings
from ..save import patch_saving_functions
@ -607,6 +613,7 @@ class FastBaseModel:
token = token,
trust_remote_code = trust_remote_code,
)
user_attn_implementation = kwargs.get("attn_implementation", None)
try:
model_class = auto_model._model_mapping[auto_config.__class__]
except Exception:
@ -631,6 +638,23 @@ class FastBaseModel:
if not ("attn_implementation" in kwargs):
kwargs["attn_implementation"] = attn_impl
model_type = getattr(auto_config, "model_type", "").lower()
if _is_flash_attention_disabled(model_type):
supports_fa2 = model_class is not None and (
getattr(model_class, "_supports_flash_attn_2", False)
or getattr(model_class, "_supports_flash_attn", False)
)
kwargs["attn_implementation"] = _disable_flash_attention_if_needed(
model_type,
auto_config,
kwargs.get("attn_implementation"),
supports_sdpa = supports_sdpa,
would_use_flash_attention = (
user_attn_implementation is None
and HAS_FLASH_ATTENTION
and supports_fa2
),
)
if not supports_sdpa and kwargs.get("attn_implementation") == "sdpa":
print(
f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager."