mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
* feat: Add cactus QAT scheme support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test(qat): add tests for cactus QAT scheme and fix missing import * Fix cactus QAT scheme: correct MappingType import, tighten PerGroup filter - Drop the broken `from torchao.dtypes import MappingType` import. `MappingType` lives in `torchao.quantization` (and `torchao.quantization.quant_primitives`); it is not exported from `torchao.dtypes` in any supported torchao release (verified on 0.14, 0.16, 0.17). The previous code raised `ImportError` on every cactus call and was masked as a misleading 'torchao not found' error. - Since `IntxWeightOnlyConfig` already defaults `mapping_type` to `MappingType.SYMMETRIC`, drop the explicit kwarg entirely and remove the import. Behavior is unchanged. - Introduce a named `group_size = 32` constant (matches the int4 / fp8-int4 pattern in the surrounding branches) and add a `% group_size == 0` divisibility guard to the filter. `PerGroup(32)` requires `in_features % 32 == 0` at `quantize_()` time, otherwise torchao raises `ValueError: in_features (N) % group_size (32) must be == 0`. The old `in_features >= 32` filter would admit non-aligned widths (e.g. 33, 48, 65, 127) and crash `_prepare_model_for_qat` for those shapes. * Warn when cactus QAT skips non-divisible Linear layers Multiple reviewers flagged that the divisibility guard added in the previous commit can silently leave Linear layers in full precision when their in_features is not a multiple of 32. For currently supported Unsloth models (Qwen, Llama, Gemma, Mistral, Phi) every Linear width is already a multiple of 32/64/128 so this never triggers, but surfacing the coverage gap is cheap and avoids users assuming 100% QAT coverage when they bring a custom model with unusual shapes. Emit a UserWarning listing up to the first 8 skipped layers whenever the cactus filter excludes any Linear due to the modulo guard. This keeps the lenient silent-skip behavior (consistent with int4 / fp8-int4), but stops making it silent. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
182 lines
7.1 KiB
Python
182 lines
7.1 KiB
Python
from unsloth import FastLanguageModel
|
|
|
|
from typing import Dict
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
try:
|
|
from torchao.quantization.qat import FakeQuantizedLinear
|
|
from torchao.quantization.qat.fake_quantizer import (
|
|
FakeQuantizerBase,
|
|
Float8FakeQuantizer,
|
|
Int4WeightFakeQuantizer,
|
|
IntxFakeQuantizer,
|
|
)
|
|
except ImportError:
|
|
print(
|
|
"Missing torchao import, please install or upgrade torchao with: pip install 'torchao>=0.15.0'"
|
|
)
|
|
|
|
|
|
class _CountingFakeQuantizer(torch.nn.Module):
|
|
"""
|
|
Dummy fake quantizer that counts the number of times it has been called.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.count = 0
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self.count += 1
|
|
return x
|
|
|
|
|
|
def _get_model(qat_scheme: str, full_finetuning: bool):
|
|
"""
|
|
Return a 2-tuple of (model, tokenizer), where the model has been configured
|
|
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
|
|
"""
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name = "unsloth/Qwen3-1.7B",
|
|
load_in_4bit = False,
|
|
full_finetuning = full_finetuning,
|
|
qat_scheme = qat_scheme if full_finetuning else None,
|
|
)
|
|
if not full_finetuning:
|
|
model = FastLanguageModel.get_peft_model(
|
|
model,
|
|
qat_scheme = qat_scheme,
|
|
)
|
|
return model, tokenizer
|
|
|
|
|
|
def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
|
|
"""
|
|
Verify that the given linear contains fake quantizers according to the `qat_scheme`.
|
|
"""
|
|
weight_only = False
|
|
if qat_scheme == "fp8-int4":
|
|
act_fq_class = Float8FakeQuantizer
|
|
weight_fq_class = Int4WeightFakeQuantizer
|
|
min_in_features = 128
|
|
elif qat_scheme == "fp8-fp8":
|
|
act_fq_class = Float8FakeQuantizer
|
|
weight_fq_class = Float8FakeQuantizer
|
|
min_in_features = -1
|
|
elif qat_scheme == "int8":
|
|
act_fq_class = None
|
|
weight_fq_class = IntxFakeQuantizer
|
|
min_in_features = 128
|
|
weight_only = True
|
|
elif qat_scheme == "cactus":
|
|
act_fq_class = None
|
|
weight_fq_class = IntxFakeQuantizer
|
|
min_in_features = 32
|
|
weight_only = True
|
|
else:
|
|
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
|
|
|
|
# Check base layer activations and weights
|
|
base_layer = getattr(linear, "base_layer", linear)
|
|
if base_layer.in_features >= min_in_features:
|
|
assert isinstance(base_layer, FakeQuantizedLinear)
|
|
if not weight_only:
|
|
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
|
|
assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
|
|
|
|
# Check lora A and B (only for full_finetuning=False)
|
|
if hasattr(linear, "lora_A") and hasattr(linear, "lora_B"):
|
|
lora_A = linear.lora_A.default
|
|
lora_B = linear.lora_B.default
|
|
if lora_A.in_features >= min_in_features:
|
|
assert isinstance(lora_A, FakeQuantizedLinear)
|
|
if not weight_only:
|
|
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
|
|
assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
|
|
if lora_B.in_features >= min_in_features:
|
|
assert isinstance(lora_B, FakeQuantizedLinear)
|
|
if not weight_only:
|
|
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
|
|
assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
|
|
|
|
|
|
def _test_fake_quantizers_are_called(
|
|
model: torch.nn.Module,
|
|
example_inputs: Dict,
|
|
full_finetuning: bool,
|
|
qat_scheme: str,
|
|
):
|
|
"""
|
|
Verify that the fake quantizers are actually called when the model is called.
|
|
"""
|
|
weight_only = qat_scheme in ["int8", "cactus"]
|
|
|
|
def _swap_fake_quantizers(model: torch.nn.Module):
|
|
for name, child in model.named_children():
|
|
if isinstance(child, FakeQuantizerBase):
|
|
setattr(model, name, _CountingFakeQuantizer())
|
|
|
|
def _assert_fake_quantizers_are_called(model: torch.nn.Module):
|
|
for name, child in model.named_children():
|
|
if full_finetuning:
|
|
if isinstance(child, FakeQuantizedLinear):
|
|
if not weight_only:
|
|
assert child.activation_fake_quantizer.count == 1
|
|
assert child.weight_fake_quantizer.count == 1
|
|
else:
|
|
# For LoRA, we only fake quantize the input activations once per block:
|
|
# For self_attn, we only fake quantize the q_proj's input activations
|
|
# For mlp, we only fake quantize the gate_proj's input activations
|
|
if name == "self_attn":
|
|
base_layer = child.q_proj.base_layer
|
|
if not weight_only:
|
|
assert hasattr(base_layer, "activation_fake_quantizer")
|
|
assert base_layer.activation_fake_quantizer.count == 1
|
|
elif name == "mlp":
|
|
base_layer = child.gate_proj.base_layer
|
|
if not weight_only:
|
|
assert hasattr(base_layer, "activation_fake_quantizer")
|
|
assert base_layer.activation_fake_quantizer.count == 1
|
|
elif isinstance(child, FakeQuantizedLinear):
|
|
# Weight fake quantizers should always be called
|
|
assert child.weight_fake_quantizer.count == 1
|
|
|
|
for k, v in example_inputs.items():
|
|
example_inputs[k] = v.cuda()
|
|
model.apply(_swap_fake_quantizers)
|
|
model(**example_inputs)
|
|
model.apply(_assert_fake_quantizers_are_called)
|
|
|
|
|
|
def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool):
|
|
"""
|
|
Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
|
|
"""
|
|
model, tokenizer = _get_model(qat_scheme, full_finetuning)
|
|
if full_finetuning:
|
|
model = model.model
|
|
else:
|
|
model = model.base_model.model.model
|
|
for layer in model.layers:
|
|
_test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)
|
|
_test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)
|
|
_test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)
|
|
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
|
|
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
|
|
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
|
|
inputs = tokenizer("How are you?", return_tensors = "pt")
|
|
_test_fake_quantizers_are_called(model, inputs, full_finetuning, qat_scheme)
|
|
|
|
|
|
# TODO: there are bad interactions across tests right now, need to figure out
|
|
# how to disable model caching before re-enabling this test
|
|
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"])
|
|
def _test_full_model_fake_quantize(qat_scheme: str):
|
|
_test_model_fake_quantize(qat_scheme, full_finetuning = True)
|
|
|
|
|
|
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8", "cactus"])
|
|
def test_lora_model_fake_quantize(qat_scheme: str):
|
|
_test_model_fake_quantize(qat_scheme, full_finetuning = False)
|