From 7c5464ad71f93365773f68bb9469c9d55a859ad9 Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal <119044997+OnePunchMonk@users.noreply.github.com> Date: Wed, 15 Apr 2026 20:10:03 +0530 Subject: [PATCH] feat: Add cactus QAT scheme support (#4679) * feat: Add cactus QAT scheme support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test(qat): add tests for cactus QAT scheme and fix missing import * Fix cactus QAT scheme: correct MappingType import, tighten PerGroup filter - Drop the broken `from torchao.dtypes import MappingType` import. `MappingType` lives in `torchao.quantization` (and `torchao.quantization.quant_primitives`); it is not exported from `torchao.dtypes` in any supported torchao release (verified on 0.14, 0.16, 0.17). The previous code raised `ImportError` on every cactus call and was masked as a misleading 'torchao not found' error. - Since `IntxWeightOnlyConfig` already defaults `mapping_type` to `MappingType.SYMMETRIC`, drop the explicit kwarg entirely and remove the import. Behavior is unchanged. - Introduce a named `group_size = 32` constant (matches the int4 / fp8-int4 pattern in the surrounding branches) and add a `% group_size == 0` divisibility guard to the filter. `PerGroup(32)` requires `in_features % 32 == 0` at `quantize_()` time, otherwise torchao raises `ValueError: in_features (N) % group_size (32) must be == 0`. The old `in_features >= 32` filter would admit non-aligned widths (e.g. 33, 48, 65, 127) and crash `_prepare_model_for_qat` for those shapes. * Warn when cactus QAT skips non-divisible Linear layers Multiple reviewers flagged that the divisibility guard added in the previous commit can silently leave Linear layers in full precision when their in_features is not a multiple of 32. For currently supported Unsloth models (Qwen, Llama, Gemma, Mistral, Phi) every Linear width is already a multiple of 32/64/128 so this never triggers, but surfacing the coverage gap is cheap and avoids users assuming 100% QAT coverage when they bring a custom model with unusual shapes. Emit a UserWarning listing up to the first 8 skipped layers whenever the cactus filter excludes any Linear due to the modulo guard. This keeps the lenient silent-skip behavior (consistent with int4 / fp8-int4), but stops making it silent. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han --- tests/utils/test_qat.py | 11 ++++++--- unsloth/models/_utils.py | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_qat.py b/tests/utils/test_qat.py index 1083712d7..08b8cd393 100644 --- a/tests/utils/test_qat.py +++ b/tests/utils/test_qat.py @@ -70,6 +70,11 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str): weight_fq_class = IntxFakeQuantizer min_in_features = 128 weight_only = True + elif qat_scheme == "cactus": + act_fq_class = None + weight_fq_class = IntxFakeQuantizer + min_in_features = 32 + weight_only = True else: raise ValueError(f"Unknown qat_scheme: {qat_scheme}") @@ -106,7 +111,7 @@ def _test_fake_quantizers_are_called( """ Verify that the fake quantizers are actually called when the model is called. """ - weight_only = qat_scheme == "int8" + weight_only = qat_scheme in ["int8", "cactus"] def _swap_fake_quantizers(model: torch.nn.Module): for name, child in model.named_children(): @@ -167,11 +172,11 @@ def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool): # TODO: there are bad interactions across tests right now, need to figure out # how to disable model caching before re-enabling this test -@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"]) +@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"]) def _test_full_model_fake_quantize(qat_scheme: str): _test_model_fake_quantize(qat_scheme, full_finetuning = True) -@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"]) +@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"]) def test_lora_model_fake_quantize(qat_scheme: str): _test_model_fake_quantize(qat_scheme, full_finetuning = False) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c330dcc32..7d99ec393 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -2727,6 +2727,55 @@ def _prepare_model_for_qat( qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, filter_fn)], ) + elif qat_scheme == "cactus": + try: + from torchao.quantization import IntxWeightOnlyConfig + except ImportError: + raise ImportError(TORCHAO_MSG) + + # IntxWeightOnlyConfig already defaults to + # `mapping_type = MappingType.SYMMETRIC`, so we intentionally do not + # import `MappingType` here. Matches the upstream Cactus runtime + # int8 / per-group-32 / symmetric weight-only configuration. + group_size = 32 + base_config = IntxWeightOnlyConfig( + weight_dtype = torch.int8, + granularity = PerGroup(group_size), + ) + filter_fn = ( + lambda m, _: isinstance(m, torch.nn.Linear) + and m.in_features >= group_size + and m.in_features % group_size == 0 + ) + # Warn if any Linear layer is skipped by the cactus filter because + # its in_features is not divisible by `group_size`. torchao's + # PerGroup(32) quantizer rejects non-divisible widths at + # `quantize_()` time, so the filter excludes those layers to keep + # the QAT prepare step from crashing. Surface that silently-skipped + # coverage gap to the user so they know some Linears will stay in + # full precision during training. + skipped_cactus_layers = [ + name + for name, module in model.named_modules() + if isinstance(module, torch.nn.Linear) + and module.in_features >= group_size + and module.in_features % group_size != 0 + ] + if skipped_cactus_layers: + preview = ", ".join(skipped_cactus_layers[:8]) + if len(skipped_cactus_layers) > 8: + preview += f", ... ({len(skipped_cactus_layers) - 8} more)" + warnings.warn( + f"Unsloth: qat_scheme='cactus' uses PerGroup({group_size}) " + "which requires in_features to be divisible by " + f"{group_size}. The following Linear layers will be kept " + f"in full precision during QAT: {preview}", + stacklevel = 2, + ) + torchao_config = TorchAOConfig( + qat_scheme = qat_scheme, + base_config_and_filter_fns = [(base_config, filter_fn)], + ) else: raise ValueError(f"Unexpected QAT scheme {qat_scheme}") assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"