diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 031cabcd3..63cd9b1c8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2405,6 +2405,31 @@ class FastLlamaModel: bnb_4bit_compute_dtype = dtype, llm_int8_skip_modules = llm_int8_skip_modules, ) + # For pre-quantized checkpoints (e.g. unsloth/Qwen3-4B-bnb-4bit), + # transformers uses the quantization_config baked into the + # checkpoint's config.json and ignores the runtime BitsAndBytesConfig + # we pass via kwargs. Merge our skip list into that bundled config + # so task heads like `score` (for *ForSequenceClassification) stay + # in the compute dtype. See unslothai/unsloth#5027. + _ckpt_qcfg = getattr(model_config, "quantization_config", None) + if _ckpt_qcfg is not None: + if isinstance(_ckpt_qcfg, dict): + _ckpt_skip = list(_ckpt_qcfg.get("llm_int8_skip_modules") or []) + for _m in llm_int8_skip_modules: + if _m not in _ckpt_skip: + _ckpt_skip.append(_m) + _ckpt_qcfg["llm_int8_skip_modules"] = _ckpt_skip + else: + _ckpt_skip = list( + getattr(_ckpt_qcfg, "llm_int8_skip_modules", None) or [] + ) + for _m in llm_int8_skip_modules: + if _m not in _ckpt_skip: + _ckpt_skip.append(_m) + try: + _ckpt_qcfg.llm_int8_skip_modules = _ckpt_skip + except Exception: + pass # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12 # RoPE Scaling's max_position_embeddings must be updated @@ -2441,6 +2466,18 @@ class FastLlamaModel: attn_implementation = preferred_attn_impl, **kwargs, ) + # Defensive: make sure the task head ended up in a floating dtype. + # The primary protection is SKIP_QUANTIZATION_MODULES plus the skip + # list merge above; this guards against a downstream path accidentally + # leaving the head in an integer storage. See unslothai/unsloth#5027. + for _head_name in ("score", "classifier", "qa_outputs"): + _head = getattr(model, _head_name, None) + if ( + _head is not None + and hasattr(_head, "weight") + and not _head.weight.is_floating_point() + ): + _head.to(dtype) elif not fast_inference: model = AutoModelForCausalLM.from_pretrained( model_name,