Fixup mapper issues and resolve properly (#4124)

* Fixup mapper issues and resolve properly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Datta Nimmaturi 2026-03-03 20:27:25 +05:30 committed by GitHub
parent e238fd14aa
commit f840119fa4
7 changed files with 224 additions and 128 deletions

View file

@ -0,0 +1,127 @@
import unittest
from unittest.mock import patch
from unsloth.models.loader_utils import get_model_name
from unsloth.models import loader_utils
from unsloth.models.mapper import FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
def _no_remote_mapper():
return {}, {}, {}
class TestGetModelName(unittest.TestCase):
def _assert_mapping(self, model_name, load_in_4bit, expected, should_change):
mapped = get_model_name(model_name, load_in_4bit = load_in_4bit)
self.assertEqual(mapped.lower(), expected.lower())
if should_change:
self.assertNotEqual(mapped.lower(), model_name.lower())
else:
self.assertEqual(mapped.lower(), model_name.lower())
@patch.object(loader_utils, "_get_new_mapper", _no_remote_mapper)
def test_resolution_matrix(self):
cases = [
# Core mappings
("meta-llama/Llama-2-7b-hf", True, "unsloth/llama-2-7b-bnb-4bit", True),
("meta-llama/Llama-2-7b-hf", False, "unsloth/llama-2-7b", True),
(
"mistralai/Ministral-8B-Instruct-2410",
True,
"mistralai/Ministral-8B-Instruct-2410",
False,
),
(
"meta-llama/Llama-3.2-1B-Instruct",
False,
"unsloth/Llama-3.2-1B-Instruct",
True,
),
(
"meta-llama/Llama-2-7b-chat-hf",
True,
"unsloth/llama-2-7b-chat-bnb-4bit",
True,
),
(
"meta-llama/Llama-3.3-70B-Instruct",
True,
"unsloth/llama-3.3-70b-instruct-unsloth-bnb-4bit",
True,
),
("Qwen/Qwen3-8B", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
("Qwen/Qwen3-8B", False, "unsloth/Qwen3-8B", True),
("Qwen/Qwen3-8B-FP8", False, "unsloth/Qwen3-8B-FP8", True),
("Qwen/Qwen3-8B-FP8", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
(
"mistralai/Ministral-3-3B-Instruct-2512",
True,
"unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit",
True,
),
(
"mistralai/Ministral-3-3B-Instruct-2512",
False,
"unsloth/Ministral-3-3B-Instruct-2512",
True,
),
("unsloth/Kimi-K2-Instruct", True, "unsloth/Kimi-K2-Instruct-BF16", True),
("unsloth/Kimi-K2-Instruct", False, "unsloth/Kimi-K2-Instruct", False),
# Fallback-to-original behavior
"nonexistent-user/nonexistent-model-123",
"google/gemma-3-random-prototype-123",
"imdatta0/nanoqwen-fp8",
"imdatta0/nanoqwen-bf16",
# Backward compatibility for legacy 4bit names
("unsloth/llama-2-7b-bnb-4bit", True, "unsloth/llama-2-7b-bnb-4bit", False),
("unsloth/llama-2-7b-bnb-4bit", False, "unsloth/llama-2-7b", True),
("google/gemma-2-9b", True, "unsloth/gemma-2-9b-bnb-4bit", True),
# GPT-OSS behavior
("openai/gpt-oss-20b", False, "unsloth/gpt-oss-20b", True),
("openai/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
("unsloth/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
("unsloth/gpt-oss-20b-bf16", True, "unsloth/gpt-oss-20b-bf16", False),
(
"unsloth/gpt-oss-20b-unsloth-bnb-4bit",
False,
"unsloth/gpt-oss-20b",
True,
),
(
"unsloth/gpt-oss-20b-bnb-4bit",
True,
"unsloth/gpt-oss-20b-bnb-4bit",
False,
),
]
for case in cases:
if isinstance(case, str):
model_name = case
with self.subTest(model_name = model_name, load_in_4bit = True):
self._assert_mapping(model_name, True, model_name, False)
else:
model_name, load_in_4bit, expected, should_change = case
with self.subTest(model_name = model_name, load_in_4bit = load_in_4bit):
self._assert_mapping(
model_name, load_in_4bit, expected, should_change
)
def test_static_mapper_contract(self):
contracts = [
("qwen/qwen3-8b", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
("qwen/qwen3-8b-fp8", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
(
"mistralai/ministral-3-3b-instruct-2512",
"unsloth/ministral-3-3b-instruct-2512-unsloth-bnb-4bit",
),
("unsloth/kimi-k2-instruct", "unsloth/kimi-k2-instruct-bf16"),
]
for src, expected in contracts:
with self.subTest(src = src):
self.assertEqual(FLOAT_TO_INT_MAPPER[src], expected)
self.assertEqual(
MAP_TO_UNSLOTH_16bit["qwen/qwen3-8b-fp8"], "unsloth/Qwen3-8B-FP8"
)
if __name__ == "__main__":
unittest.main()

View file

@ -74,7 +74,6 @@ __all__ = [
"dequantize_module_weight",
"patch_hf_quantizer",
"verify_fp8_support_if_applicable",
"_redirect_fp8_to_bf16",
"_get_inference_mode_context_manager",
"hf_login",
"is_moe_model",
@ -2584,59 +2583,6 @@ def patch_hf_quantizer():
patch_hf_quantizer()
def _redirect_fp8_to_bf16(
model_name, auto_config, load_in_fp8, token, trust_remote_code
):
"""
Detect FP8 quantization in model config and redirect to BF16 sibling.
Models shipping FP8 as default (e.g. mistralai/Ministral-3-*B-Instruct)
cannot be loaded with BNB 4-bit/8-bit or 16-bit mode. This detects
quant_method in ("fp8", "fbgemm_fp8") and redirects to {model_name}-BF16.
Redirect is SKIPPED when load_in_fp8 is truthy (True or 'block'),
meaning the user explicitly wants FP8 loading.
Returns (model_name, auto_config) -- possibly updated.
"""
if not hasattr(auto_config, "quantization_config"):
return model_name, auto_config
_qc = auto_config.quantization_config
_qm = (
_qc.get("quant_method", "")
if isinstance(_qc, dict)
else getattr(_qc, "quant_method", "")
)
if _qm not in ("fp8", "fbgemm_fp8") or load_in_fp8:
return model_name, auto_config
_bf16_name = model_name.rstrip("/") + "-BF16"
_original_name = model_name
try:
from huggingface_hub import model_info as _hf_model_info
from transformers import AutoConfig
_hf_model_info(_bf16_name, token = token)
_bf16_config = AutoConfig.from_pretrained(
_bf16_name,
token = token,
trust_remote_code = trust_remote_code,
)
print(
f"Unsloth: {_original_name} uses FP8 weights. "
f"Redirecting to {_bf16_name}."
)
return _bf16_name, _bf16_config
except Exception:
raise RuntimeError(
f"Unsloth: {_original_name} uses FP8 weights but no BF16 version "
f"was found at {_bf16_name}.\n"
f"Loading FP8 weights with BitsAndBytes or in 16-bit will fail.\n"
f"Set load_in_fp8=True to use FP8 mode, or upload a BF16 version."
)
def verify_fp8_support_if_applicable(model_config):
quant_method = get_quant_type(model_config)
if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda":

View file

@ -25,7 +25,8 @@ from ._utils import move_to_device
from ._utils import (
_get_inference_mode_context_manager,
_prepare_model_for_qat,
_redirect_fp8_to_bf16,
is_bfloat16_supported,
get_quant_type,
)
from .loader_utils import _get_fp8_mode_and_check_settings
from ..utils.packing import (
@ -2331,15 +2332,6 @@ class FastLlamaModel:
token = token,
attn_implementation = "sdpa",
)
# Handle FP8 models: redirect to BF16 sibling when the model ships with
# FP8 weights. Redirect is skipped when load_in_fp8 is truthy (True or 'block').
model_name, model_config = _redirect_fp8_to_bf16(
model_name,
model_config,
load_in_fp8,
token,
trust_remote_code,
)
model_config.model_name = model_name
model_max_seq_length = model_config.max_position_embeddings

View file

@ -372,7 +372,11 @@ class FastLanguageModel(FastLlamaModel):
fp8_mode = None
if not use_exact_model_name:
new_model_name = get_model_name(
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
model_name,
load_in_4bit = load_in_4bit,
load_in_fp8 = load_in_fp8,
token = token,
trust_remote_code = trust_remote_code,
)
if new_model_name is None and load_in_fp8 != False:
fp8_mode = _get_fp8_mode_and_check_settings(
@ -525,7 +529,13 @@ class FastLanguageModel(FastLlamaModel):
# Check base model again for PEFT
model_name = peft_config.base_model_name_or_path
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
model_name = get_model_name(
model_name,
load_in_4bit = load_in_4bit,
load_in_fp8 = load_in_fp8,
token = token,
trust_remote_code = trust_remote_code,
)
# Check if pre-quantized models are allowed
# For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(

View file

@ -198,18 +198,42 @@ def _get_new_mapper():
return {}, {}, {}
def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
assert load_in_fp8 in (True, False, "block")
new_model_name = __get_model_name(
def _resolve_with_mappers(
model_name,
load_in_4bit,
load_in_fp8,
int_to_float,
float_to_int,
map_to_unsloth_16bit,
):
return __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
INT_TO_FLOAT_MAPPER = int_to_float,
FLOAT_TO_INT_MAPPER = float_to_int,
MAP_TO_UNSLOTH_16bit = map_to_unsloth_16bit,
load_in_fp8 = load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
)
def get_model_name(
model_name,
load_in_4bit = True,
load_in_fp8 = False,
token = None,
trust_remote_code = False,
):
assert load_in_fp8 in (True, False, "block")
new_model_name = _resolve_with_mappers(
model_name = model_name,
load_in_4bit = load_in_4bit,
load_in_fp8 = load_in_fp8,
int_to_float = INT_TO_FLOAT_MAPPER,
float_to_int = FLOAT_TO_INT_MAPPER,
map_to_unsloth_16bit = MAP_TO_UNSLOTH_16bit,
)
# In the rare case, we convert bad model names to other names
# For eg too large dynamic quants or MoEs
if (
@ -228,15 +252,13 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = (
_get_new_mapper()
)
upgraded_model_name = __get_model_name(
upgraded_model_name = _resolve_with_mappers(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
load_in_fp8 = load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
int_to_float = NEW_INT_TO_FLOAT_MAPPER,
float_to_int = NEW_FLOAT_TO_INT_MAPPER,
map_to_unsloth_16bit = NEW_MAP_TO_UNSLOTH_16bit,
)
if upgraded_model_name is not None:
raise NotImplementedError(
@ -245,10 +267,11 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
)
if load_in_fp8 != False:
# Handle on the fly TorchAO FP8 quantization
return new_model_name
return new_model_name if new_model_name is not None else model_name
if new_model_name is None:
new_model_name = model_name
return new_model_name
def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:

View file

@ -1337,6 +1337,9 @@ __INT_TO_FLOAT_MAPPER = \
"mistralai/Ministral-3-14B-Reasoning-2512",
"unsloth/Ministral-3-14B-Reasoning-2512-bnb-4bit",
),
"unsloth/Kimi-K2-Instruct-BF16" : (
"unsloth/Kimi-K2-Instruct",
),
}
INT_TO_FLOAT_MAPPER = {}
@ -1345,6 +1348,19 @@ MAP_TO_UNSLOTH_16bit = {}
FLOAT_TO_FP8_BLOCK_MAPPER = {}
FLOAT_TO_FP8_ROW_MAPPER = {}
def _add_with_lower(mapper, key, value):
if key is None:
return
mapper[key] = value
mapper[key.lower()] = value
def _add_lower_only(mapper, key, value):
if key is None:
return
mapper[key.lower()] = value
for key, values in __INT_TO_FLOAT_MAPPER.items():
block, row = None, None
if type(values) is dict:
@ -1355,21 +1371,24 @@ for key, values in __INT_TO_FLOAT_MAPPER.items():
float8_values = values["8"]
assert len(float8_values) == 3
official, block, row = float8_values
FLOAT_TO_FP8_BLOCK_MAPPER[key.lower()] = block
FLOAT_TO_FP8_ROW_MAPPER[key.lower()] = row
FLOAT_TO_FP8_BLOCK_MAPPER[official.lower() + "-dynamic"] = block
FLOAT_TO_FP8_ROW_MAPPER[official.lower()] = row
FLOAT_TO_FP8_ROW_MAPPER[official.lower() + "-dynamic"] = row
FLOAT_TO_FP8_BLOCK_MAPPER[float16_values[0]] = block
FLOAT_TO_FP8_BLOCK_MAPPER[float16_values[0].lower()] = block
FLOAT_TO_FP8_ROW_MAPPER[float16_values[0]] = block
FLOAT_TO_FP8_ROW_MAPPER[float16_values[0].lower()] = block
for k in float8_values:
FLOAT_TO_FP8_BLOCK_MAPPER[k.lower()] = block
FLOAT_TO_FP8_ROW_MAPPER[k.lower()] = row
for k in float16_values:
FLOAT_TO_FP8_BLOCK_MAPPER[k.lower()] = block
FLOAT_TO_FP8_ROW_MAPPER[k.lower()] = row
_add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, key, block)
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, key, row)
_add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, official + "-dynamic", block)
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official, row)
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official + "-dynamic", row)
for k in float8_values + float16_values:
_add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, k, block)
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, k, row)
if float8_values[1] is not None and float8_values[1].startswith("unsloth"):
for value in float8_values:
if value is not None:
_add_with_lower(MAP_TO_UNSLOTH_16bit, value, float8_values[1])
for value in float8_values:
if value is not None:
FLOAT_TO_INT_MAPPER[value] = key
FLOAT_TO_INT_MAPPER[value.lower()] = key.lower()
values = float16_values
INT_TO_FLOAT_MAPPER[key] = values[0]
@ -1379,27 +1398,16 @@ for key, values in __INT_TO_FLOAT_MAPPER.items():
# Map to Unsloth version for 16bit versions
if len(values) == 2:
if values[0].startswith("unsloth"):
MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
if block is not None:
MAP_TO_UNSLOTH_16bit[block] = values[0]
MAP_TO_UNSLOTH_16bit[block.lower()] = values[0]
if row is not None:
MAP_TO_UNSLOTH_16bit[row] = values[0]
MAP_TO_UNSLOTH_16bit[row.lower()] = values[0]
_add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])
_add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])
_add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])
elif len(values) == 3:
# Dynamic Unsloth quantization
if values[0].startswith("unsloth"):
MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
MAP_TO_UNSLOTH_16bit[values[2]] = values[0]
MAP_TO_UNSLOTH_16bit[values[2].lower()] = values[0]
if block is not None:
MAP_TO_UNSLOTH_16bit[block] = values[0]
MAP_TO_UNSLOTH_16bit[block.lower()] = values[0]
if row is not None:
MAP_TO_UNSLOTH_16bit[row] = values[0]
MAP_TO_UNSLOTH_16bit[row.lower()] = values[0]
_add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])
_add_with_lower(MAP_TO_UNSLOTH_16bit, values[2], values[0])
_add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])
_add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])
pass
# Get lowercased

View file

@ -30,7 +30,6 @@ from ..kernels import (
post_patch_loss_function,
)
from ._utils import __version__, importlib_version, _prepare_model_for_qat
from ._utils import _redirect_fp8_to_bf16
from ._utils import *
from .loader_utils import _get_fp8_mode_and_check_settings
from ..save import patch_saving_functions
@ -612,18 +611,9 @@ class FastBaseModel:
model_class = None
flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)
# Handle FP8 models: redirect to BF16 sibling when the model ships with
# FP8 weights (e.g. Ministral-3-3B-Instruct-2512). FP8 weights cannot be
# directly loaded by BNB, and the FP8 quantization config can cause issues
# even for 16-bit loading.
# Redirect is skipped when load_in_fp8 is truthy (True or 'block').
model_name, auto_config = _redirect_fp8_to_bf16(
model_name,
auto_config,
load_in_fp8,
token,
trust_remote_code,
)
# Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with
# FP8 weights. We just need to update it here for sanity.
auto_config.model_name = model_name
# Re-resolve model_class after potential config change
try:
model_class = auto_model._model_mapping[auto_config.__class__]