Fix QAT + LoRA fast path, add tests (#3307)

**Summary:** The existing QAT + LoRA path only applied fake
quantization to the original slow path, but the default is the
fast path that calls unsloth's fast LoRA primitives. This commit
integrates fake quantization into these fast primitives as well,
and add unit tests to assert that fake quantization is actually
taking place.

**Test Plan:**

Unit tests:
```
pytest tests/utils/test_qat.py
```

End-to-end test: https://gist.github.com/andrewor14/6360dd69b5784c71c46e80c14f53e6b6

Full fine-tuning Llama3.1-8B with and without QAT + LoRA on yahma/alpaca-cleaned for 1 epoch:

- Batch size = 8 (no grad accum)
- Learning rate = 2e-4
- Quantization scheme = int4 weight only (with bf16 activations)

Wikitext perplexity:

- Baseline = int4 quantized model finetuned without QAT
- QAT int4 quantized model (with this PR) achieved 33% lower perplexity than the int4 baseline
- QAT int4 quantized model without this PR was worse than the int4 baseline

```
==> unsloth_model_lora_baseline_output/lm_eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |7.5551|±  |   N/A|

==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.7655|±  |   N/A|

==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.3548|±  |   N/A|
```
This commit is contained in:
andrewor14 2025-09-17 18:18:17 -04:00 committed by GitHub
parent 70f790a8e4
commit 3ffb3bdcfe
4 changed files with 208 additions and 2 deletions

154
tests/utils/test_qat.py Normal file
View file

@ -0,0 +1,154 @@
from unsloth import FastLanguageModel
from typing import Dict
import pytest
import torch
from torchao.quantization.qat import FakeQuantizedLinear
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizerBase,
Float8FakeQuantizer,
Int4WeightPreshuffledFakeQuantizer,
)
class _CountingFakeQuantizer(torch.nn.Module):
"""
Dummy fake quantizer that counts the number of times it has been called.
"""
def __init__(self):
super().__init__()
self.count = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.count += 1
return x
def _get_model(qat_scheme: str, full_finetuning: bool):
"""
Return a 2-tuple of (model, tokenizer), where the model has been configured
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
"""
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-1.7B",
load_in_4bit = False,
full_finetuning = full_finetuning,
qat_scheme = qat_scheme if full_finetuning else None,
)
if not full_finetuning:
model = FastLanguageModel.get_peft_model(
model,
qat_scheme = qat_scheme,
)
return model, tokenizer
def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
"""
Verify that the given linear contains fake quantizers according to the `qat_scheme`.
"""
if qat_scheme == "fp8-int4":
act_fq_class = Float8FakeQuantizer
weight_fq_class = Int4WeightPreshuffledFakeQuantizer
min_in_features = 128
elif qat_scheme == "fp8-fp8":
act_fq_class = Float8FakeQuantizer
weight_fq_class = Float8FakeQuantizer
min_in_features = -1
else:
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
# Check base layer activations and weights
base_layer = getattr(linear, "base_layer", linear)
if base_layer.in_features >= min_in_features:
assert isinstance(base_layer, FakeQuantizedLinear)
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
# Check lora A and B (only for full_finetuning=False)
if hasattr(linear, "lora_A") and hasattr(linear, "lora_B"):
lora_A = linear.lora_A.default
lora_B = linear.lora_B.default
if lora_A.in_features >= min_in_features:
assert isinstance(lora_A, FakeQuantizedLinear)
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
if lora_B.in_features >= min_in_features:
assert isinstance(lora_B, FakeQuantizedLinear)
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
def _test_fake_quantizers_are_called(
model: torch.nn.Module,
example_inputs: Dict,
full_finetuning: bool,
):
"""
Verify that the fake quantizers are actually called when the model is called.
"""
def _swap_fake_quantizers(model: torch.nn.Module):
for name, child in model.named_children():
if isinstance(child, FakeQuantizerBase):
setattr(model, name, _CountingFakeQuantizer())
def _assert_fake_quantizers_are_called(model: torch.nn.Module):
for name, child in model.named_children():
if full_finetuning:
if isinstance(child, FakeQuantizedLinear):
assert child.activation_fake_quantizer.count == 1
assert child.weight_fake_quantizer.count == 1
else:
# For LoRA, we only fake quantize the input activations once per block:
# For self_attn, we only fake quantize the q_proj's input activations
# For mlp, we only fake quantize the gate_proj's input activations
if name == "self_attn":
base_layer = child.q_proj.base_layer
assert hasattr(base_layer, "activation_fake_quantizer")
assert base_layer.activation_fake_quantizer.count == 1
elif name == "mlp":
base_layer = child.gate_proj.base_layer
assert hasattr(base_layer, "activation_fake_quantizer")
assert base_layer.activation_fake_quantizer.count == 1
elif isinstance(child, FakeQuantizedLinear):
# Weight fake quantizers should always be called
assert child.weight_fake_quantizer.count == 1
for k, v in example_inputs.items():
example_inputs[k] = v.cuda()
model.apply(_swap_fake_quantizers)
model(**example_inputs)
model.apply(_assert_fake_quantizers_are_called)
def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
"""
Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
"""
model, tokenizer = _get_model(qat_scheme, full_finetuning)
if full_finetuning:
model = model.model
else:
model = model.base_model.model.model
for layer in model.layers:
_test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
inputs = tokenizer("How are you?", return_tensors="pt")
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
# TODO: there are bad interactions across tests right now, need to figure out
# how to disable model caching before re-enabling this test
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
def _test_full_model_fake_quantize(qat_scheme: bool):
_test_model_fake_quantize(qat_scheme, full_finetuning=True)
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
def test_lora_model_fake_quantize(qat_scheme: bool):
_test_model_fake_quantize(qat_scheme, full_finetuning=False)

View file

@ -14,6 +14,7 @@
import torch
from .utils import (
_maybe_fake_quantize_activations,
fast_dequantize,
QUANT_STATE,
get_lora_parameters,
@ -175,6 +176,7 @@ pass
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def apply_lora_mlp_swiglu(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@ -190,6 +192,7 @@ pass
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@ -205,6 +208,7 @@ pass
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
def apply_lora_mlp_geglu_approx(self, X):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
@ -360,6 +364,7 @@ pass
def apply_lora_qkv(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.q_proj)
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
@ -453,6 +458,7 @@ pass
def apply_lora_o(self, X):
X = _maybe_fake_quantize_activations(X, self.o_proj)
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O

View file

@ -188,10 +188,19 @@ torch_bfloat16 = torch.bfloat16
def QUANT_STATE(W): return getattr(W, "quant_state", None)
def get_lora_parameters(proj):
"""
Return a 5-tuple of (weight, weight quant_state, lora A, lora B, and lora scale).
If QAT is enabled, additionally fake quantize the base layer and lora weights.
"""
# For DPO or disabled adapters
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
# Optionally apply fake quantization to base layer weights for QAT
weight_fake_quantizer = getattr(base_layer, "weight_fake_quantizer", None)
if weight_fake_quantizer is not None:
W = weight_fake_quantizer(W)
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
if getattr(proj, "disable_adapters", True) or proj.merged:
return W, getattr(W, "quant_state", None), None, None, None
@ -201,11 +210,23 @@ def get_lora_parameters(proj):
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
adapter = adapter[0]
# Optionally apply fake quantization to lora weights for QAT
lora_A_linear = proj.lora_A[adapter]
lora_B_linear = proj.lora_B[adapter]
lora_A_fake_quantizer = getattr(lora_A_linear, "weight_fake_quantizer", None)
lora_B_fake_quantizer = getattr(lora_B_linear, "weight_fake_quantizer", None)
A = lora_A_linear.weight
B = lora_B_linear.weight
if lora_A_fake_quantizer is not None:
A = lora_A_fake_quantizer(A)
if lora_B_fake_quantizer is not None:
B = lora_B_fake_quantizer(B)
return (
W,
getattr(W, "quant_state", None),
proj.lora_A [adapter].weight,
proj.lora_B [adapter].weight,
A,
B,
proj.scaling[adapter],
)
pass
@ -235,6 +256,21 @@ def get_lora_parameters_bias(proj):
)
pass
def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> torch.Tensor:
"""
If QAT is enabled, fake quantize the input activations.
Otherwise, just return the input activations as is.
Weights are fake quantized separately in `get_lora_parameters`.
"""
base_layer = getattr(proj, "base_layer", proj)
activation_fake_quantizer = getattr(base_layer, "activation_fake_quantizer", None)
if activation_fake_quantizer is not None:
X = activation_fake_quantizer(X)
return X
pass
# INTEL GPU Specific Logic
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
@torch.inference_mode

View file

@ -1651,6 +1651,8 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
from torchao.quantization import (
Float8DynamicActivationInt4WeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Int8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
PerRow,
quantize_,
)
@ -1662,6 +1664,14 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
elif qat_scheme == "fp8-fp8":
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
elif qat_scheme == "int8-int4":
group_size = 32
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
elif qat_scheme == "int4":
group_size = 128
base_config = Int4WeightOnlyConfig(group_size=group_size)
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
else:
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
pass