feat: Add cactus QAT scheme support (#4679)

* feat: Add cactus QAT scheme support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test(qat): add tests for cactus QAT scheme and fix missing import

* Fix cactus QAT scheme: correct MappingType import, tighten PerGroup filter

- Drop the broken `from torchao.dtypes import MappingType` import. `MappingType`
  lives in `torchao.quantization` (and `torchao.quantization.quant_primitives`);
  it is not exported from `torchao.dtypes` in any supported torchao release
  (verified on 0.14, 0.16, 0.17). The previous code raised `ImportError` on
  every cactus call and was masked as a misleading 'torchao not found' error.
- Since `IntxWeightOnlyConfig` already defaults `mapping_type` to
  `MappingType.SYMMETRIC`, drop the explicit kwarg entirely and remove the
  import. Behavior is unchanged.
- Introduce a named `group_size = 32` constant (matches the int4 / fp8-int4
  pattern in the surrounding branches) and add a `% group_size == 0`
  divisibility guard to the filter. `PerGroup(32)` requires
  `in_features % 32 == 0` at `quantize_()` time, otherwise torchao raises
  `ValueError: in_features (N) % group_size (32) must be == 0`. The old
  `in_features >= 32` filter would admit non-aligned widths (e.g. 33, 48, 65,
  127) and crash `_prepare_model_for_qat` for those shapes.

* Warn when cactus QAT skips non-divisible Linear layers

Multiple reviewers flagged that the divisibility guard added in the
previous commit can silently leave Linear layers in full precision when
their in_features is not a multiple of 32. For currently supported
Unsloth models (Qwen, Llama, Gemma, Mistral, Phi) every Linear width is
already a multiple of 32/64/128 so this never triggers, but surfacing
the coverage gap is cheap and avoids users assuming 100% QAT coverage
when they bring a custom model with unusual shapes.

Emit a UserWarning listing up to the first 8 skipped layers whenever
the cactus filter excludes any Linear due to the modulo guard. This
keeps the lenient silent-skip behavior (consistent with int4 /
fp8-int4), but stops making it silent.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
Avaya Aggarwal 2026-04-15 20:10:03 +05:30 committed by GitHub
parent f18e9dddf0
commit 7c5464ad71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 57 additions and 3 deletions

View file

@ -70,6 +70,11 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
weight_fq_class = IntxFakeQuantizer weight_fq_class = IntxFakeQuantizer
min_in_features = 128 min_in_features = 128
weight_only = True weight_only = True
elif qat_scheme == "cactus":
act_fq_class = None
weight_fq_class = IntxFakeQuantizer
min_in_features = 32
weight_only = True
else: else:
raise ValueError(f"Unknown qat_scheme: {qat_scheme}") raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
@ -106,7 +111,7 @@ def _test_fake_quantizers_are_called(
""" """
Verify that the fake quantizers are actually called when the model is called. Verify that the fake quantizers are actually called when the model is called.
""" """
weight_only = qat_scheme == "int8" weight_only = qat_scheme in ["int8", "cactus"]
def _swap_fake_quantizers(model: torch.nn.Module): def _swap_fake_quantizers(model: torch.nn.Module):
for name, child in model.named_children(): for name, child in model.named_children():
@ -167,11 +172,11 @@ def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool):
# TODO: there are bad interactions across tests right now, need to figure out # TODO: there are bad interactions across tests right now, need to figure out
# how to disable model caching before re-enabling this test # how to disable model caching before re-enabling this test
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"]) @pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"])
def _test_full_model_fake_quantize(qat_scheme: str): def _test_full_model_fake_quantize(qat_scheme: str):
_test_model_fake_quantize(qat_scheme, full_finetuning = True) _test_model_fake_quantize(qat_scheme, full_finetuning = True)
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"]) @pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"])
def test_lora_model_fake_quantize(qat_scheme: str): def test_lora_model_fake_quantize(qat_scheme: str):
_test_model_fake_quantize(qat_scheme, full_finetuning = False) _test_model_fake_quantize(qat_scheme, full_finetuning = False)

View file

@ -2727,6 +2727,55 @@ def _prepare_model_for_qat(
qat_scheme = qat_scheme, qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)], base_config_and_filter_fns = [(base_config, filter_fn)],
) )
elif qat_scheme == "cactus":
try:
from torchao.quantization import IntxWeightOnlyConfig
except ImportError:
raise ImportError(TORCHAO_MSG)
# IntxWeightOnlyConfig already defaults to
# `mapping_type = MappingType.SYMMETRIC`, so we intentionally do not
# import `MappingType` here. Matches the upstream Cactus runtime
# int8 / per-group-32 / symmetric weight-only configuration.
group_size = 32
base_config = IntxWeightOnlyConfig(
weight_dtype = torch.int8,
granularity = PerGroup(group_size),
)
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
and m.in_features % group_size == 0
)
# Warn if any Linear layer is skipped by the cactus filter because
# its in_features is not divisible by `group_size`. torchao's
# PerGroup(32) quantizer rejects non-divisible widths at
# `quantize_()` time, so the filter excludes those layers to keep
# the QAT prepare step from crashing. Surface that silently-skipped
# coverage gap to the user so they know some Linears will stay in
# full precision during training.
skipped_cactus_layers = [
name
for name, module in model.named_modules()
if isinstance(module, torch.nn.Linear)
and module.in_features >= group_size
and module.in_features % group_size != 0
]
if skipped_cactus_layers:
preview = ", ".join(skipped_cactus_layers[:8])
if len(skipped_cactus_layers) > 8:
preview += f", ... ({len(skipped_cactus_layers) - 8} more)"
warnings.warn(
f"Unsloth: qat_scheme='cactus' uses PerGroup({group_size}) "
"which requires in_features to be divisible by "
f"{group_size}. The following Linear layers will be kept "
f"in full precision during QAT: {preview}",
stacklevel = 2,
)
torchao_config = TorchAOConfig(
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
else: else:
raise ValueError(f"Unexpected QAT scheme {qat_scheme}") raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}" assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"