mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
fix Gemma4 flash attn disable (#5045)
* fix pass attn implementation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3869fbe1cc
commit
a4d4dfe4ac
2 changed files with 95 additions and 1 deletions
|
|
@ -233,6 +233,8 @@ def apply_unsloth_gradient_checkpointing(
|
|||
# access on some GPU architectures (B200). Falls back to eager safely.
|
||||
_FLEX_EXCLUDED_MODELS = ("gpt_oss", "mllama", "nemotron_h", "modernbert")
|
||||
_EAGER_ONLY_PREFIXES = ("gemma3n",)
|
||||
_FLASH_ATTENTION_DISABLED_MODELS = ("gemma4", "gemma4_text")
|
||||
_FLASH_ATTENTION_DISABLED_WARNED = set()
|
||||
|
||||
|
||||
def _is_flex_excluded(model_type):
|
||||
|
|
@ -243,6 +245,61 @@ def _is_eager_only(model_type):
|
|||
return any(model_type.startswith(p) for p in _EAGER_ONLY_PREFIXES)
|
||||
|
||||
|
||||
def _is_flash_attention_disabled(model_type):
|
||||
return model_type in _FLASH_ATTENTION_DISABLED_MODELS
|
||||
|
||||
|
||||
def _is_flash_attention_requested(attn_implementation):
|
||||
return isinstance(attn_implementation, str) and attn_implementation.startswith(
|
||||
"flash_attention"
|
||||
)
|
||||
|
||||
|
||||
def _disable_flash_attention_if_needed(
|
||||
model_type,
|
||||
config,
|
||||
attn_implementation = None,
|
||||
supports_sdpa = False,
|
||||
would_use_flash_attention = False,
|
||||
):
|
||||
if not _is_flash_attention_disabled(model_type):
|
||||
return attn_implementation
|
||||
|
||||
requested_attn_implementation = attn_implementation
|
||||
if requested_attn_implementation is None:
|
||||
requested_attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
if requested_attn_implementation is None:
|
||||
requested_attn_implementation = getattr(config, "attn_implementation", None)
|
||||
|
||||
if requested_attn_implementation == "eager":
|
||||
return _set_attn_impl(config, "eager")
|
||||
|
||||
fallback_attn_implementation = "sdpa" if supports_sdpa else "eager"
|
||||
if (
|
||||
_is_flash_attention_requested(requested_attn_implementation)
|
||||
or would_use_flash_attention
|
||||
):
|
||||
logged_attn_implementation = (
|
||||
requested_attn_implementation
|
||||
if _is_flash_attention_requested(requested_attn_implementation)
|
||||
else "flash_attention_2"
|
||||
)
|
||||
warning_key = (
|
||||
model_type,
|
||||
logged_attn_implementation,
|
||||
fallback_attn_implementation,
|
||||
)
|
||||
if warning_key not in _FLASH_ATTENTION_DISABLED_WARNED:
|
||||
_FLASH_ATTENTION_DISABLED_WARNED.add(warning_key)
|
||||
print(
|
||||
f"Unsloth: `{logged_attn_implementation}` is not supported "
|
||||
"for Gemma 4 - "
|
||||
f"defaulting to `{fallback_attn_implementation}`."
|
||||
)
|
||||
|
||||
return _set_attn_impl(config, fallback_attn_implementation)
|
||||
|
||||
|
||||
def _set_attn_impl(config, impl):
|
||||
"""Helper function to set attention implementation on config and return it."""
|
||||
if config is not None:
|
||||
|
|
@ -260,6 +317,19 @@ def determine_attention_implementation(model_class, config):
|
|||
_set_attn_impl(config, "eager")
|
||||
return "eager"
|
||||
|
||||
# Models with known Flash Attention incompatibilities. Gemma 4 full-attention
|
||||
# layers use global_head_dim=512, which exceeds Flash Attention's dense
|
||||
# head-dim support. Keep explicit eager requests, otherwise prefer SDPA.
|
||||
if _is_flash_attention_disabled(model_type):
|
||||
supports_sdpa = model_class is not None and getattr(
|
||||
model_class, "_supports_sdpa", False
|
||||
)
|
||||
return _disable_flash_attention_if_needed(
|
||||
model_type,
|
||||
config,
|
||||
supports_sdpa = supports_sdpa,
|
||||
)
|
||||
|
||||
# Flash Attention 2
|
||||
if HAS_FLASH_ATTENTION and model_class is not None:
|
||||
supports_fa2 = getattr(model_class, "_supports_flash_attn_2", False) or getattr(
|
||||
|
|
|
|||
|
|
@ -29,7 +29,13 @@ except:
|
|||
from ..kernels import (
|
||||
post_patch_loss_function,
|
||||
)
|
||||
from ._utils import __version__, importlib_version, _prepare_model_for_qat
|
||||
from ._utils import (
|
||||
__version__,
|
||||
importlib_version,
|
||||
_prepare_model_for_qat,
|
||||
_is_flash_attention_disabled,
|
||||
_disable_flash_attention_if_needed,
|
||||
)
|
||||
from ._utils import *
|
||||
from .loader_utils import _get_fp8_mode_and_check_settings
|
||||
from ..save import patch_saving_functions
|
||||
|
|
@ -607,6 +613,7 @@ class FastBaseModel:
|
|||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
user_attn_implementation = kwargs.get("attn_implementation", None)
|
||||
try:
|
||||
model_class = auto_model._model_mapping[auto_config.__class__]
|
||||
except Exception:
|
||||
|
|
@ -631,6 +638,23 @@ class FastBaseModel:
|
|||
|
||||
if not ("attn_implementation" in kwargs):
|
||||
kwargs["attn_implementation"] = attn_impl
|
||||
model_type = getattr(auto_config, "model_type", "").lower()
|
||||
if _is_flash_attention_disabled(model_type):
|
||||
supports_fa2 = model_class is not None and (
|
||||
getattr(model_class, "_supports_flash_attn_2", False)
|
||||
or getattr(model_class, "_supports_flash_attn", False)
|
||||
)
|
||||
kwargs["attn_implementation"] = _disable_flash_attention_if_needed(
|
||||
model_type,
|
||||
auto_config,
|
||||
kwargs.get("attn_implementation"),
|
||||
supports_sdpa = supports_sdpa,
|
||||
would_use_flash_attention = (
|
||||
user_attn_implementation is None
|
||||
and HAS_FLASH_ATTENTION
|
||||
and supports_fa2
|
||||
),
|
||||
)
|
||||
if not supports_sdpa and kwargs.get("attn_implementation") == "sdpa":
|
||||
print(
|
||||
f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager."
|
||||
|
|
|
|||
Loading…
Reference in a new issue