mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
feat: Add cactus QAT scheme support (#4679)
* feat: Add cactus QAT scheme support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test(qat): add tests for cactus QAT scheme and fix missing import * Fix cactus QAT scheme: correct MappingType import, tighten PerGroup filter - Drop the broken `from torchao.dtypes import MappingType` import. `MappingType` lives in `torchao.quantization` (and `torchao.quantization.quant_primitives`); it is not exported from `torchao.dtypes` in any supported torchao release (verified on 0.14, 0.16, 0.17). The previous code raised `ImportError` on every cactus call and was masked as a misleading 'torchao not found' error. - Since `IntxWeightOnlyConfig` already defaults `mapping_type` to `MappingType.SYMMETRIC`, drop the explicit kwarg entirely and remove the import. Behavior is unchanged. - Introduce a named `group_size = 32` constant (matches the int4 / fp8-int4 pattern in the surrounding branches) and add a `% group_size == 0` divisibility guard to the filter. `PerGroup(32)` requires `in_features % 32 == 0` at `quantize_()` time, otherwise torchao raises `ValueError: in_features (N) % group_size (32) must be == 0`. The old `in_features >= 32` filter would admit non-aligned widths (e.g. 33, 48, 65, 127) and crash `_prepare_model_for_qat` for those shapes. * Warn when cactus QAT skips non-divisible Linear layers Multiple reviewers flagged that the divisibility guard added in the previous commit can silently leave Linear layers in full precision when their in_features is not a multiple of 32. For currently supported Unsloth models (Qwen, Llama, Gemma, Mistral, Phi) every Linear width is already a multiple of 32/64/128 so this never triggers, but surfacing the coverage gap is cheap and avoids users assuming 100% QAT coverage when they bring a custom model with unusual shapes. Emit a UserWarning listing up to the first 8 skipped layers whenever the cactus filter excludes any Linear due to the modulo guard. This keeps the lenient silent-skip behavior (consistent with int4 / fp8-int4), but stops making it silent. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
parent
f18e9dddf0
commit
7c5464ad71
2 changed files with 57 additions and 3 deletions
|
|
@ -70,6 +70,11 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
|
||||||
weight_fq_class = IntxFakeQuantizer
|
weight_fq_class = IntxFakeQuantizer
|
||||||
min_in_features = 128
|
min_in_features = 128
|
||||||
weight_only = True
|
weight_only = True
|
||||||
|
elif qat_scheme == "cactus":
|
||||||
|
act_fq_class = None
|
||||||
|
weight_fq_class = IntxFakeQuantizer
|
||||||
|
min_in_features = 32
|
||||||
|
weight_only = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
|
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
|
||||||
|
|
||||||
|
|
@ -106,7 +111,7 @@ def _test_fake_quantizers_are_called(
|
||||||
"""
|
"""
|
||||||
Verify that the fake quantizers are actually called when the model is called.
|
Verify that the fake quantizers are actually called when the model is called.
|
||||||
"""
|
"""
|
||||||
weight_only = qat_scheme == "int8"
|
weight_only = qat_scheme in ["int8", "cactus"]
|
||||||
|
|
||||||
def _swap_fake_quantizers(model: torch.nn.Module):
|
def _swap_fake_quantizers(model: torch.nn.Module):
|
||||||
for name, child in model.named_children():
|
for name, child in model.named_children():
|
||||||
|
|
@ -167,11 +172,11 @@ def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool):
|
||||||
|
|
||||||
# TODO: there are bad interactions across tests right now, need to figure out
|
# TODO: there are bad interactions across tests right now, need to figure out
|
||||||
# how to disable model caching before re-enabling this test
|
# how to disable model caching before re-enabling this test
|
||||||
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"])
|
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"])
|
||||||
def _test_full_model_fake_quantize(qat_scheme: str):
|
def _test_full_model_fake_quantize(qat_scheme: str):
|
||||||
_test_model_fake_quantize(qat_scheme, full_finetuning = True)
|
_test_model_fake_quantize(qat_scheme, full_finetuning = True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"])
|
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"])
|
||||||
def test_lora_model_fake_quantize(qat_scheme: str):
|
def test_lora_model_fake_quantize(qat_scheme: str):
|
||||||
_test_model_fake_quantize(qat_scheme, full_finetuning = False)
|
_test_model_fake_quantize(qat_scheme, full_finetuning = False)
|
||||||
|
|
|
||||||
|
|
@ -2727,6 +2727,55 @@ def _prepare_model_for_qat(
|
||||||
qat_scheme = qat_scheme,
|
qat_scheme = qat_scheme,
|
||||||
base_config_and_filter_fns = [(base_config, filter_fn)],
|
base_config_and_filter_fns = [(base_config, filter_fn)],
|
||||||
)
|
)
|
||||||
|
elif qat_scheme == "cactus":
|
||||||
|
try:
|
||||||
|
from torchao.quantization import IntxWeightOnlyConfig
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(TORCHAO_MSG)
|
||||||
|
|
||||||
|
# IntxWeightOnlyConfig already defaults to
|
||||||
|
# `mapping_type = MappingType.SYMMETRIC`, so we intentionally do not
|
||||||
|
# import `MappingType` here. Matches the upstream Cactus runtime
|
||||||
|
# int8 / per-group-32 / symmetric weight-only configuration.
|
||||||
|
group_size = 32
|
||||||
|
base_config = IntxWeightOnlyConfig(
|
||||||
|
weight_dtype = torch.int8,
|
||||||
|
granularity = PerGroup(group_size),
|
||||||
|
)
|
||||||
|
filter_fn = (
|
||||||
|
lambda m, _: isinstance(m, torch.nn.Linear)
|
||||||
|
and m.in_features >= group_size
|
||||||
|
and m.in_features % group_size == 0
|
||||||
|
)
|
||||||
|
# Warn if any Linear layer is skipped by the cactus filter because
|
||||||
|
# its in_features is not divisible by `group_size`. torchao's
|
||||||
|
# PerGroup(32) quantizer rejects non-divisible widths at
|
||||||
|
# `quantize_()` time, so the filter excludes those layers to keep
|
||||||
|
# the QAT prepare step from crashing. Surface that silently-skipped
|
||||||
|
# coverage gap to the user so they know some Linears will stay in
|
||||||
|
# full precision during training.
|
||||||
|
skipped_cactus_layers = [
|
||||||
|
name
|
||||||
|
for name, module in model.named_modules()
|
||||||
|
if isinstance(module, torch.nn.Linear)
|
||||||
|
and module.in_features >= group_size
|
||||||
|
and module.in_features % group_size != 0
|
||||||
|
]
|
||||||
|
if skipped_cactus_layers:
|
||||||
|
preview = ", ".join(skipped_cactus_layers[:8])
|
||||||
|
if len(skipped_cactus_layers) > 8:
|
||||||
|
preview += f", ... ({len(skipped_cactus_layers) - 8} more)"
|
||||||
|
warnings.warn(
|
||||||
|
f"Unsloth: qat_scheme='cactus' uses PerGroup({group_size}) "
|
||||||
|
"which requires in_features to be divisible by "
|
||||||
|
f"{group_size}. The following Linear layers will be kept "
|
||||||
|
f"in full precision during QAT: {preview}",
|
||||||
|
stacklevel = 2,
|
||||||
|
)
|
||||||
|
torchao_config = TorchAOConfig(
|
||||||
|
qat_scheme = qat_scheme,
|
||||||
|
base_config_and_filter_fns = [(base_config, filter_fn)],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
|
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
|
||||||
assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"
|
assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue