mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Fixup mapper issues and resolve properly (#4124)
* Fixup mapper issues and resolve properly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e238fd14aa
commit
f840119fa4
7 changed files with 224 additions and 128 deletions
127
tests/test_get_model_name.py
Normal file
127
tests/test_get_model_name.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
from unsloth.models.loader_utils import get_model_name
|
||||
from unsloth.models import loader_utils
|
||||
from unsloth.models.mapper import FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
|
||||
|
||||
|
||||
def _no_remote_mapper():
|
||||
return {}, {}, {}
|
||||
|
||||
|
||||
class TestGetModelName(unittest.TestCase):
|
||||
def _assert_mapping(self, model_name, load_in_4bit, expected, should_change):
|
||||
mapped = get_model_name(model_name, load_in_4bit = load_in_4bit)
|
||||
self.assertEqual(mapped.lower(), expected.lower())
|
||||
if should_change:
|
||||
self.assertNotEqual(mapped.lower(), model_name.lower())
|
||||
else:
|
||||
self.assertEqual(mapped.lower(), model_name.lower())
|
||||
|
||||
@patch.object(loader_utils, "_get_new_mapper", _no_remote_mapper)
|
||||
def test_resolution_matrix(self):
|
||||
cases = [
|
||||
# Core mappings
|
||||
("meta-llama/Llama-2-7b-hf", True, "unsloth/llama-2-7b-bnb-4bit", True),
|
||||
("meta-llama/Llama-2-7b-hf", False, "unsloth/llama-2-7b", True),
|
||||
(
|
||||
"mistralai/Ministral-8B-Instruct-2410",
|
||||
True,
|
||||
"mistralai/Ministral-8B-Instruct-2410",
|
||||
False,
|
||||
),
|
||||
(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
False,
|
||||
"unsloth/Llama-3.2-1B-Instruct",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
True,
|
||||
"unsloth/llama-2-7b-chat-bnb-4bit",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"meta-llama/Llama-3.3-70B-Instruct",
|
||||
True,
|
||||
"unsloth/llama-3.3-70b-instruct-unsloth-bnb-4bit",
|
||||
True,
|
||||
),
|
||||
("Qwen/Qwen3-8B", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
|
||||
("Qwen/Qwen3-8B", False, "unsloth/Qwen3-8B", True),
|
||||
("Qwen/Qwen3-8B-FP8", False, "unsloth/Qwen3-8B-FP8", True),
|
||||
("Qwen/Qwen3-8B-FP8", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
|
||||
(
|
||||
"mistralai/Ministral-3-3B-Instruct-2512",
|
||||
True,
|
||||
"unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"mistralai/Ministral-3-3B-Instruct-2512",
|
||||
False,
|
||||
"unsloth/Ministral-3-3B-Instruct-2512",
|
||||
True,
|
||||
),
|
||||
("unsloth/Kimi-K2-Instruct", True, "unsloth/Kimi-K2-Instruct-BF16", True),
|
||||
("unsloth/Kimi-K2-Instruct", False, "unsloth/Kimi-K2-Instruct", False),
|
||||
# Fallback-to-original behavior
|
||||
"nonexistent-user/nonexistent-model-123",
|
||||
"google/gemma-3-random-prototype-123",
|
||||
"imdatta0/nanoqwen-fp8",
|
||||
"imdatta0/nanoqwen-bf16",
|
||||
# Backward compatibility for legacy 4bit names
|
||||
("unsloth/llama-2-7b-bnb-4bit", True, "unsloth/llama-2-7b-bnb-4bit", False),
|
||||
("unsloth/llama-2-7b-bnb-4bit", False, "unsloth/llama-2-7b", True),
|
||||
("google/gemma-2-9b", True, "unsloth/gemma-2-9b-bnb-4bit", True),
|
||||
# GPT-OSS behavior
|
||||
("openai/gpt-oss-20b", False, "unsloth/gpt-oss-20b", True),
|
||||
("openai/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
|
||||
("unsloth/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
|
||||
("unsloth/gpt-oss-20b-bf16", True, "unsloth/gpt-oss-20b-bf16", False),
|
||||
(
|
||||
"unsloth/gpt-oss-20b-unsloth-bnb-4bit",
|
||||
False,
|
||||
"unsloth/gpt-oss-20b",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"unsloth/gpt-oss-20b-bnb-4bit",
|
||||
True,
|
||||
"unsloth/gpt-oss-20b-bnb-4bit",
|
||||
False,
|
||||
),
|
||||
]
|
||||
for case in cases:
|
||||
if isinstance(case, str):
|
||||
model_name = case
|
||||
with self.subTest(model_name = model_name, load_in_4bit = True):
|
||||
self._assert_mapping(model_name, True, model_name, False)
|
||||
else:
|
||||
model_name, load_in_4bit, expected, should_change = case
|
||||
with self.subTest(model_name = model_name, load_in_4bit = load_in_4bit):
|
||||
self._assert_mapping(
|
||||
model_name, load_in_4bit, expected, should_change
|
||||
)
|
||||
|
||||
def test_static_mapper_contract(self):
|
||||
contracts = [
|
||||
("qwen/qwen3-8b", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
|
||||
("qwen/qwen3-8b-fp8", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
|
||||
(
|
||||
"mistralai/ministral-3-3b-instruct-2512",
|
||||
"unsloth/ministral-3-3b-instruct-2512-unsloth-bnb-4bit",
|
||||
),
|
||||
("unsloth/kimi-k2-instruct", "unsloth/kimi-k2-instruct-bf16"),
|
||||
]
|
||||
for src, expected in contracts:
|
||||
with self.subTest(src = src):
|
||||
self.assertEqual(FLOAT_TO_INT_MAPPER[src], expected)
|
||||
self.assertEqual(
|
||||
MAP_TO_UNSLOTH_16bit["qwen/qwen3-8b-fp8"], "unsloth/Qwen3-8B-FP8"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -74,7 +74,6 @@ __all__ = [
|
|||
"dequantize_module_weight",
|
||||
"patch_hf_quantizer",
|
||||
"verify_fp8_support_if_applicable",
|
||||
"_redirect_fp8_to_bf16",
|
||||
"_get_inference_mode_context_manager",
|
||||
"hf_login",
|
||||
"is_moe_model",
|
||||
|
|
@ -2584,59 +2583,6 @@ def patch_hf_quantizer():
|
|||
patch_hf_quantizer()
|
||||
|
||||
|
||||
def _redirect_fp8_to_bf16(
|
||||
model_name, auto_config, load_in_fp8, token, trust_remote_code
|
||||
):
|
||||
"""
|
||||
Detect FP8 quantization in model config and redirect to BF16 sibling.
|
||||
|
||||
Models shipping FP8 as default (e.g. mistralai/Ministral-3-*B-Instruct)
|
||||
cannot be loaded with BNB 4-bit/8-bit or 16-bit mode. This detects
|
||||
quant_method in ("fp8", "fbgemm_fp8") and redirects to {model_name}-BF16.
|
||||
|
||||
Redirect is SKIPPED when load_in_fp8 is truthy (True or 'block'),
|
||||
meaning the user explicitly wants FP8 loading.
|
||||
|
||||
Returns (model_name, auto_config) -- possibly updated.
|
||||
"""
|
||||
if not hasattr(auto_config, "quantization_config"):
|
||||
return model_name, auto_config
|
||||
|
||||
_qc = auto_config.quantization_config
|
||||
_qm = (
|
||||
_qc.get("quant_method", "")
|
||||
if isinstance(_qc, dict)
|
||||
else getattr(_qc, "quant_method", "")
|
||||
)
|
||||
if _qm not in ("fp8", "fbgemm_fp8") or load_in_fp8:
|
||||
return model_name, auto_config
|
||||
|
||||
_bf16_name = model_name.rstrip("/") + "-BF16"
|
||||
_original_name = model_name
|
||||
try:
|
||||
from huggingface_hub import model_info as _hf_model_info
|
||||
from transformers import AutoConfig
|
||||
|
||||
_hf_model_info(_bf16_name, token = token)
|
||||
_bf16_config = AutoConfig.from_pretrained(
|
||||
_bf16_name,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
print(
|
||||
f"Unsloth: {_original_name} uses FP8 weights. "
|
||||
f"Redirecting to {_bf16_name}."
|
||||
)
|
||||
return _bf16_name, _bf16_config
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
f"Unsloth: {_original_name} uses FP8 weights but no BF16 version "
|
||||
f"was found at {_bf16_name}.\n"
|
||||
f"Loading FP8 weights with BitsAndBytes or in 16-bit will fail.\n"
|
||||
f"Set load_in_fp8=True to use FP8 mode, or upload a BF16 version."
|
||||
)
|
||||
|
||||
|
||||
def verify_fp8_support_if_applicable(model_config):
|
||||
quant_method = get_quant_type(model_config)
|
||||
if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda":
|
||||
|
|
|
|||
|
|
@ -25,7 +25,8 @@ from ._utils import move_to_device
|
|||
from ._utils import (
|
||||
_get_inference_mode_context_manager,
|
||||
_prepare_model_for_qat,
|
||||
_redirect_fp8_to_bf16,
|
||||
is_bfloat16_supported,
|
||||
get_quant_type,
|
||||
)
|
||||
from .loader_utils import _get_fp8_mode_and_check_settings
|
||||
from ..utils.packing import (
|
||||
|
|
@ -2331,15 +2332,6 @@ class FastLlamaModel:
|
|||
token = token,
|
||||
attn_implementation = "sdpa",
|
||||
)
|
||||
# Handle FP8 models: redirect to BF16 sibling when the model ships with
|
||||
# FP8 weights. Redirect is skipped when load_in_fp8 is truthy (True or 'block').
|
||||
model_name, model_config = _redirect_fp8_to_bf16(
|
||||
model_name,
|
||||
model_config,
|
||||
load_in_fp8,
|
||||
token,
|
||||
trust_remote_code,
|
||||
)
|
||||
model_config.model_name = model_name
|
||||
model_max_seq_length = model_config.max_position_embeddings
|
||||
|
||||
|
|
|
|||
|
|
@ -372,7 +372,11 @@ class FastLanguageModel(FastLlamaModel):
|
|||
fp8_mode = None
|
||||
if not use_exact_model_name:
|
||||
new_model_name = get_model_name(
|
||||
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
|
||||
model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
if new_model_name is None and load_in_fp8 != False:
|
||||
fp8_mode = _get_fp8_mode_and_check_settings(
|
||||
|
|
@ -525,7 +529,13 @@ class FastLanguageModel(FastLlamaModel):
|
|||
# Check base model again for PEFT
|
||||
model_name = peft_config.base_model_name_or_path
|
||||
if not use_exact_model_name:
|
||||
model_name = get_model_name(model_name, load_in_4bit)
|
||||
model_name = get_model_name(
|
||||
model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
token = token,
|
||||
trust_remote_code = trust_remote_code,
|
||||
)
|
||||
# Check if pre-quantized models are allowed
|
||||
# For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
|
||||
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
|
||||
|
|
|
|||
|
|
@ -198,18 +198,42 @@ def _get_new_mapper():
|
|||
return {}, {}, {}
|
||||
|
||||
|
||||
def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
|
||||
assert load_in_fp8 in (True, False, "block")
|
||||
new_model_name = __get_model_name(
|
||||
def _resolve_with_mappers(
|
||||
model_name,
|
||||
load_in_4bit,
|
||||
load_in_fp8,
|
||||
int_to_float,
|
||||
float_to_int,
|
||||
map_to_unsloth_16bit,
|
||||
):
|
||||
return __get_model_name(
|
||||
model_name = model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
|
||||
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
|
||||
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
|
||||
INT_TO_FLOAT_MAPPER = int_to_float,
|
||||
FLOAT_TO_INT_MAPPER = float_to_int,
|
||||
MAP_TO_UNSLOTH_16bit = map_to_unsloth_16bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
|
||||
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
|
||||
)
|
||||
|
||||
|
||||
def get_model_name(
|
||||
model_name,
|
||||
load_in_4bit = True,
|
||||
load_in_fp8 = False,
|
||||
token = None,
|
||||
trust_remote_code = False,
|
||||
):
|
||||
assert load_in_fp8 in (True, False, "block")
|
||||
new_model_name = _resolve_with_mappers(
|
||||
model_name = model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
int_to_float = INT_TO_FLOAT_MAPPER,
|
||||
float_to_int = FLOAT_TO_INT_MAPPER,
|
||||
map_to_unsloth_16bit = MAP_TO_UNSLOTH_16bit,
|
||||
)
|
||||
# In the rare case, we convert bad model names to other names
|
||||
# For eg too large dynamic quants or MoEs
|
||||
if (
|
||||
|
|
@ -228,15 +252,13 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
|
|||
NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = (
|
||||
_get_new_mapper()
|
||||
)
|
||||
upgraded_model_name = __get_model_name(
|
||||
upgraded_model_name = _resolve_with_mappers(
|
||||
model_name = model_name,
|
||||
load_in_4bit = load_in_4bit,
|
||||
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
|
||||
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
|
||||
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
|
||||
load_in_fp8 = load_in_fp8,
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
|
||||
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
|
||||
int_to_float = NEW_INT_TO_FLOAT_MAPPER,
|
||||
float_to_int = NEW_FLOAT_TO_INT_MAPPER,
|
||||
map_to_unsloth_16bit = NEW_MAP_TO_UNSLOTH_16bit,
|
||||
)
|
||||
if upgraded_model_name is not None:
|
||||
raise NotImplementedError(
|
||||
|
|
@ -245,10 +267,11 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
|
|||
'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
|
||||
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
|
||||
)
|
||||
if load_in_fp8 != False:
|
||||
# Handle on the fly TorchAO FP8 quantization
|
||||
return new_model_name
|
||||
return new_model_name if new_model_name is not None else model_name
|
||||
|
||||
if new_model_name is None:
|
||||
new_model_name = model_name
|
||||
|
||||
return new_model_name
|
||||
|
||||
|
||||
def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1337,6 +1337,9 @@ __INT_TO_FLOAT_MAPPER = \
|
|||
"mistralai/Ministral-3-14B-Reasoning-2512",
|
||||
"unsloth/Ministral-3-14B-Reasoning-2512-bnb-4bit",
|
||||
),
|
||||
"unsloth/Kimi-K2-Instruct-BF16" : (
|
||||
"unsloth/Kimi-K2-Instruct",
|
||||
),
|
||||
}
|
||||
|
||||
INT_TO_FLOAT_MAPPER = {}
|
||||
|
|
@ -1345,6 +1348,19 @@ MAP_TO_UNSLOTH_16bit = {}
|
|||
FLOAT_TO_FP8_BLOCK_MAPPER = {}
|
||||
FLOAT_TO_FP8_ROW_MAPPER = {}
|
||||
|
||||
|
||||
def _add_with_lower(mapper, key, value):
|
||||
if key is None:
|
||||
return
|
||||
mapper[key] = value
|
||||
mapper[key.lower()] = value
|
||||
|
||||
|
||||
def _add_lower_only(mapper, key, value):
|
||||
if key is None:
|
||||
return
|
||||
mapper[key.lower()] = value
|
||||
|
||||
for key, values in __INT_TO_FLOAT_MAPPER.items():
|
||||
block, row = None, None
|
||||
if type(values) is dict:
|
||||
|
|
@ -1355,21 +1371,24 @@ for key, values in __INT_TO_FLOAT_MAPPER.items():
|
|||
float8_values = values["8"]
|
||||
assert len(float8_values) == 3
|
||||
official, block, row = float8_values
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER[key.lower()] = block
|
||||
FLOAT_TO_FP8_ROW_MAPPER[key.lower()] = row
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER[official.lower() + "-dynamic"] = block
|
||||
FLOAT_TO_FP8_ROW_MAPPER[official.lower()] = row
|
||||
FLOAT_TO_FP8_ROW_MAPPER[official.lower() + "-dynamic"] = row
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER[float16_values[0]] = block
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER[float16_values[0].lower()] = block
|
||||
FLOAT_TO_FP8_ROW_MAPPER[float16_values[0]] = block
|
||||
FLOAT_TO_FP8_ROW_MAPPER[float16_values[0].lower()] = block
|
||||
for k in float8_values:
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER[k.lower()] = block
|
||||
FLOAT_TO_FP8_ROW_MAPPER[k.lower()] = row
|
||||
for k in float16_values:
|
||||
FLOAT_TO_FP8_BLOCK_MAPPER[k.lower()] = block
|
||||
FLOAT_TO_FP8_ROW_MAPPER[k.lower()] = row
|
||||
_add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, key, block)
|
||||
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, key, row)
|
||||
_add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, official + "-dynamic", block)
|
||||
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official, row)
|
||||
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official + "-dynamic", row)
|
||||
for k in float8_values + float16_values:
|
||||
_add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, k, block)
|
||||
_add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, k, row)
|
||||
|
||||
if float8_values[1] is not None and float8_values[1].startswith("unsloth"):
|
||||
for value in float8_values:
|
||||
if value is not None:
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, value, float8_values[1])
|
||||
|
||||
for value in float8_values:
|
||||
if value is not None:
|
||||
FLOAT_TO_INT_MAPPER[value] = key
|
||||
FLOAT_TO_INT_MAPPER[value.lower()] = key.lower()
|
||||
values = float16_values
|
||||
INT_TO_FLOAT_MAPPER[key] = values[0]
|
||||
|
||||
|
|
@ -1379,27 +1398,16 @@ for key, values in __INT_TO_FLOAT_MAPPER.items():
|
|||
# Map to Unsloth version for 16bit versions
|
||||
if len(values) == 2:
|
||||
if values[0].startswith("unsloth"):
|
||||
MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
|
||||
if block is not None:
|
||||
MAP_TO_UNSLOTH_16bit[block] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[block.lower()] = values[0]
|
||||
if row is not None:
|
||||
MAP_TO_UNSLOTH_16bit[row] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[row.lower()] = values[0]
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])
|
||||
elif len(values) == 3:
|
||||
# Dynamic Unsloth quantization
|
||||
if values[0].startswith("unsloth"):
|
||||
MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[values[2]] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[values[2].lower()] = values[0]
|
||||
if block is not None:
|
||||
MAP_TO_UNSLOTH_16bit[block] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[block.lower()] = values[0]
|
||||
if row is not None:
|
||||
MAP_TO_UNSLOTH_16bit[row] = values[0]
|
||||
MAP_TO_UNSLOTH_16bit[row.lower()] = values[0]
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, values[2], values[0])
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])
|
||||
_add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])
|
||||
pass
|
||||
|
||||
# Get lowercased
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ from ..kernels import (
|
|||
post_patch_loss_function,
|
||||
)
|
||||
from ._utils import __version__, importlib_version, _prepare_model_for_qat
|
||||
from ._utils import _redirect_fp8_to_bf16
|
||||
from ._utils import *
|
||||
from .loader_utils import _get_fp8_mode_and_check_settings
|
||||
from ..save import patch_saving_functions
|
||||
|
|
@ -612,18 +611,9 @@ class FastBaseModel:
|
|||
model_class = None
|
||||
flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)
|
||||
|
||||
# Handle FP8 models: redirect to BF16 sibling when the model ships with
|
||||
# FP8 weights (e.g. Ministral-3-3B-Instruct-2512). FP8 weights cannot be
|
||||
# directly loaded by BNB, and the FP8 quantization config can cause issues
|
||||
# even for 16-bit loading.
|
||||
# Redirect is skipped when load_in_fp8 is truthy (True or 'block').
|
||||
model_name, auto_config = _redirect_fp8_to_bf16(
|
||||
model_name,
|
||||
auto_config,
|
||||
load_in_fp8,
|
||||
token,
|
||||
trust_remote_code,
|
||||
)
|
||||
# Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with
|
||||
# FP8 weights. We just need to update it here for sanity.
|
||||
auto_config.model_name = model_name
|
||||
# Re-resolve model_class after potential config change
|
||||
try:
|
||||
model_class = auto_model._model_mapping[auto_config.__class__]
|
||||
|
|
|
|||
Loading…
Reference in a new issue