Auto-enable padding-free SFT (#3672)

* implement (sdpa, xformers, fa2) sample packing

* attention dispatching

* ddp working OOTB with CLI

* packed SWA and softcap support

* enable batch flattening

* LGPL license headers

* mask packed sequence boundaries

* auto-enable sample packing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add explicit toggle for sample packing

* Add explicit toggle for sample packing

* Update __init__.py

* Update unsloth/kernels/rope_embedding.py

* Update unsloth/kernels/rope_embedding.py

* remove grad output clones; restore deleted FastLanguageModel arg

* fix

* restore rope embedding clones

* xformers mask cache

* implement (sdpa, xformers, fa2) sample packing

* attention dispatching

* ddp working OOTB with CLI

* packed SWA and softcap support

* enable batch flattening

* LGPL license headers

* mask packed sequence boundaries

* auto-enable sample packing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add explicit toggle for sample packing

* Add explicit toggle for sample packing

* Update __init__.py

* Update unsloth/kernels/rope_embedding.py

* Update unsloth/kernels/rope_embedding.py

* remove grad output clones; restore deleted FastLanguageModel arg

* fix

* restore rope embedding clones

* xformers mask cache

* add back accidental deletion

* Update unsloth/kernels/rope_embedding.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix merge conflicts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add **kwargs

* add back clobbered

* Update rope_embedding.py

* Update rope_embedding.py

* simplify trl warnings filter

* docstring

* nit

* bugfix

* add padding-free seqlen metadata

* auto-enable padding free

* gemma2 disable

* Apply suggestion from @danielhanchen

* Update trainer.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update trainer.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
Dan Saunders 2025-12-10 06:07:29 -05:00 committed by GitHub
parent 496f84ff6b
commit 75e0d7ce62
4 changed files with 281 additions and 20 deletions

View file

@ -16,7 +16,9 @@
from unsloth import FastLanguageModel
from unsloth.utils import attention_dispatch as attention_dispatch_utils
from unsloth.utils.packing import (
configure_padding_free,
configure_sample_packing,
enable_padding_free_metadata,
enable_sample_packing,
mask_packed_sequence_boundaries,
)
@ -150,6 +152,14 @@ def test_configure_sample_packing():
assert config.remove_unused_columns is False
def test_configure_padding_free():
config = SimpleNamespace(remove_unused_columns = True)
configure_padding_free(config)
assert config.padding_free is True
assert config.remove_unused_columns is False
class _DummyChild(torch.nn.Module):
def __init__(self):
super().__init__()
@ -177,6 +187,20 @@ class _DummyTrainer:
)
class _PaddingFreeCollator:
def __init__(self):
self.padding_free = True
self.return_position_ids = False
self.calls = 0
def torch_call(self, examples):
self.calls += 1
return {
"input_ids": torch.tensor([[0]], dtype = torch.long),
"examples_seen": self.calls,
}
def test_enable_sample_packing():
model = _DummyModel()
trainer = _DummyTrainer()
@ -251,6 +275,34 @@ def test_enable_sample_packing_trl_collator(tmp_path):
trainer.accelerator.free_memory()
def test_enable_padding_free_metadata():
model = _DummyModel()
trainer = SimpleNamespace(
args = SimpleNamespace(remove_unused_columns = True),
data_collator = _PaddingFreeCollator(),
)
enable_padding_free_metadata(model, trainer)
assert getattr(model, "_unsloth_allow_packed_overlength") is True
assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
collator = trainer.data_collator
assert collator.return_position_ids is True
assert getattr(collator, "_unsloth_padding_free_lengths_wrapped") is True
examples = [
{"input_ids": [0, 1, 2]},
{"input_ids": [3, 4]},
]
batch = collator.torch_call(examples)
assert torch.equal(
batch["packed_seq_lengths"],
torch.tensor([3, 2], dtype = torch.int32),
)
assert trainer.args.remove_unused_columns is False
def test_packing_sdpa(tmp_path):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device)

View file

@ -23,13 +23,19 @@ import trl
import inspect
from trl import SFTTrainer
from . import is_bfloat16_supported
from unsloth.utils import configure_sample_packing, enable_sample_packing
from unsloth.utils import (
configure_padding_free,
configure_sample_packing,
enable_padding_free_metadata,
enable_sample_packing,
)
from unsloth_zoo.training_utils import (
unsloth_train as _unsloth_train,
)
from unsloth_zoo.vision_utils import (
UnslothVisionDataCollator,
)
from unsloth_zoo.hf_utils import get_transformers_model_type
from packaging.version import Version
import dataclasses
@ -47,6 +53,19 @@ _AUTO_PACKING_ENV_DISABLED = os.environ.get(
"UNSLOTH_DISABLE_AUTO_PACKING", ""
).strip().lower() in {"1", "true", "yes", "on"}
_AUTO_PADDING_FREE_ENV_DISABLED = os.environ.get(
"UNSLOTH_DISABLE_AUTO_PADDING_FREE", ""
).strip().lower() in {"1", "true", "yes", "on"}
# [TODO]
# Below cannot work with padding-free
_PADDING_FREE_BLOCK_LIST = {
"gemma2", # - gemma2: Uses slow_attention_softcapping which has torch.compile issues
"gpt_oss", # - gpt_oss: Uses Flex Attention which doesn't handle padding_free correctly
"mistral", # - mistral: Unfortunately I think sliding window attention doesn't work correctly?
}
def _should_auto_pack(config) -> bool:
if config is None or _AUTO_PACKING_ENV_DISABLED:
@ -56,6 +75,14 @@ def _should_auto_pack(config) -> bool:
return not getattr(config, "_unsloth_disable_auto_packing", False)
def _should_auto_padding_free(config) -> bool:
if config is None or _AUTO_PADDING_FREE_ENV_DISABLED:
return False
if getattr(config, "packing", False):
return False
return not getattr(config, "padding_free", False)
def _disable_sample_packing(config):
if config is None:
return
@ -269,11 +296,36 @@ def _patch_sft_trainer_auto_packing(trl_module):
else:
config_arg = kwargs.get("args")
# Check if model type is unsupported for padding_free
model = kwargs.get("model")
is_unsupported_model = False
is_vlm = False
if model is not None:
model_config = getattr(model, "config", None)
if model_config is not None:
model_types = get_transformers_model_type(model_config)
# Blocklist: models that don't work correctly with padding_free
is_unsupported_model = any(
x in PADDING_FREE_BLOCKLIST for x in model_types
)
# Check if VLM
architectures = getattr(model_config, "architectures", None)
if architectures is None:
architectures = []
is_vlm = any(
x.endswith("ForConditionalGeneration") for x in architectures
)
is_vlm = is_vlm or hasattr(model_config, "vision_config")
processing_class = kwargs.get("processing_class") or kwargs.get("tokenizer")
data_collator = kwargs.get("data_collator")
blocked = data_collator is not None or isinstance(
processing_class, ProcessorMixin
# We also disable vision language models for padding free collators
blocked = (
data_collator is not None
or isinstance(processing_class, ProcessorMixin)
or is_vlm
)
if blocked and _should_auto_pack(config_arg):
reason = (
@ -292,6 +344,22 @@ def _patch_sft_trainer_auto_packing(trl_module):
auto_pack_active = True
logger.info("Unsloth: Sample packing auto-enabled for SFTTrainer instance.")
auto_padding_free_active = False
padding_free_requested = getattr(config_arg, "padding_free", None) is True
if not blocked:
if padding_free_requested:
configure_padding_free(config_arg)
elif not is_unsupported_gemma and _should_auto_padding_free(config_arg):
configure_padding_free(config_arg)
auto_padding_free_active = True
logger.info(
"Unsloth: Padding-free batching auto-enabled for SFTTrainer instance."
)
elif is_unsupported_gemma and _should_auto_padding_free(config_arg):
logger.info(
"Unsloth: Padding-free batching auto-disabled for Gemma 2 (requires flash attention)."
)
try:
original_init(self, *args, **kwargs)
except ValueError as exc:
@ -307,11 +375,24 @@ def _patch_sft_trainer_auto_packing(trl_module):
raise
trainer_args = getattr(self, "args", None)
if auto_pack_active and _should_auto_pack(trainer_args):
trainer_packing = bool(trainer_args and getattr(trainer_args, "packing", False))
trainer_padding_free = bool(
trainer_args and getattr(trainer_args, "padding_free", False)
)
if trainer_packing and (auto_pack_active or _should_auto_pack(trainer_args)):
enable_sample_packing(self.model, self)
print(
"🦥 Unsloth: Packing enabled - training is >2x faster and uses less VRAM!"
)
elif trainer_padding_free:
enable_padding_free_metadata(self.model, self)
message = (
"🦥 Unsloth: Padding-free auto-enabled, enabling faster training."
if auto_padding_free_active
else "🦥 Unsloth: Padding-free enabled, enabling faster training."
)
print(message)
sft_trainer.__init__ = new_init
sft_trainer._unsloth_auto_packing_wrapped = True
@ -343,7 +424,6 @@ def _patch_trl_trainer():
except:
continue
if not _AUTO_PACKING_ENV_DISABLED:
_patch_sft_trainer_auto_packing(trl)
_patch_sft_trainer_auto_packing(trl)
trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True

View file

@ -13,7 +13,13 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from .packing import configure_sample_packing, enable_sample_packing
from .packing import (
configure_padding_free,
configure_sample_packing,
enable_padding_free_metadata,
enable_sample_packing,
mark_allow_overlength,
)
from .attention_dispatch import (
AttentionConfig,
AttentionContext,
@ -27,7 +33,10 @@ from .attention_dispatch import (
__all__ = [
"configure_sample_packing",
"configure_padding_free",
"enable_sample_packing",
"enable_padding_free_metadata",
"mark_allow_overlength",
"AttentionConfig",
"AttentionContext",
"FLASH_VARLEN",

View file

@ -84,6 +84,19 @@ def _ensure_trl_warning_filter():
_TRL_FILTER_INSTALLED = True
def mark_allow_overlength(module):
"""Mark a module hierarchy so padding-free batches can exceed max_seq_length."""
if module is None:
return
if hasattr(module, "max_seq_length"):
setattr(module, "_unsloth_allow_packed_overlength", True)
children = getattr(module, "children", None)
if children is None:
return
for child in children():
mark_allow_overlength(child)
def configure_sample_packing(config):
"""Mutate an ``SFTConfig`` so TRL prepares packed batches."""
_ensure_trl_warning_filter()
@ -92,25 +105,37 @@ def configure_sample_packing(config):
setattr(config, "remove_unused_columns", False)
def enable_sample_packing(model, trainer):
def configure_padding_free(config):
"""Mutate an ``SFTConfig`` so TRL enables padding-free batching without packing."""
_ensure_trl_warning_filter()
setattr(config, "padding_free", True)
if hasattr(config, "remove_unused_columns"):
setattr(config, "remove_unused_columns", False)
def enable_sample_packing(
model,
trainer,
*,
sequence_lengths_key: str = "seq_lengths",
) -> None:
"""Enable runtime support for packed batches on an existing trainer."""
if model is None or trainer is None:
raise ValueError("model and trainer must not be None")
def _mark_allow_overlength(module):
if hasattr(module, "max_seq_length"):
setattr(module, "_unsloth_allow_packed_overlength", True)
for child in module.children():
_mark_allow_overlength(child)
mark_allow_overlength(model)
_mark_allow_overlength(model)
if hasattr(trainer, "args") and hasattr(trainer.args, "remove_unused_columns"):
trainer.args.remove_unused_columns = False
collator = getattr(trainer, "data_collator", None)
if (
collator is None
or not hasattr(collator, "torch_call")
or getattr(collator, "_unsloth_packing_wrapped", False)
):
if collator is None or not hasattr(collator, "torch_call"):
return
if getattr(collator, "_unsloth_packing_wrapped", False):
return
if hasattr(collator, "padding_free"):
collator.padding_free = True
if hasattr(collator, "return_position_ids"):
collator.return_position_ids = True
@ -120,18 +145,107 @@ def enable_sample_packing(model, trainer):
batch = original_torch_call(examples)
if examples and isinstance(examples[0], dict):
seq_lengths: list[int] = []
per_example_counts: list[int] = []
for example in examples:
seq_lengths.extend(example["seq_lengths"])
lengths = example.get(sequence_lengths_key)
if isinstance(lengths, Iterable):
numeric_lengths = [int(length) for length in lengths]
seq_lengths.extend(numeric_lengths)
per_example_counts.append(len(numeric_lengths))
else:
per_example_counts.append(0)
if seq_lengths:
batch["packed_seq_lengths"] = torch.tensor(
seq_lengths, dtype = torch.int32
)
position_ids = batch.get("position_ids")
input_ids = batch.get("input_ids")
if position_ids is None and input_ids is not None:
position_ids = torch.zeros_like(
input_ids, dtype = torch.long, device = input_ids.device
)
if position_ids is not None and input_ids is not None:
seq_index = 0
for row_idx, count in enumerate(per_example_counts):
cursor = 0
for _ in range(count):
length = seq_lengths[seq_index]
if length > 0:
position_ids[row_idx, cursor : cursor + length] = (
torch.arange(
length,
dtype = torch.long,
device = position_ids.device,
)
)
cursor += length
seq_index += 1
batch["position_ids"] = position_ids
if "attention_mask" in batch and getattr(
collator, "return_position_ids", False
):
batch.pop("attention_mask")
return batch
collator.torch_call = torch_call_with_lengths
collator._unsloth_packing_wrapped = True
def enable_padding_free_metadata(model, trainer):
"""Inject seq-length metadata when padding-free batching is enabled without packing."""
trainer_args = getattr(trainer, "args", None)
if (
trainer_args is not None
and hasattr(trainer_args, "remove_unused_columns")
and trainer_args.remove_unused_columns
):
trainer_args.remove_unused_columns = False
_ensure_trl_warning_filter()
collator = getattr(trainer, "data_collator", None)
if (
collator is None
or getattr(collator, "_unsloth_padding_free_lengths_wrapped", False)
or not getattr(collator, "padding_free", False)
):
# Nothing to do if there's no collator, we've already wrapped it, or padding-free is off.
return
mark_allow_overlength(model)
if hasattr(collator, "return_position_ids"):
collator.return_position_ids = True
original_torch_call = collator.torch_call
def torch_call_with_padding_free_metadata(examples: Sequence[dict]):
seq_lengths: list[int] = []
if examples and isinstance(examples[0], dict):
for example in examples:
lengths = example.get("seq_lengths")
if lengths is None:
ids = example.get("input_ids")
if ids is None:
continue
lengths = [len(ids)]
example["seq_lengths"] = lengths
seq_lengths.extend(lengths)
batch = original_torch_call(examples)
if seq_lengths:
batch["packed_seq_lengths"] = torch.tensor(
seq_lengths,
dtype = torch.int32,
)
return batch
collator.torch_call = torch_call_with_padding_free_metadata
collator._unsloth_padding_free_lengths_wrapped = True
def get_packed_info_from_kwargs(
kwargs: dict,
device: torch.device,
@ -261,6 +375,12 @@ def mask_packed_sequence_boundaries(
__all__ = [
"configure_sample_packing",
"configure_padding_free",
"enable_sample_packing",
"enable_padding_free_metadata",
"mark_allow_overlength",
"get_packed_info_from_kwargs",
"build_xformers_block_causal_mask",
"build_sdpa_packed_attention_mask",
"mask_packed_sequence_boundaries",
]