mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Auto-enable padding-free SFT (#3672)
* implement (sdpa, xformers, fa2) sample packing * attention dispatching * ddp working OOTB with CLI * packed SWA and softcap support * enable batch flattening * LGPL license headers * mask packed sequence boundaries * auto-enable sample packing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add explicit toggle for sample packing * Add explicit toggle for sample packing * Update __init__.py * Update unsloth/kernels/rope_embedding.py * Update unsloth/kernels/rope_embedding.py * remove grad output clones; restore deleted FastLanguageModel arg * fix * restore rope embedding clones * xformers mask cache * implement (sdpa, xformers, fa2) sample packing * attention dispatching * ddp working OOTB with CLI * packed SWA and softcap support * enable batch flattening * LGPL license headers * mask packed sequence boundaries * auto-enable sample packing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add explicit toggle for sample packing * Add explicit toggle for sample packing * Update __init__.py * Update unsloth/kernels/rope_embedding.py * Update unsloth/kernels/rope_embedding.py * remove grad output clones; restore deleted FastLanguageModel arg * fix * restore rope embedding clones * xformers mask cache * add back accidental deletion * Update unsloth/kernels/rope_embedding.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix merge conflicts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add **kwargs * add back clobbered * Update rope_embedding.py * Update rope_embedding.py * simplify trl warnings filter * docstring * nit * bugfix * add padding-free seqlen metadata * auto-enable padding free * gemma2 disable * Apply suggestion from @danielhanchen * Update trainer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update trainer.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
parent
496f84ff6b
commit
75e0d7ce62
4 changed files with 281 additions and 20 deletions
|
|
@ -16,7 +16,9 @@
|
|||
from unsloth import FastLanguageModel
|
||||
from unsloth.utils import attention_dispatch as attention_dispatch_utils
|
||||
from unsloth.utils.packing import (
|
||||
configure_padding_free,
|
||||
configure_sample_packing,
|
||||
enable_padding_free_metadata,
|
||||
enable_sample_packing,
|
||||
mask_packed_sequence_boundaries,
|
||||
)
|
||||
|
|
@ -150,6 +152,14 @@ def test_configure_sample_packing():
|
|||
assert config.remove_unused_columns is False
|
||||
|
||||
|
||||
def test_configure_padding_free():
|
||||
config = SimpleNamespace(remove_unused_columns = True)
|
||||
configure_padding_free(config)
|
||||
|
||||
assert config.padding_free is True
|
||||
assert config.remove_unused_columns is False
|
||||
|
||||
|
||||
class _DummyChild(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -177,6 +187,20 @@ class _DummyTrainer:
|
|||
)
|
||||
|
||||
|
||||
class _PaddingFreeCollator:
|
||||
def __init__(self):
|
||||
self.padding_free = True
|
||||
self.return_position_ids = False
|
||||
self.calls = 0
|
||||
|
||||
def torch_call(self, examples):
|
||||
self.calls += 1
|
||||
return {
|
||||
"input_ids": torch.tensor([[0]], dtype = torch.long),
|
||||
"examples_seen": self.calls,
|
||||
}
|
||||
|
||||
|
||||
def test_enable_sample_packing():
|
||||
model = _DummyModel()
|
||||
trainer = _DummyTrainer()
|
||||
|
|
@ -251,6 +275,34 @@ def test_enable_sample_packing_trl_collator(tmp_path):
|
|||
trainer.accelerator.free_memory()
|
||||
|
||||
|
||||
def test_enable_padding_free_metadata():
|
||||
model = _DummyModel()
|
||||
trainer = SimpleNamespace(
|
||||
args = SimpleNamespace(remove_unused_columns = True),
|
||||
data_collator = _PaddingFreeCollator(),
|
||||
)
|
||||
|
||||
enable_padding_free_metadata(model, trainer)
|
||||
|
||||
assert getattr(model, "_unsloth_allow_packed_overlength") is True
|
||||
assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
|
||||
|
||||
collator = trainer.data_collator
|
||||
assert collator.return_position_ids is True
|
||||
assert getattr(collator, "_unsloth_padding_free_lengths_wrapped") is True
|
||||
|
||||
examples = [
|
||||
{"input_ids": [0, 1, 2]},
|
||||
{"input_ids": [3, 4]},
|
||||
]
|
||||
batch = collator.torch_call(examples)
|
||||
assert torch.equal(
|
||||
batch["packed_seq_lengths"],
|
||||
torch.tensor([3, 2], dtype = torch.int32),
|
||||
)
|
||||
assert trainer.args.remove_unused_columns is False
|
||||
|
||||
|
||||
def test_packing_sdpa(tmp_path):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device)
|
||||
|
|
|
|||
|
|
@ -23,13 +23,19 @@ import trl
|
|||
import inspect
|
||||
from trl import SFTTrainer
|
||||
from . import is_bfloat16_supported
|
||||
from unsloth.utils import configure_sample_packing, enable_sample_packing
|
||||
from unsloth.utils import (
|
||||
configure_padding_free,
|
||||
configure_sample_packing,
|
||||
enable_padding_free_metadata,
|
||||
enable_sample_packing,
|
||||
)
|
||||
from unsloth_zoo.training_utils import (
|
||||
unsloth_train as _unsloth_train,
|
||||
)
|
||||
from unsloth_zoo.vision_utils import (
|
||||
UnslothVisionDataCollator,
|
||||
)
|
||||
from unsloth_zoo.hf_utils import get_transformers_model_type
|
||||
from packaging.version import Version
|
||||
import dataclasses
|
||||
|
||||
|
|
@ -47,6 +53,19 @@ _AUTO_PACKING_ENV_DISABLED = os.environ.get(
|
|||
"UNSLOTH_DISABLE_AUTO_PACKING", ""
|
||||
).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
_AUTO_PADDING_FREE_ENV_DISABLED = os.environ.get(
|
||||
"UNSLOTH_DISABLE_AUTO_PADDING_FREE", ""
|
||||
).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
# [TODO]
|
||||
# Below cannot work with padding-free
|
||||
_PADDING_FREE_BLOCK_LIST = {
|
||||
"gemma2", # - gemma2: Uses slow_attention_softcapping which has torch.compile issues
|
||||
"gpt_oss", # - gpt_oss: Uses Flex Attention which doesn't handle padding_free correctly
|
||||
"mistral", # - mistral: Unfortunately I think sliding window attention doesn't work correctly?
|
||||
}
|
||||
|
||||
|
||||
def _should_auto_pack(config) -> bool:
|
||||
if config is None or _AUTO_PACKING_ENV_DISABLED:
|
||||
|
|
@ -56,6 +75,14 @@ def _should_auto_pack(config) -> bool:
|
|||
return not getattr(config, "_unsloth_disable_auto_packing", False)
|
||||
|
||||
|
||||
def _should_auto_padding_free(config) -> bool:
|
||||
if config is None or _AUTO_PADDING_FREE_ENV_DISABLED:
|
||||
return False
|
||||
if getattr(config, "packing", False):
|
||||
return False
|
||||
return not getattr(config, "padding_free", False)
|
||||
|
||||
|
||||
def _disable_sample_packing(config):
|
||||
if config is None:
|
||||
return
|
||||
|
|
@ -269,11 +296,36 @@ def _patch_sft_trainer_auto_packing(trl_module):
|
|||
else:
|
||||
config_arg = kwargs.get("args")
|
||||
|
||||
# Check if model type is unsupported for padding_free
|
||||
model = kwargs.get("model")
|
||||
is_unsupported_model = False
|
||||
is_vlm = False
|
||||
if model is not None:
|
||||
model_config = getattr(model, "config", None)
|
||||
if model_config is not None:
|
||||
model_types = get_transformers_model_type(model_config)
|
||||
# Blocklist: models that don't work correctly with padding_free
|
||||
is_unsupported_model = any(
|
||||
x in PADDING_FREE_BLOCKLIST for x in model_types
|
||||
)
|
||||
|
||||
# Check if VLM
|
||||
architectures = getattr(model_config, "architectures", None)
|
||||
if architectures is None:
|
||||
architectures = []
|
||||
is_vlm = any(
|
||||
x.endswith("ForConditionalGeneration") for x in architectures
|
||||
)
|
||||
is_vlm = is_vlm or hasattr(model_config, "vision_config")
|
||||
|
||||
processing_class = kwargs.get("processing_class") or kwargs.get("tokenizer")
|
||||
data_collator = kwargs.get("data_collator")
|
||||
|
||||
blocked = data_collator is not None or isinstance(
|
||||
processing_class, ProcessorMixin
|
||||
# We also disable vision language models for padding free collators
|
||||
blocked = (
|
||||
data_collator is not None
|
||||
or isinstance(processing_class, ProcessorMixin)
|
||||
or is_vlm
|
||||
)
|
||||
if blocked and _should_auto_pack(config_arg):
|
||||
reason = (
|
||||
|
|
@ -292,6 +344,22 @@ def _patch_sft_trainer_auto_packing(trl_module):
|
|||
auto_pack_active = True
|
||||
logger.info("Unsloth: Sample packing auto-enabled for SFTTrainer instance.")
|
||||
|
||||
auto_padding_free_active = False
|
||||
padding_free_requested = getattr(config_arg, "padding_free", None) is True
|
||||
if not blocked:
|
||||
if padding_free_requested:
|
||||
configure_padding_free(config_arg)
|
||||
elif not is_unsupported_gemma and _should_auto_padding_free(config_arg):
|
||||
configure_padding_free(config_arg)
|
||||
auto_padding_free_active = True
|
||||
logger.info(
|
||||
"Unsloth: Padding-free batching auto-enabled for SFTTrainer instance."
|
||||
)
|
||||
elif is_unsupported_gemma and _should_auto_padding_free(config_arg):
|
||||
logger.info(
|
||||
"Unsloth: Padding-free batching auto-disabled for Gemma 2 (requires flash attention)."
|
||||
)
|
||||
|
||||
try:
|
||||
original_init(self, *args, **kwargs)
|
||||
except ValueError as exc:
|
||||
|
|
@ -307,11 +375,24 @@ def _patch_sft_trainer_auto_packing(trl_module):
|
|||
raise
|
||||
|
||||
trainer_args = getattr(self, "args", None)
|
||||
if auto_pack_active and _should_auto_pack(trainer_args):
|
||||
trainer_packing = bool(trainer_args and getattr(trainer_args, "packing", False))
|
||||
trainer_padding_free = bool(
|
||||
trainer_args and getattr(trainer_args, "padding_free", False)
|
||||
)
|
||||
|
||||
if trainer_packing and (auto_pack_active or _should_auto_pack(trainer_args)):
|
||||
enable_sample_packing(self.model, self)
|
||||
print(
|
||||
"🦥 Unsloth: Packing enabled - training is >2x faster and uses less VRAM!"
|
||||
)
|
||||
elif trainer_padding_free:
|
||||
enable_padding_free_metadata(self.model, self)
|
||||
message = (
|
||||
"🦥 Unsloth: Padding-free auto-enabled, enabling faster training."
|
||||
if auto_padding_free_active
|
||||
else "🦥 Unsloth: Padding-free enabled, enabling faster training."
|
||||
)
|
||||
print(message)
|
||||
|
||||
sft_trainer.__init__ = new_init
|
||||
sft_trainer._unsloth_auto_packing_wrapped = True
|
||||
|
|
@ -343,7 +424,6 @@ def _patch_trl_trainer():
|
|||
except:
|
||||
continue
|
||||
|
||||
if not _AUTO_PACKING_ENV_DISABLED:
|
||||
_patch_sft_trainer_auto_packing(trl)
|
||||
_patch_sft_trainer_auto_packing(trl)
|
||||
|
||||
trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True
|
||||
|
|
|
|||
|
|
@ -13,7 +13,13 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
from .packing import configure_sample_packing, enable_sample_packing
|
||||
from .packing import (
|
||||
configure_padding_free,
|
||||
configure_sample_packing,
|
||||
enable_padding_free_metadata,
|
||||
enable_sample_packing,
|
||||
mark_allow_overlength,
|
||||
)
|
||||
from .attention_dispatch import (
|
||||
AttentionConfig,
|
||||
AttentionContext,
|
||||
|
|
@ -27,7 +33,10 @@ from .attention_dispatch import (
|
|||
|
||||
__all__ = [
|
||||
"configure_sample_packing",
|
||||
"configure_padding_free",
|
||||
"enable_sample_packing",
|
||||
"enable_padding_free_metadata",
|
||||
"mark_allow_overlength",
|
||||
"AttentionConfig",
|
||||
"AttentionContext",
|
||||
"FLASH_VARLEN",
|
||||
|
|
|
|||
|
|
@ -84,6 +84,19 @@ def _ensure_trl_warning_filter():
|
|||
_TRL_FILTER_INSTALLED = True
|
||||
|
||||
|
||||
def mark_allow_overlength(module):
|
||||
"""Mark a module hierarchy so padding-free batches can exceed max_seq_length."""
|
||||
if module is None:
|
||||
return
|
||||
if hasattr(module, "max_seq_length"):
|
||||
setattr(module, "_unsloth_allow_packed_overlength", True)
|
||||
children = getattr(module, "children", None)
|
||||
if children is None:
|
||||
return
|
||||
for child in children():
|
||||
mark_allow_overlength(child)
|
||||
|
||||
|
||||
def configure_sample_packing(config):
|
||||
"""Mutate an ``SFTConfig`` so TRL prepares packed batches."""
|
||||
_ensure_trl_warning_filter()
|
||||
|
|
@ -92,25 +105,37 @@ def configure_sample_packing(config):
|
|||
setattr(config, "remove_unused_columns", False)
|
||||
|
||||
|
||||
def enable_sample_packing(model, trainer):
|
||||
def configure_padding_free(config):
|
||||
"""Mutate an ``SFTConfig`` so TRL enables padding-free batching without packing."""
|
||||
_ensure_trl_warning_filter()
|
||||
setattr(config, "padding_free", True)
|
||||
if hasattr(config, "remove_unused_columns"):
|
||||
setattr(config, "remove_unused_columns", False)
|
||||
|
||||
|
||||
def enable_sample_packing(
|
||||
model,
|
||||
trainer,
|
||||
*,
|
||||
sequence_lengths_key: str = "seq_lengths",
|
||||
) -> None:
|
||||
"""Enable runtime support for packed batches on an existing trainer."""
|
||||
if model is None or trainer is None:
|
||||
raise ValueError("model and trainer must not be None")
|
||||
|
||||
def _mark_allow_overlength(module):
|
||||
if hasattr(module, "max_seq_length"):
|
||||
setattr(module, "_unsloth_allow_packed_overlength", True)
|
||||
for child in module.children():
|
||||
_mark_allow_overlength(child)
|
||||
mark_allow_overlength(model)
|
||||
|
||||
_mark_allow_overlength(model)
|
||||
if hasattr(trainer, "args") and hasattr(trainer.args, "remove_unused_columns"):
|
||||
trainer.args.remove_unused_columns = False
|
||||
|
||||
collator = getattr(trainer, "data_collator", None)
|
||||
if (
|
||||
collator is None
|
||||
or not hasattr(collator, "torch_call")
|
||||
or getattr(collator, "_unsloth_packing_wrapped", False)
|
||||
):
|
||||
if collator is None or not hasattr(collator, "torch_call"):
|
||||
return
|
||||
if getattr(collator, "_unsloth_packing_wrapped", False):
|
||||
return
|
||||
|
||||
if hasattr(collator, "padding_free"):
|
||||
collator.padding_free = True
|
||||
if hasattr(collator, "return_position_ids"):
|
||||
collator.return_position_ids = True
|
||||
|
||||
|
|
@ -120,18 +145,107 @@ def enable_sample_packing(model, trainer):
|
|||
batch = original_torch_call(examples)
|
||||
if examples and isinstance(examples[0], dict):
|
||||
seq_lengths: list[int] = []
|
||||
per_example_counts: list[int] = []
|
||||
for example in examples:
|
||||
seq_lengths.extend(example["seq_lengths"])
|
||||
lengths = example.get(sequence_lengths_key)
|
||||
if isinstance(lengths, Iterable):
|
||||
numeric_lengths = [int(length) for length in lengths]
|
||||
seq_lengths.extend(numeric_lengths)
|
||||
per_example_counts.append(len(numeric_lengths))
|
||||
else:
|
||||
per_example_counts.append(0)
|
||||
if seq_lengths:
|
||||
batch["packed_seq_lengths"] = torch.tensor(
|
||||
seq_lengths, dtype = torch.int32
|
||||
)
|
||||
|
||||
position_ids = batch.get("position_ids")
|
||||
input_ids = batch.get("input_ids")
|
||||
if position_ids is None and input_ids is not None:
|
||||
position_ids = torch.zeros_like(
|
||||
input_ids, dtype = torch.long, device = input_ids.device
|
||||
)
|
||||
|
||||
if position_ids is not None and input_ids is not None:
|
||||
seq_index = 0
|
||||
for row_idx, count in enumerate(per_example_counts):
|
||||
cursor = 0
|
||||
for _ in range(count):
|
||||
length = seq_lengths[seq_index]
|
||||
if length > 0:
|
||||
position_ids[row_idx, cursor : cursor + length] = (
|
||||
torch.arange(
|
||||
length,
|
||||
dtype = torch.long,
|
||||
device = position_ids.device,
|
||||
)
|
||||
)
|
||||
cursor += length
|
||||
seq_index += 1
|
||||
batch["position_ids"] = position_ids
|
||||
|
||||
if "attention_mask" in batch and getattr(
|
||||
collator, "return_position_ids", False
|
||||
):
|
||||
batch.pop("attention_mask")
|
||||
return batch
|
||||
|
||||
collator.torch_call = torch_call_with_lengths
|
||||
collator._unsloth_packing_wrapped = True
|
||||
|
||||
|
||||
def enable_padding_free_metadata(model, trainer):
|
||||
"""Inject seq-length metadata when padding-free batching is enabled without packing."""
|
||||
|
||||
trainer_args = getattr(trainer, "args", None)
|
||||
if (
|
||||
trainer_args is not None
|
||||
and hasattr(trainer_args, "remove_unused_columns")
|
||||
and trainer_args.remove_unused_columns
|
||||
):
|
||||
trainer_args.remove_unused_columns = False
|
||||
|
||||
_ensure_trl_warning_filter()
|
||||
collator = getattr(trainer, "data_collator", None)
|
||||
if (
|
||||
collator is None
|
||||
or getattr(collator, "_unsloth_padding_free_lengths_wrapped", False)
|
||||
or not getattr(collator, "padding_free", False)
|
||||
):
|
||||
# Nothing to do if there's no collator, we've already wrapped it, or padding-free is off.
|
||||
return
|
||||
|
||||
mark_allow_overlength(model)
|
||||
if hasattr(collator, "return_position_ids"):
|
||||
collator.return_position_ids = True
|
||||
|
||||
original_torch_call = collator.torch_call
|
||||
|
||||
def torch_call_with_padding_free_metadata(examples: Sequence[dict]):
|
||||
seq_lengths: list[int] = []
|
||||
if examples and isinstance(examples[0], dict):
|
||||
for example in examples:
|
||||
lengths = example.get("seq_lengths")
|
||||
if lengths is None:
|
||||
ids = example.get("input_ids")
|
||||
if ids is None:
|
||||
continue
|
||||
lengths = [len(ids)]
|
||||
example["seq_lengths"] = lengths
|
||||
seq_lengths.extend(lengths)
|
||||
|
||||
batch = original_torch_call(examples)
|
||||
if seq_lengths:
|
||||
batch["packed_seq_lengths"] = torch.tensor(
|
||||
seq_lengths,
|
||||
dtype = torch.int32,
|
||||
)
|
||||
return batch
|
||||
|
||||
collator.torch_call = torch_call_with_padding_free_metadata
|
||||
collator._unsloth_padding_free_lengths_wrapped = True
|
||||
|
||||
|
||||
def get_packed_info_from_kwargs(
|
||||
kwargs: dict,
|
||||
device: torch.device,
|
||||
|
|
@ -261,6 +375,12 @@ def mask_packed_sequence_boundaries(
|
|||
|
||||
__all__ = [
|
||||
"configure_sample_packing",
|
||||
"configure_padding_free",
|
||||
"enable_sample_packing",
|
||||
"enable_padding_free_metadata",
|
||||
"mark_allow_overlength",
|
||||
"get_packed_info_from_kwargs",
|
||||
"build_xformers_block_causal_mask",
|
||||
"build_sdpa_packed_attention_mask",
|
||||
"mask_packed_sequence_boundaries",
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue