Fix grad-accum accepts_loss_kwargs detection for vision wrappers (#5036)

* Fix grad-accum model_accepts_loss_kwargs detection for vision wrappers

Replace the source-string rewrite of Trainer.__init__ with an instance-level
accepts_loss_kwargs shadow applied on the loaded model. Covers:

  1. Unsloth-compiled forward -> True, so HF Trainer does not double-scale
     on top of unsloth_fixed_cross_entropy's num_items_in_batch division.
  2. Stock forward on a conditional-generation wrapper (Gemma3n, Gemma3
     pre-4.57, Qwen-VL family, etc.) where the outer class has no
     accepts_loss_kwargs but the inner .model declares False -> False.
     This is the case that reproduces issue #4982 under trust_remote_code
     or UNSLOTH_COMPILE_DISABLE, where the previous fix's outer-attr
     check walked past the inner model and fell through to signature
     inspection.
  3. Text LMs without any explicit accepts_loss_kwargs -> leave HF default.

The previous .replace()-based patch silently no-ops on transformers 4.48
through 4.52 (variable named model, not unwrapped_model) and is fragile
against any upstream reformat. The new helper walks the PEFT / HF wrapper
chain, finds the first class that declares accepts_loss_kwargs on its own
class dict (type(m).__dict__, not hasattr, to avoid PEFT __getattr__
forwarding), and setattr-shadows that value at every wrapper level so
HF Trainer's hasattr(unwrapped_model, ...) check picks it up at whichever
level accelerate.unwrap_model returns.

Also adds an unconditional post-init clamp of
accelerator.gradient_accumulation_steps = 1 to work around the
transformers 5.0 through 5.5 GradientAccumulationPlugin regression that
makes accelerator.backward divide loss by GA on top of training_step's
own /GA division. Fixed upstream in 5.6.0.dev0; no-op on 4.x and 5.6+.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Trim comments

* Address review: cover PEFT-after-load and custom compile location

Two review findings from 3/20 reviewers:

1. [3 of 20 reviewers] apply_accepts_loss_kwargs_fix was called from the
   loaders before get_peft_model wraps the base model, so on transformers
   4.48-4.52 (which does hasattr on the outer model) the instance shadow
   on the base model was lost after PEFT wrapping. Fix: also call it from
   the wrapped Trainer.__init__ so it runs on whatever model the user
   actually hands to Trainer, which is always the final wrapped form.

2. [1 of 20 reviewers] _forward_is_unsloth_compiled hard-coded the
   substrings "unsloth_compiled" / "unsloth_cache" in the co_filename
   check, which misclassifies compiled forwards when
   UNSLOTH_COMPILE_LOCATION is set to a custom directory. Fix: new
   _unsloth_compile_cache_leaves helper that reads the env var and
   matches the basename against path components, honoring both the
   default and any user override.

Verified locally:
- PEFT-after-load simulation: HF's hasattr(peft, "accepts_loss_kwargs")
  now returns True after our init wrapper runs, and value resolves to
  False on Gemma3n-style inner wrappers.
- Custom UNSLOTH_COMPILE_LOCATION simulation: compiled detection returns
  True for /tmp/my_custom_cache/compiled.py when the env var is set.
- End-to-end Gemma-3 270m + LoRA SFT unchanged: loss 4.9626, grad-norm
  matches prior run, all 4 wrapper levels now carry the shadowed attr.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-15 06:59:36 -07:00 committed by GitHub
parent 1ccfd2e0a5
commit 1a4ca5eca8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 142 additions and 38 deletions

View file

@ -45,6 +45,7 @@ __all__ = [
# "accelerate_old_send_to_device",
# "accelerate_new_send_to_device",
"patch_gradient_accumulation_fix",
"apply_accepts_loss_kwargs_fix",
"patch_compiling_bitsandbytes",
"patch_regional_compilation",
"patch_layernorm",
@ -2083,47 +2084,148 @@ def patch_gradient_accumulation_fix(Trainer):
exec(function, globals())
Trainer.training_step = _unsloth_training_step
# Prevent double scaling gradient accumulation
# https://github.com/huggingface/transformers/pull/37208
# Patch model_accepts_loss_kwargs detection in Trainer.__init__
if Trainer.__init__.__name__ != "_unsloth___init__":
# Wrap Trainer.__init__: (1) pre-init, shadow accepts_loss_kwargs on whatever
# model was passed in (covers PEFT wrapping done after FastModel.from_pretrained);
# (2) post-init, clamp accelerator GA to 1 for the transformers 5.0-5.5
# GradientAccumulationPlugin regression. No-op on 4.x and 5.6+. See #4982.
if not getattr(Trainer, "_unsloth_init_wrapped_for_accelerate_gas", False):
_original_trainer_init = Trainer.__init__
def _unsloth_trainer_init(self, *args, **kwargs):
model = kwargs.get("model")
if model is None and len(args) > 0:
model = args[0]
if model is not None:
try:
apply_accepts_loss_kwargs_fix(model)
except Exception:
pass
_original_trainer_init(self, *args, **kwargs)
try:
accelerator = getattr(self, "accelerator", None)
if (
accelerator is not None
and getattr(accelerator, "gradient_accumulation_steps", 1) > 1
):
accelerator.gradient_accumulation_steps = 1
gs = getattr(accelerator, "gradient_state", None)
if gs is not None and hasattr(gs, "plugin_kwargs"):
try:
gs.plugin_kwargs["num_steps"] = 1
except Exception:
pass
except Exception:
pass
_unsloth_trainer_init.__wrapped__ = _original_trainer_init
Trainer.__init__ = _unsloth_trainer_init
Trainer._unsloth_init_wrapped_for_accelerate_gas = True
def _unsloth_compile_cache_leaves():
# Accepts `UNSLOTH_COMPILE_LOCATION` overrides (the env var unsloth_zoo honors).
leaves = {"unsloth_compiled_cache", "unsloth_cache", "unsloth_compiled"}
loc = os.environ.get("UNSLOTH_COMPILE_LOCATION", "") or ""
loc = loc.rstrip("/\\")
if loc:
leaves.add(os.path.basename(loc) or loc)
return leaves
def _forward_is_unsloth_compiled(model):
# True iff forward was installed from the Unsloth compile cache directory.
# __module__ stays as the transformers module, so check co_filename.
leaves = _unsloth_compile_cache_leaves()
def check(m):
if m is None:
return False
fwd = getattr(type(m), "forward", None)
if fwd is None:
return False
code = getattr(fwd, "__code__", None)
fn = getattr(code, "co_filename", "") if code is not None else ""
fn = fn.replace("\\", "/")
parts = set(fn.split("/"))
return any(leaf in parts for leaf in leaves)
if check(model):
return True
seen = set()
m = model
for _ in range(4):
if m is None or id(m) in seen:
break
seen.add(id(m))
nxt = getattr(m, "base_model", None)
if nxt is None or nxt is m:
nxt = getattr(m, "model", None)
if nxt is None or nxt is m:
break
if check(nxt):
return True
m = nxt
return False
def _find_concrete_accepts_loss_kwargs(model):
# Walk wrapper chain for first class that declares accepts_loss_kwargs in its
# own __mro__ dict. Avoids PEFT __getattr__ forwarding and our own shadow.
seen = set()
m = model
for _ in range(6):
if m is None or id(m) in seen:
break
seen.add(id(m))
for klass in type(m).__mro__:
if "accepts_loss_kwargs" in klass.__dict__:
return klass.__dict__[
"accepts_loss_kwargs"
], f"{klass.__name__}.accepts_loss_kwargs"
nxt = getattr(m, "base_model", None)
if nxt is None or nxt is m:
nxt = getattr(m, "model", None)
if nxt is None or nxt is m:
break
m = nxt
return None, "no explicit accepts_loss_kwargs on any wrapper level"
def _shadow_accepts_loss_kwargs(model, value):
# Set the attribute at every wrapper level so HF's hasattr check resolves
# regardless of where accelerator / peft unwrap lands.
seen = set()
m = model
for _ in range(8):
if m is None or id(m) in seen:
break
seen.add(id(m))
try:
init_function = inspect.getsource(Trainer.__init__)
setattr(m, "accepts_loss_kwargs", value)
except Exception:
init_function = ""
if init_function is not None:
init_function = textwrap.dedent(init_function)
pass
nxt = getattr(m, "base_model", None)
if nxt is None or nxt is m:
nxt = getattr(m, "model", None)
if nxt is None or nxt is m:
break
m = nxt
# Import all variables that need importing
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
if item in init_function:
good_items.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in good_items)
+ ")",
globals(),
)
def apply_accepts_loss_kwargs_fix(model):
# Shadow the correct accepts_loss_kwargs on the model so HF Trainer picks it
# up via hasattr(unwrapped_model, ...). Replaces the old Trainer.__init__
# source rewrite. Priority: compiled forward -> True; else first class attr
# in wrapper chain; else leave HF default. Issue #4982.
if _forward_is_unsloth_compiled(model):
_shadow_accepts_loss_kwargs(model, True)
return "True (Unsloth compiled forward)"
init_function = init_function.replace(
"def __init__", "def _unsloth___init__", 1
)
# Respect an inner wrapped model's explicit accepts_loss_kwargs flag before inferring from forward(**kwargs).
# https://github.com/unslothai/unsloth/issues/4982 Gemma4ForConditionalGeneration had issues with grad_acc
init_function = init_function.replace(
"self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n else:",
"self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n"
' elif hasattr(getattr(unwrapped_model, "model", None), "accepts_loss_kwargs"):\n'
" self.model_accepts_loss_kwargs = unwrapped_model.model.accepts_loss_kwargs\n"
" else:",
)
exec(init_function, globals())
Trainer.__init__ = _unsloth___init__
value, reason = _find_concrete_accepts_loss_kwargs(model)
if value is None:
return f"default (signature inspection, {reason})"
_shadow_accepts_loss_kwargs(model, value)
return f"{value} ({reason})"
def patch_tokenizer(model, tokenizer):

View file

@ -2695,7 +2695,8 @@ class FastLlamaModel:
patch_saving_functions(model)
Trainer._inner_training_loop = _fast_inner_training_loop
# Fix gradient accumulation
# Fix gradient accumulation. See issue #4982.
apply_accepts_loss_kwargs_fix(model)
patch_gradient_accumulation_fix(Trainer)
# Save tokenizer for inference purposes

View file

@ -1123,9 +1123,10 @@ class FastBaseModel:
)
patch_saving_functions(tokenizer, vision = True)
# Fix gradient accumulation
# Fix gradient accumulation. See issue #4982.
from transformers.trainer import Trainer
apply_accepts_loss_kwargs_fix(model)
patch_gradient_accumulation_fix(Trainer)
# Save tokenizer for inference purposes