mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix grad-accum accepts_loss_kwargs detection for vision wrappers (#5036)
* Fix grad-accum model_accepts_loss_kwargs detection for vision wrappers
Replace the source-string rewrite of Trainer.__init__ with an instance-level
accepts_loss_kwargs shadow applied on the loaded model. Covers:
1. Unsloth-compiled forward -> True, so HF Trainer does not double-scale
on top of unsloth_fixed_cross_entropy's num_items_in_batch division.
2. Stock forward on a conditional-generation wrapper (Gemma3n, Gemma3
pre-4.57, Qwen-VL family, etc.) where the outer class has no
accepts_loss_kwargs but the inner .model declares False -> False.
This is the case that reproduces issue #4982 under trust_remote_code
or UNSLOTH_COMPILE_DISABLE, where the previous fix's outer-attr
check walked past the inner model and fell through to signature
inspection.
3. Text LMs without any explicit accepts_loss_kwargs -> leave HF default.
The previous .replace()-based patch silently no-ops on transformers 4.48
through 4.52 (variable named model, not unwrapped_model) and is fragile
against any upstream reformat. The new helper walks the PEFT / HF wrapper
chain, finds the first class that declares accepts_loss_kwargs on its own
class dict (type(m).__dict__, not hasattr, to avoid PEFT __getattr__
forwarding), and setattr-shadows that value at every wrapper level so
HF Trainer's hasattr(unwrapped_model, ...) check picks it up at whichever
level accelerate.unwrap_model returns.
Also adds an unconditional post-init clamp of
accelerator.gradient_accumulation_steps = 1 to work around the
transformers 5.0 through 5.5 GradientAccumulationPlugin regression that
makes accelerator.backward divide loss by GA on top of training_step's
own /GA division. Fixed upstream in 5.6.0.dev0; no-op on 4.x and 5.6+.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Trim comments
* Address review: cover PEFT-after-load and custom compile location
Two review findings from 3/20 reviewers:
1. [3 of 20 reviewers] apply_accepts_loss_kwargs_fix was called from the
loaders before get_peft_model wraps the base model, so on transformers
4.48-4.52 (which does hasattr on the outer model) the instance shadow
on the base model was lost after PEFT wrapping. Fix: also call it from
the wrapped Trainer.__init__ so it runs on whatever model the user
actually hands to Trainer, which is always the final wrapped form.
2. [1 of 20 reviewers] _forward_is_unsloth_compiled hard-coded the
substrings "unsloth_compiled" / "unsloth_cache" in the co_filename
check, which misclassifies compiled forwards when
UNSLOTH_COMPILE_LOCATION is set to a custom directory. Fix: new
_unsloth_compile_cache_leaves helper that reads the env var and
matches the basename against path components, honoring both the
default and any user override.
Verified locally:
- PEFT-after-load simulation: HF's hasattr(peft, "accepts_loss_kwargs")
now returns True after our init wrapper runs, and value resolves to
False on Gemma3n-style inner wrappers.
- Custom UNSLOTH_COMPILE_LOCATION simulation: compiled detection returns
True for /tmp/my_custom_cache/compiled.py when the env var is set.
- End-to-end Gemma-3 270m + LoRA SFT unchanged: loss 4.9626, grad-norm
matches prior run, all 4 wrapper levels now carry the shadowed attr.
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1ccfd2e0a5
commit
1a4ca5eca8
3 changed files with 142 additions and 38 deletions
|
|
@ -45,6 +45,7 @@ __all__ = [
|
|||
# "accelerate_old_send_to_device",
|
||||
# "accelerate_new_send_to_device",
|
||||
"patch_gradient_accumulation_fix",
|
||||
"apply_accepts_loss_kwargs_fix",
|
||||
"patch_compiling_bitsandbytes",
|
||||
"patch_regional_compilation",
|
||||
"patch_layernorm",
|
||||
|
|
@ -2083,47 +2084,148 @@ def patch_gradient_accumulation_fix(Trainer):
|
|||
exec(function, globals())
|
||||
Trainer.training_step = _unsloth_training_step
|
||||
|
||||
# Prevent double scaling gradient accumulation
|
||||
# https://github.com/huggingface/transformers/pull/37208
|
||||
# Patch model_accepts_loss_kwargs detection in Trainer.__init__
|
||||
if Trainer.__init__.__name__ != "_unsloth___init__":
|
||||
# Wrap Trainer.__init__: (1) pre-init, shadow accepts_loss_kwargs on whatever
|
||||
# model was passed in (covers PEFT wrapping done after FastModel.from_pretrained);
|
||||
# (2) post-init, clamp accelerator GA to 1 for the transformers 5.0-5.5
|
||||
# GradientAccumulationPlugin regression. No-op on 4.x and 5.6+. See #4982.
|
||||
if not getattr(Trainer, "_unsloth_init_wrapped_for_accelerate_gas", False):
|
||||
_original_trainer_init = Trainer.__init__
|
||||
|
||||
def _unsloth_trainer_init(self, *args, **kwargs):
|
||||
model = kwargs.get("model")
|
||||
if model is None and len(args) > 0:
|
||||
model = args[0]
|
||||
if model is not None:
|
||||
try:
|
||||
apply_accepts_loss_kwargs_fix(model)
|
||||
except Exception:
|
||||
pass
|
||||
_original_trainer_init(self, *args, **kwargs)
|
||||
try:
|
||||
accelerator = getattr(self, "accelerator", None)
|
||||
if (
|
||||
accelerator is not None
|
||||
and getattr(accelerator, "gradient_accumulation_steps", 1) > 1
|
||||
):
|
||||
accelerator.gradient_accumulation_steps = 1
|
||||
gs = getattr(accelerator, "gradient_state", None)
|
||||
if gs is not None and hasattr(gs, "plugin_kwargs"):
|
||||
try:
|
||||
gs.plugin_kwargs["num_steps"] = 1
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_unsloth_trainer_init.__wrapped__ = _original_trainer_init
|
||||
Trainer.__init__ = _unsloth_trainer_init
|
||||
Trainer._unsloth_init_wrapped_for_accelerate_gas = True
|
||||
|
||||
|
||||
def _unsloth_compile_cache_leaves():
|
||||
# Accepts `UNSLOTH_COMPILE_LOCATION` overrides (the env var unsloth_zoo honors).
|
||||
leaves = {"unsloth_compiled_cache", "unsloth_cache", "unsloth_compiled"}
|
||||
loc = os.environ.get("UNSLOTH_COMPILE_LOCATION", "") or ""
|
||||
loc = loc.rstrip("/\\")
|
||||
if loc:
|
||||
leaves.add(os.path.basename(loc) or loc)
|
||||
return leaves
|
||||
|
||||
|
||||
def _forward_is_unsloth_compiled(model):
|
||||
# True iff forward was installed from the Unsloth compile cache directory.
|
||||
# __module__ stays as the transformers module, so check co_filename.
|
||||
leaves = _unsloth_compile_cache_leaves()
|
||||
|
||||
def check(m):
|
||||
if m is None:
|
||||
return False
|
||||
fwd = getattr(type(m), "forward", None)
|
||||
if fwd is None:
|
||||
return False
|
||||
code = getattr(fwd, "__code__", None)
|
||||
fn = getattr(code, "co_filename", "") if code is not None else ""
|
||||
fn = fn.replace("\\", "/")
|
||||
parts = set(fn.split("/"))
|
||||
return any(leaf in parts for leaf in leaves)
|
||||
|
||||
if check(model):
|
||||
return True
|
||||
seen = set()
|
||||
m = model
|
||||
for _ in range(4):
|
||||
if m is None or id(m) in seen:
|
||||
break
|
||||
seen.add(id(m))
|
||||
nxt = getattr(m, "base_model", None)
|
||||
if nxt is None or nxt is m:
|
||||
nxt = getattr(m, "model", None)
|
||||
if nxt is None or nxt is m:
|
||||
break
|
||||
if check(nxt):
|
||||
return True
|
||||
m = nxt
|
||||
return False
|
||||
|
||||
|
||||
def _find_concrete_accepts_loss_kwargs(model):
|
||||
# Walk wrapper chain for first class that declares accepts_loss_kwargs in its
|
||||
# own __mro__ dict. Avoids PEFT __getattr__ forwarding and our own shadow.
|
||||
seen = set()
|
||||
m = model
|
||||
for _ in range(6):
|
||||
if m is None or id(m) in seen:
|
||||
break
|
||||
seen.add(id(m))
|
||||
for klass in type(m).__mro__:
|
||||
if "accepts_loss_kwargs" in klass.__dict__:
|
||||
return klass.__dict__[
|
||||
"accepts_loss_kwargs"
|
||||
], f"{klass.__name__}.accepts_loss_kwargs"
|
||||
nxt = getattr(m, "base_model", None)
|
||||
if nxt is None or nxt is m:
|
||||
nxt = getattr(m, "model", None)
|
||||
if nxt is None or nxt is m:
|
||||
break
|
||||
m = nxt
|
||||
return None, "no explicit accepts_loss_kwargs on any wrapper level"
|
||||
|
||||
|
||||
def _shadow_accepts_loss_kwargs(model, value):
|
||||
# Set the attribute at every wrapper level so HF's hasattr check resolves
|
||||
# regardless of where accelerator / peft unwrap lands.
|
||||
seen = set()
|
||||
m = model
|
||||
for _ in range(8):
|
||||
if m is None or id(m) in seen:
|
||||
break
|
||||
seen.add(id(m))
|
||||
try:
|
||||
init_function = inspect.getsource(Trainer.__init__)
|
||||
setattr(m, "accepts_loss_kwargs", value)
|
||||
except Exception:
|
||||
init_function = ""
|
||||
if init_function is not None:
|
||||
init_function = textwrap.dedent(init_function)
|
||||
pass
|
||||
nxt = getattr(m, "base_model", None)
|
||||
if nxt is None or nxt is m:
|
||||
nxt = getattr(m, "model", None)
|
||||
if nxt is None or nxt is m:
|
||||
break
|
||||
m = nxt
|
||||
|
||||
# Import all variables that need importing
|
||||
import transformers.trainer
|
||||
|
||||
items_in_trainer = dir(transformers.trainer)
|
||||
good_items = []
|
||||
for item in items_in_trainer:
|
||||
if item in init_function:
|
||||
good_items.append(item)
|
||||
exec(
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in good_items)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
def apply_accepts_loss_kwargs_fix(model):
|
||||
# Shadow the correct accepts_loss_kwargs on the model so HF Trainer picks it
|
||||
# up via hasattr(unwrapped_model, ...). Replaces the old Trainer.__init__
|
||||
# source rewrite. Priority: compiled forward -> True; else first class attr
|
||||
# in wrapper chain; else leave HF default. Issue #4982.
|
||||
if _forward_is_unsloth_compiled(model):
|
||||
_shadow_accepts_loss_kwargs(model, True)
|
||||
return "True (Unsloth compiled forward)"
|
||||
|
||||
init_function = init_function.replace(
|
||||
"def __init__", "def _unsloth___init__", 1
|
||||
)
|
||||
|
||||
# Respect an inner wrapped model's explicit accepts_loss_kwargs flag before inferring from forward(**kwargs).
|
||||
# https://github.com/unslothai/unsloth/issues/4982 Gemma4ForConditionalGeneration had issues with grad_acc
|
||||
init_function = init_function.replace(
|
||||
"self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n else:",
|
||||
"self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n"
|
||||
' elif hasattr(getattr(unwrapped_model, "model", None), "accepts_loss_kwargs"):\n'
|
||||
" self.model_accepts_loss_kwargs = unwrapped_model.model.accepts_loss_kwargs\n"
|
||||
" else:",
|
||||
)
|
||||
exec(init_function, globals())
|
||||
Trainer.__init__ = _unsloth___init__
|
||||
value, reason = _find_concrete_accepts_loss_kwargs(model)
|
||||
if value is None:
|
||||
return f"default (signature inspection, {reason})"
|
||||
_shadow_accepts_loss_kwargs(model, value)
|
||||
return f"{value} ({reason})"
|
||||
|
||||
|
||||
def patch_tokenizer(model, tokenizer):
|
||||
|
|
|
|||
|
|
@ -2695,7 +2695,8 @@ class FastLlamaModel:
|
|||
patch_saving_functions(model)
|
||||
Trainer._inner_training_loop = _fast_inner_training_loop
|
||||
|
||||
# Fix gradient accumulation
|
||||
# Fix gradient accumulation. See issue #4982.
|
||||
apply_accepts_loss_kwargs_fix(model)
|
||||
patch_gradient_accumulation_fix(Trainer)
|
||||
|
||||
# Save tokenizer for inference purposes
|
||||
|
|
|
|||
|
|
@ -1123,9 +1123,10 @@ class FastBaseModel:
|
|||
)
|
||||
patch_saving_functions(tokenizer, vision = True)
|
||||
|
||||
# Fix gradient accumulation
|
||||
# Fix gradient accumulation. See issue #4982.
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
apply_accepts_loss_kwargs_fix(model)
|
||||
patch_gradient_accumulation_fix(Trainer)
|
||||
|
||||
# Save tokenizer for inference purposes
|
||||
|
|
|
|||
Loading…
Reference in a new issue