mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
* Respect classification head skip list on pre-quantized 4-bit checkpoints (#5027) FastLanguageModel.from_pretrained(..., num_labels=N) crashed with "NotImplementedError: normal_kernel_cuda not implemented for 'Byte'" on pre-quantized bnb 4-bit checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit) when running on transformers 5.x. Two pieces were needed to close this out: 1. unsloth_zoo PR: add "score", "classifier", "qa_outputs" to SKIP_QUANTIZATION_MODULES so replace_with_bnb_linear leaves task heads in the compute dtype. 2. This commit: for pre-quantized checkpoints, transformers reads llm_int8_skip_modules from the quantization_config baked into config.json and ignores the runtime BitsAndBytesConfig we pass via kwargs. Unsloth must merge its skip list into model_config.quantization_config.llm_int8_skip_modules before the from_pretrained call, or the checkpoint's frozen list (e.g. ["lm_head", "multi_modal_projector", "merger", "modality_projection"]) wins and the `score` head gets converted to Linear4bit with uint8 storage, then _init_weights calls normal_ on uint8 and crashes. Also add a defensive post-load cast on the task head to guard against any residual path that ends up with a non-floating head dtype. Verified on transformers 4.57.6 and 5.5.0 with: - unsloth/Qwen3-4B-bnb-4bit + num_labels=3 - unsloth/Qwen3-4B (non-bnb repo, load_in_4bit=True) - unsloth/Llama-3.2-1B-Instruct + num_labels=3 - unsloth/ModernBERT-large classifier head (bert_classification notebook) - Regression: causal LM path unchanged, backbone still 4-bit - 3-step SFT on num_labels=3 confirms gradient flow and weight updates on score.weight Fixes unslothai/unsloth#5027 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1fcb2502cf
commit
b7a8ff2833
1 changed files with 37 additions and 0 deletions
|
|
@ -2405,6 +2405,31 @@ class FastLlamaModel:
|
|||
bnb_4bit_compute_dtype = dtype,
|
||||
llm_int8_skip_modules = llm_int8_skip_modules,
|
||||
)
|
||||
# For pre-quantized checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit),
|
||||
# transformers uses the quantization_config baked into the
|
||||
# checkpoint's config.json and ignores the runtime BitsAndBytesConfig
|
||||
# we pass via kwargs. Merge our skip list into that bundled config
|
||||
# so task heads like `score` (for *ForSequenceClassification) stay
|
||||
# in the compute dtype. See unslothai/unsloth#5027.
|
||||
_ckpt_qcfg = getattr(model_config, "quantization_config", None)
|
||||
if _ckpt_qcfg is not None:
|
||||
if isinstance(_ckpt_qcfg, dict):
|
||||
_ckpt_skip = list(_ckpt_qcfg.get("llm_int8_skip_modules") or [])
|
||||
for _m in llm_int8_skip_modules:
|
||||
if _m not in _ckpt_skip:
|
||||
_ckpt_skip.append(_m)
|
||||
_ckpt_qcfg["llm_int8_skip_modules"] = _ckpt_skip
|
||||
else:
|
||||
_ckpt_skip = list(
|
||||
getattr(_ckpt_qcfg, "llm_int8_skip_modules", None) or []
|
||||
)
|
||||
for _m in llm_int8_skip_modules:
|
||||
if _m not in _ckpt_skip:
|
||||
_ckpt_skip.append(_m)
|
||||
try:
|
||||
_ckpt_qcfg.llm_int8_skip_modules = _ckpt_skip
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
|
||||
# RoPE Scaling's max_position_embeddings must be updated
|
||||
|
|
@ -2441,6 +2466,18 @@ class FastLlamaModel:
|
|||
attn_implementation = preferred_attn_impl,
|
||||
**kwargs,
|
||||
)
|
||||
# Defensive: make sure the task head ended up in a floating dtype.
|
||||
# The primary protection is SKIP_QUANTIZATION_MODULES plus the skip
|
||||
# list merge above; this guards against a downstream path accidentally
|
||||
# leaving the head in an integer storage. See unslothai/unsloth#5027.
|
||||
for _head_name in ("score", "classifier", "qa_outputs"):
|
||||
_head = getattr(model, _head_name, None)
|
||||
if (
|
||||
_head is not None
|
||||
and hasattr(_head, "weight")
|
||||
and not _head.weight.is_floating_point()
|
||||
):
|
||||
_head.to(dtype)
|
||||
elif not fast_inference:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
|
|
|
|||
Loading…
Reference in a new issue