diff --git a/tests/utils/test_qat.py b/tests/utils/test_qat.py index 1083712d7..08b8cd393 100644 --- a/tests/utils/test_qat.py +++ b/tests/utils/test_qat.py @@ -70,6 +70,11 @@ def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str): weight_fq_class = IntxFakeQuantizer min_in_features = 128 weight_only = True + elif qat_scheme == "cactus": + act_fq_class = None + weight_fq_class = IntxFakeQuantizer + min_in_features = 32 + weight_only = True else: raise ValueError(f"Unknown qat_scheme: {qat_scheme}") @@ -106,7 +111,7 @@ def _test_fake_quantizers_are_called( """ Verify that the fake quantizers are actually called when the model is called. """ - weight_only = qat_scheme == "int8" + weight_only = qat_scheme in ["int8", "cactus"] def _swap_fake_quantizers(model: torch.nn.Module): for name, child in model.named_children(): @@ -167,11 +172,11 @@ def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool): # TODO: there are bad interactions across tests right now, need to figure out # how to disable model caching before re-enabling this test -@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"]) +@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"]) def _test_full_model_fake_quantize(qat_scheme: str): _test_model_fake_quantize(qat_scheme, full_finetuning = True) -@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"]) +@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"]) def test_lora_model_fake_quantize(qat_scheme: str): _test_model_fake_quantize(qat_scheme, full_finetuning = False) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c330dcc32..7d99ec393 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -2727,6 +2727,55 @@ def _prepare_model_for_qat( qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, filter_fn)], ) + elif qat_scheme == "cactus": + try: + from torchao.quantization import IntxWeightOnlyConfig + except ImportError: + raise ImportError(TORCHAO_MSG) + + # IntxWeightOnlyConfig already defaults to + # `mapping_type = MappingType.SYMMETRIC`, so we intentionally do not + # import `MappingType` here. Matches the upstream Cactus runtime + # int8 / per-group-32 / symmetric weight-only configuration. + group_size = 32 + base_config = IntxWeightOnlyConfig( + weight_dtype = torch.int8, + granularity = PerGroup(group_size), + ) + filter_fn = ( + lambda m, _: isinstance(m, torch.nn.Linear) + and m.in_features >= group_size + and m.in_features % group_size == 0 + ) + # Warn if any Linear layer is skipped by the cactus filter because + # its in_features is not divisible by `group_size`. torchao's + # PerGroup(32) quantizer rejects non-divisible widths at + # `quantize_()` time, so the filter excludes those layers to keep + # the QAT prepare step from crashing. Surface that silently-skipped + # coverage gap to the user so they know some Linears will stay in + # full precision during training. + skipped_cactus_layers = [ + name + for name, module in model.named_modules() + if isinstance(module, torch.nn.Linear) + and module.in_features >= group_size + and module.in_features % group_size != 0 + ] + if skipped_cactus_layers: + preview = ", ".join(skipped_cactus_layers[:8]) + if len(skipped_cactus_layers) > 8: + preview += f", ... ({len(skipped_cactus_layers) - 8} more)" + warnings.warn( + f"Unsloth: qat_scheme='cactus' uses PerGroup({group_size}) " + "which requires in_features to be divisible by " + f"{group_size}. The following Linear layers will be kept " + f"in full precision during QAT: {preview}", + stacklevel = 2, + ) + torchao_config = TorchAOConfig( + qat_scheme = qat_scheme, + base_config_and_filter_fns = [(base_config, filter_fn)], + ) else: raise ValueError(f"Unexpected QAT scheme {qat_scheme}") assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"