Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027) (#5034)

* Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027)

FastLanguageModel.from_pretrained(..., num_labels=N) crashed with
"NotImplementedError: normal_kernel_cuda not implemented for 'Byte'" on
pre-quantized bnb 4-bit checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit)
when running on transformers 5.x.

Two pieces were needed to close this out:

1. unsloth_zoo PR: add "score", "classifier", "qa_outputs" to
   SKIP_QUANTIZATION_MODULES so replace_with_bnb_linear leaves task
   heads in the compute dtype.

2. This commit: for pre-quantized checkpoints, transformers reads
   llm_int8_skip_modules from the quantization_config baked into
   config.json and ignores the runtime BitsAndBytesConfig we pass via
   kwargs. Unsloth must merge its skip list into
   model_config.quantization_config.llm_int8_skip_modules before the
   from_pretrained call, or the checkpoint's frozen list
   (e.g. ["lm_head", "multi_modal_projector", "merger",
   "modality_projection"]) wins and the `score` head gets converted to
   Linear4bit with uint8 storage, then _init_weights calls normal_ on
   uint8 and crashes.

Also add a defensive post-load cast on the task head to guard against
any residual path that ends up with a non-floating head dtype.

Verified on transformers 4.57.6 and 5.5.0 with:
- unsloth/Qwen3-4B-bnb-4bit + num_labels=3
- unsloth/Qwen3-4B (non-bnb repo, load_in_4bit=True)
- unsloth/Llama-3.2-1B-Instruct + num_labels=3
- unsloth/ModernBERT-large classifier head (bert_classification notebook)
- Regression: causal LM path unchanged, backbone still 4-bit
- 3-step SFT on num_labels=3 confirms gradient flow and weight updates
  on score.weight

Fixes unslothai/unsloth#5027

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-04-15 05:16:33 -07:00 committed by GitHub
parent 1fcb2502cf
commit b7a8ff2833
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2405,6 +2405,31 @@ class FastLlamaModel:
bnb_4bit_compute_dtype = dtype,
llm_int8_skip_modules = llm_int8_skip_modules,
)
# For pre-quantized checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit),
# transformers uses the quantization_config baked into the
# checkpoint's config.json and ignores the runtime BitsAndBytesConfig
# we pass via kwargs. Merge our skip list into that bundled config
# so task heads like `score` (for *ForSequenceClassification) stay
# in the compute dtype. See unslothai/unsloth#5027.
_ckpt_qcfg = getattr(model_config, "quantization_config", None)
if _ckpt_qcfg is not None:
if isinstance(_ckpt_qcfg, dict):
_ckpt_skip = list(_ckpt_qcfg.get("llm_int8_skip_modules") or [])
for _m in llm_int8_skip_modules:
if _m not in _ckpt_skip:
_ckpt_skip.append(_m)
_ckpt_qcfg["llm_int8_skip_modules"] = _ckpt_skip
else:
_ckpt_skip = list(
getattr(_ckpt_qcfg, "llm_int8_skip_modules", None) or []
)
for _m in llm_int8_skip_modules:
if _m not in _ckpt_skip:
_ckpt_skip.append(_m)
try:
_ckpt_qcfg.llm_int8_skip_modules = _ckpt_skip
except Exception:
pass
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
# RoPE Scaling's max_position_embeddings must be updated
@ -2441,6 +2466,18 @@ class FastLlamaModel:
attn_implementation = preferred_attn_impl,
**kwargs,
)
# Defensive: make sure the task head ended up in a floating dtype.
# The primary protection is SKIP_QUANTIZATION_MODULES plus the skip
# list merge above; this guards against a downstream path accidentally
# leaving the head in an integer storage. See unslothai/unsloth#5027.
for _head_name in ("score", "classifier", "qa_outputs"):
_head = getattr(model, _head_name, None)
if (
_head is not None
and hasattr(_head, "weight")
and not _head.weight.is_floating_point()
):
_head.to(dtype)
elif not fast_inference:
model = AutoModelForCausalLM.from_pretrained(
model_name,