mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fix QAT + LoRA fast path, add tests (#3307)
**Summary:** The existing QAT + LoRA path only applied fake quantization to the original slow path, but the default is the fast path that calls unsloth's fast LoRA primitives. This commit integrates fake quantization into these fast primitives as well, and add unit tests to assert that fake quantization is actually taking place. **Test Plan:** Unit tests: ``` pytest tests/utils/test_qat.py ``` End-to-end test: https://gist.github.com/andrewor14/6360dd69b5784c71c46e80c14f53e6b6 Full fine-tuning Llama3.1-8B with and without QAT + LoRA on yahma/alpaca-cleaned for 1 epoch: - Batch size = 8 (no grad accum) - Learning rate = 2e-4 - Quantization scheme = int4 weight only (with bf16 activations) Wikitext perplexity: - Baseline = int4 quantized model finetuned without QAT - QAT int4 quantized model (with this PR) achieved 33% lower perplexity than the int4 baseline - QAT int4 quantized model without this PR was worse than the int4 baseline ``` ==> unsloth_model_lora_baseline_output/lm_eval_float.log <== | | |none | 0|word_perplexity|↓ |7.5551|± | N/A| ==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.7655|± | N/A| ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.3548|± | N/A| ```
This commit is contained in:
parent
70f790a8e4
commit
3ffb3bdcfe
4 changed files with 208 additions and 2 deletions
154
tests/utils/test_qat.py
Normal file
154
tests/utils/test_qat.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from unsloth import FastLanguageModel
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torchao.quantization.qat import FakeQuantizedLinear
|
||||
from torchao.quantization.qat.fake_quantizer import (
|
||||
FakeQuantizerBase,
|
||||
Float8FakeQuantizer,
|
||||
Int4WeightPreshuffledFakeQuantizer,
|
||||
)
|
||||
|
||||
|
||||
class _CountingFakeQuantizer(torch.nn.Module):
|
||||
"""
|
||||
Dummy fake quantizer that counts the number of times it has been called.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.count = 0
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.count += 1
|
||||
return x
|
||||
|
||||
|
||||
def _get_model(qat_scheme: str, full_finetuning: bool):
|
||||
"""
|
||||
Return a 2-tuple of (model, tokenizer), where the model has been configured
|
||||
to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
|
||||
"""
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name = "unsloth/Qwen3-1.7B",
|
||||
load_in_4bit = False,
|
||||
full_finetuning = full_finetuning,
|
||||
qat_scheme = qat_scheme if full_finetuning else None,
|
||||
)
|
||||
if not full_finetuning:
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
qat_scheme = qat_scheme,
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
|
||||
"""
|
||||
Verify that the given linear contains fake quantizers according to the `qat_scheme`.
|
||||
"""
|
||||
if qat_scheme == "fp8-int4":
|
||||
act_fq_class = Float8FakeQuantizer
|
||||
weight_fq_class = Int4WeightPreshuffledFakeQuantizer
|
||||
min_in_features = 128
|
||||
elif qat_scheme == "fp8-fp8":
|
||||
act_fq_class = Float8FakeQuantizer
|
||||
weight_fq_class = Float8FakeQuantizer
|
||||
min_in_features = -1
|
||||
else:
|
||||
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
|
||||
|
||||
# Check base layer activations and weights
|
||||
base_layer = getattr(linear, "base_layer", linear)
|
||||
if base_layer.in_features >= min_in_features:
|
||||
assert isinstance(base_layer, FakeQuantizedLinear)
|
||||
assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
|
||||
assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
|
||||
|
||||
# Check lora A and B (only for full_finetuning=False)
|
||||
if hasattr(linear, "lora_A") and hasattr(linear, "lora_B"):
|
||||
lora_A = linear.lora_A.default
|
||||
lora_B = linear.lora_B.default
|
||||
if lora_A.in_features >= min_in_features:
|
||||
assert isinstance(lora_A, FakeQuantizedLinear)
|
||||
assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
|
||||
assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
|
||||
if lora_B.in_features >= min_in_features:
|
||||
assert isinstance(lora_B, FakeQuantizedLinear)
|
||||
assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
|
||||
assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
|
||||
|
||||
|
||||
def _test_fake_quantizers_are_called(
|
||||
model: torch.nn.Module,
|
||||
example_inputs: Dict,
|
||||
full_finetuning: bool,
|
||||
):
|
||||
"""
|
||||
Verify that the fake quantizers are actually called when the model is called.
|
||||
"""
|
||||
def _swap_fake_quantizers(model: torch.nn.Module):
|
||||
for name, child in model.named_children():
|
||||
if isinstance(child, FakeQuantizerBase):
|
||||
setattr(model, name, _CountingFakeQuantizer())
|
||||
|
||||
def _assert_fake_quantizers_are_called(model: torch.nn.Module):
|
||||
for name, child in model.named_children():
|
||||
if full_finetuning:
|
||||
if isinstance(child, FakeQuantizedLinear):
|
||||
assert child.activation_fake_quantizer.count == 1
|
||||
assert child.weight_fake_quantizer.count == 1
|
||||
else:
|
||||
# For LoRA, we only fake quantize the input activations once per block:
|
||||
# For self_attn, we only fake quantize the q_proj's input activations
|
||||
# For mlp, we only fake quantize the gate_proj's input activations
|
||||
if name == "self_attn":
|
||||
base_layer = child.q_proj.base_layer
|
||||
assert hasattr(base_layer, "activation_fake_quantizer")
|
||||
assert base_layer.activation_fake_quantizer.count == 1
|
||||
elif name == "mlp":
|
||||
base_layer = child.gate_proj.base_layer
|
||||
assert hasattr(base_layer, "activation_fake_quantizer")
|
||||
assert base_layer.activation_fake_quantizer.count == 1
|
||||
elif isinstance(child, FakeQuantizedLinear):
|
||||
# Weight fake quantizers should always be called
|
||||
assert child.weight_fake_quantizer.count == 1
|
||||
|
||||
for k, v in example_inputs.items():
|
||||
example_inputs[k] = v.cuda()
|
||||
model.apply(_swap_fake_quantizers)
|
||||
model(**example_inputs)
|
||||
model.apply(_assert_fake_quantizers_are_called)
|
||||
|
||||
|
||||
def _test_model_fake_quantize(qat_scheme: bool, full_finetuning: bool):
|
||||
"""
|
||||
Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
|
||||
"""
|
||||
model, tokenizer = _get_model(qat_scheme, full_finetuning)
|
||||
if full_finetuning:
|
||||
model = model.model
|
||||
else:
|
||||
model = model.base_model.model.model
|
||||
for layer in model.layers:
|
||||
_test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
|
||||
_test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
|
||||
inputs = tokenizer("How are you?", return_tensors="pt")
|
||||
_test_fake_quantizers_are_called(model, inputs, full_finetuning)
|
||||
|
||||
|
||||
# TODO: there are bad interactions across tests right now, need to figure out
|
||||
# how to disable model caching before re-enabling this test
|
||||
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
|
||||
def _test_full_model_fake_quantize(qat_scheme: bool):
|
||||
_test_model_fake_quantize(qat_scheme, full_finetuning=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8"])
|
||||
def test_lora_model_fake_quantize(qat_scheme: bool):
|
||||
_test_model_fake_quantize(qat_scheme, full_finetuning=False)
|
||||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import torch
|
||||
from .utils import (
|
||||
_maybe_fake_quantize_activations,
|
||||
fast_dequantize,
|
||||
QUANT_STATE,
|
||||
get_lora_parameters,
|
||||
|
|
@ -175,6 +176,7 @@ pass
|
|||
|
||||
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
||||
def apply_lora_mlp_swiglu(self, X, inplace = True):
|
||||
X = _maybe_fake_quantize_activations(X, self.gate_proj)
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
|
@ -190,6 +192,7 @@ pass
|
|||
|
||||
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
||||
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
|
||||
X = _maybe_fake_quantize_activations(X, self.gate_proj)
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
|
@ -205,6 +208,7 @@ pass
|
|||
|
||||
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
|
||||
def apply_lora_mlp_geglu_approx(self, X):
|
||||
X = _maybe_fake_quantize_activations(X, self.gate_proj)
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
|
@ -360,6 +364,7 @@ pass
|
|||
|
||||
|
||||
def apply_lora_qkv(self, X, inplace = True):
|
||||
X = _maybe_fake_quantize_activations(X, self.q_proj)
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
|
|
@ -453,6 +458,7 @@ pass
|
|||
|
||||
|
||||
def apply_lora_o(self, X):
|
||||
X = _maybe_fake_quantize_activations(X, self.o_proj)
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
return O
|
||||
|
|
|
|||
|
|
@ -188,10 +188,19 @@ torch_bfloat16 = torch.bfloat16
|
|||
def QUANT_STATE(W): return getattr(W, "quant_state", None)
|
||||
|
||||
def get_lora_parameters(proj):
|
||||
"""
|
||||
Return a 5-tuple of (weight, weight quant_state, lora A, lora B, and lora scale).
|
||||
If QAT is enabled, additionally fake quantize the base layer and lora weights.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
||||
W = base_layer.weight
|
||||
|
||||
# Optionally apply fake quantization to base layer weights for QAT
|
||||
weight_fake_quantizer = getattr(base_layer, "weight_fake_quantizer", None)
|
||||
if weight_fake_quantizer is not None:
|
||||
W = weight_fake_quantizer(W)
|
||||
|
||||
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
if getattr(proj, "disable_adapters", True) or proj.merged:
|
||||
return W, getattr(W, "quant_state", None), None, None, None
|
||||
|
|
@ -201,11 +210,23 @@ def get_lora_parameters(proj):
|
|||
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
|
||||
adapter = adapter[0]
|
||||
|
||||
# Optionally apply fake quantization to lora weights for QAT
|
||||
lora_A_linear = proj.lora_A[adapter]
|
||||
lora_B_linear = proj.lora_B[adapter]
|
||||
lora_A_fake_quantizer = getattr(lora_A_linear, "weight_fake_quantizer", None)
|
||||
lora_B_fake_quantizer = getattr(lora_B_linear, "weight_fake_quantizer", None)
|
||||
A = lora_A_linear.weight
|
||||
B = lora_B_linear.weight
|
||||
if lora_A_fake_quantizer is not None:
|
||||
A = lora_A_fake_quantizer(A)
|
||||
if lora_B_fake_quantizer is not None:
|
||||
B = lora_B_fake_quantizer(B)
|
||||
|
||||
return (
|
||||
W,
|
||||
getattr(W, "quant_state", None),
|
||||
proj.lora_A [adapter].weight,
|
||||
proj.lora_B [adapter].weight,
|
||||
A,
|
||||
B,
|
||||
proj.scaling[adapter],
|
||||
)
|
||||
pass
|
||||
|
|
@ -235,6 +256,21 @@ def get_lora_parameters_bias(proj):
|
|||
)
|
||||
pass
|
||||
|
||||
|
||||
def _maybe_fake_quantize_activations(X: torch.Tensor, proj: torch.nn.Module) -> torch.Tensor:
|
||||
"""
|
||||
If QAT is enabled, fake quantize the input activations.
|
||||
Otherwise, just return the input activations as is.
|
||||
Weights are fake quantized separately in `get_lora_parameters`.
|
||||
"""
|
||||
base_layer = getattr(proj, "base_layer", proj)
|
||||
activation_fake_quantizer = getattr(base_layer, "activation_fake_quantizer", None)
|
||||
if activation_fake_quantizer is not None:
|
||||
X = activation_fake_quantizer(X)
|
||||
return X
|
||||
pass
|
||||
|
||||
|
||||
# INTEL GPU Specific Logic
|
||||
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
|
||||
@torch.inference_mode
|
||||
|
|
|
|||
|
|
@ -1651,6 +1651,8 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
|
|||
from torchao.quantization import (
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
PerRow,
|
||||
quantize_,
|
||||
)
|
||||
|
|
@ -1662,6 +1664,14 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
|
|||
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
|
||||
elif qat_scheme == "fp8-fp8":
|
||||
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||
elif qat_scheme == "int8-int4":
|
||||
group_size = 32
|
||||
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
|
||||
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
|
||||
elif qat_scheme == "int4":
|
||||
group_size = 128
|
||||
base_config = Int4WeightOnlyConfig(group_size=group_size)
|
||||
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
|
||||
else:
|
||||
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue