From b7a8ff2833ee7d123c4fc875e0e3c10b3cd5218f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 15 Apr 2026 05:16:33 -0700 Subject: [PATCH] Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027) (#5034) * Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027) FastLanguageModel.from_pretrained(..., num_labels=N) crashed with "NotImplementedError: normal_kernel_cuda not implemented for 'Byte'" on pre-quantized bnb 4-bit checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit) when running on transformers 5.x. Two pieces were needed to close this out: 1. unsloth_zoo PR: add "score", "classifier", "qa_outputs" to SKIP_QUANTIZATION_MODULES so replace_with_bnb_linear leaves task heads in the compute dtype. 2. This commit: for pre-quantized checkpoints, transformers reads llm_int8_skip_modules from the quantization_config baked into config.json and ignores the runtime BitsAndBytesConfig we pass via kwargs. Unsloth must merge its skip list into model_config.quantization_config.llm_int8_skip_modules before the from_pretrained call, or the checkpoint's frozen list (e.g. ["lm_head", "multi_modal_projector", "merger", "modality_projection"]) wins and the `score` head gets converted to Linear4bit with uint8 storage, then _init_weights calls normal_ on uint8 and crashes. Also add a defensive post-load cast on the task head to guard against any residual path that ends up with a non-floating head dtype. Verified on transformers 4.57.6 and 5.5.0 with: - unsloth/Qwen3-4B-bnb-4bit + num_labels=3 - unsloth/Qwen3-4B (non-bnb repo, load_in_4bit=True) - unsloth/Llama-3.2-1B-Instruct + num_labels=3 - unsloth/ModernBERT-large classifier head (bert_classification notebook) - Regression: causal LM path unchanged, backbone still 4-bit - 3-step SFT on num_labels=3 confirms gradient flow and weight updates on score.weight Fixes unslothai/unsloth#5027 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- unsloth/models/llama.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 031cabcd3..63cd9b1c8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2405,6 +2405,31 @@ class FastLlamaModel: bnb_4bit_compute_dtype = dtype, llm_int8_skip_modules = llm_int8_skip_modules, ) + # For pre-quantized checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit), + # transformers uses the quantization_config baked into the + # checkpoint's config.json and ignores the runtime BitsAndBytesConfig + # we pass via kwargs. Merge our skip list into that bundled config + # so task heads like `score` (for *ForSequenceClassification) stay + # in the compute dtype. See unslothai/unsloth#5027. + _ckpt_qcfg = getattr(model_config, "quantization_config", None) + if _ckpt_qcfg is not None: + if isinstance(_ckpt_qcfg, dict): + _ckpt_skip = list(_ckpt_qcfg.get("llm_int8_skip_modules") or []) + for _m in llm_int8_skip_modules: + if _m not in _ckpt_skip: + _ckpt_skip.append(_m) + _ckpt_qcfg["llm_int8_skip_modules"] = _ckpt_skip + else: + _ckpt_skip = list( + getattr(_ckpt_qcfg, "llm_int8_skip_modules", None) or [] + ) + for _m in llm_int8_skip_modules: + if _m not in _ckpt_skip: + _ckpt_skip.append(_m) + try: + _ckpt_qcfg.llm_int8_skip_modules = _ckpt_skip + except Exception: + pass # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12 # RoPE Scaling's max_position_embeddings must be updated @@ -2441,6 +2466,18 @@ class FastLlamaModel: attn_implementation = preferred_attn_impl, **kwargs, ) + # Defensive: make sure the task head ended up in a floating dtype. + # The primary protection is SKIP_QUANTIZATION_MODULES plus the skip + # list merge above; this guards against a downstream path accidentally + # leaving the head in an integer storage. See unslothai/unsloth#5027. + for _head_name in ("score", "classifier", "qa_outputs"): + _head = getattr(model, _head_name, None) + if ( + _head is not None + and hasattr(_head, "weight") + and not _head.weight.is_floating_point() + ): + _head.to(dtype) elif not fast_inference: model = AutoModelForCausalLM.from_pretrained( model_name,