mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
695 lines
26 KiB
Python
695 lines
26 KiB
Python
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import importlib.util
|
|
from pathlib import Path
|
|
from importlib.metadata import version as importlib_version
|
|
from packaging.version import Version as TrueVersion
|
|
import re
|
|
import logging
|
|
import textwrap
|
|
import warnings
|
|
|
|
# We cannot do from unsloth_zoo.log import logger since FBGEMM might cause seg faults.
|
|
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") in (
|
|
"1",
|
|
"True",
|
|
"true",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
if UNSLOTH_ENABLE_LOGGING:
|
|
logging.basicConfig(
|
|
level = logging.INFO, format = "[%(name)s|%(levelname)s]%(message)s"
|
|
)
|
|
logger.setLevel(logging.INFO)
|
|
else:
|
|
logging.basicConfig(
|
|
level = logging.WARNING, format = "[%(name)s|%(levelname)s]%(message)s"
|
|
)
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
|
|
def Version(version):
|
|
try:
|
|
new_version = str(version)
|
|
new_version = re.match(r"[0-9\.]{1,}", new_version)
|
|
if new_version is None:
|
|
raise Exception(str(e))
|
|
new_version = new_version.group(0).rstrip(".")
|
|
if new_version != version:
|
|
new_version += ".1" # Add .1 for dev / alpha / beta / rc
|
|
return TrueVersion(new_version)
|
|
except:
|
|
from inspect import getframeinfo, stack
|
|
|
|
caller = getframeinfo(stack()[1][0])
|
|
raise RuntimeError(
|
|
f"Unsloth: Could not get version for `{version}`\n"
|
|
f"File name = [{caller.filename}] Line number = [{caller.lineno}]"
|
|
)
|
|
|
|
|
|
# Ignore logging messages
|
|
class HideLoggingMessage(logging.Filter):
|
|
__slots__ = ("text",)
|
|
|
|
def __init__(self, text):
|
|
self.text = text
|
|
|
|
def filter(self, x):
|
|
return not (self.text in x.getMessage())
|
|
|
|
|
|
class HidePrintMessage:
|
|
def __init__(self, original_stream):
|
|
self._original_stream = original_stream
|
|
self._hidden_texts = []
|
|
|
|
def add_filter(self, text):
|
|
self._hidden_texts.append(text)
|
|
|
|
def write(self, message):
|
|
if not any(text in message for text in self._hidden_texts):
|
|
self._original_stream.write(message)
|
|
|
|
def flush(self):
|
|
self._original_stream.flush()
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._original_stream, name)
|
|
|
|
|
|
if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") != "1":
|
|
import sys
|
|
|
|
# Apply to stderr for FBGEMM and CUTLASS errors
|
|
sys.stderr = HidePrintMessage(sys.stderr)
|
|
# https://github.com/pytorch/FBGEMM/blob/d99cd96490ec4aabac2ee95b1e76ea4dcfcfa628/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py#L43-L52
|
|
sys.stderr.add_filter("TMA benchmarks will be running")
|
|
# CUTLASS/FBGEMM MMA instruction error on SM90 vs SM100 (Blackwell) GPUs
|
|
# https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp
|
|
sys.stderr.add_filter("Arch conditional MMA instruction used without targeting")
|
|
# CUTLASS arch conditional errors for various architectures
|
|
sys.stderr.add_filter("CUTE_INVALID_CONTROL_PATH")
|
|
# CUTLASS TMA-related errors when not targeting correct architecture
|
|
sys.stderr.add_filter("Trying to use tma without CUTE_ARCH_TMA")
|
|
# Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0
|
|
logging.getLogger("torchao").setLevel(logging.ERROR)
|
|
# Also filter torchao print to stderr about cpp extensions
|
|
sys.stderr.add_filter("Skipping import of cpp extensions")
|
|
# SyntaxWarning: invalid escape sequence '\.'
|
|
warnings.filterwarnings(
|
|
"ignore", message = "invalid escape sequence", category = SyntaxWarning
|
|
)
|
|
# PYTORCH_CUDA_ALLOC_CONF is deprecated warning from torch
|
|
warnings.filterwarnings("ignore", message = "PYTORCH_CUDA_ALLOC_CONF is deprecated")
|
|
# TF32 precision deprecation warning from torch
|
|
warnings.filterwarnings(
|
|
"ignore", message = "Please use the new API settings to control TF32"
|
|
)
|
|
# Deprecation warnings from torchao
|
|
warnings.filterwarnings("ignore", message = "`int4_weight_only` is deprecated")
|
|
warnings.filterwarnings("ignore", message = "`int8_weight_only` is deprecated")
|
|
|
|
|
|
# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
|
|
# MUST do this at the start primarily due to tensorflow causing issues
|
|
def fix_message_factory_issue():
|
|
try:
|
|
import google.protobuf.message_factory
|
|
|
|
class MessageFactory:
|
|
def CreatePrototype(self, *args, **kwargs):
|
|
return
|
|
|
|
def GetMessages(self, *args, **kwargs):
|
|
return
|
|
|
|
def GetPrototype(self, *args, **kwargs):
|
|
return
|
|
|
|
if not hasattr(google.protobuf.message_factory, "MessageFactory"):
|
|
logger.info("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
|
|
google.protobuf.message_factory.MessageFactory = MessageFactory
|
|
elif (
|
|
hasattr(google.protobuf.message_factory, "MessageFactory")
|
|
and not hasattr(
|
|
google.protobuf.message_factory.MessageFactory, "GetPrototype"
|
|
)
|
|
and not hasattr(google.protobuf.message_factory, "GetMessageClass")
|
|
):
|
|
google.protobuf.message_factory.MessageFactory = MessageFactory
|
|
logger.info("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
|
|
elif (
|
|
hasattr(google.protobuf.message_factory, "MessageFactory")
|
|
and not hasattr(
|
|
google.protobuf.message_factory.MessageFactory, "GetPrototype"
|
|
)
|
|
and hasattr(google.protobuf.message_factory, "GetMessageClass")
|
|
):
|
|
GetMessageClass = google.protobuf.message_factory.GetMessageClass
|
|
|
|
def GetPrototype(self, descriptor):
|
|
return GetMessageClass(descriptor)
|
|
|
|
google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype
|
|
logger.info("Unsloth: Patching protobuf.MessageFactory.GetPrototype")
|
|
pass
|
|
except:
|
|
pass
|
|
|
|
|
|
# Fix Xformers performance issues since 0.0.25
|
|
def fix_xformers_performance_issue():
|
|
spec = importlib.util.find_spec("xformers")
|
|
if spec is None:
|
|
return
|
|
xformers_version = importlib_version("xformers")
|
|
if Version(xformers_version) < Version("0.0.29"):
|
|
xformers_location = spec.origin
|
|
if xformers_location is None:
|
|
xformers_location = spec.submodule_search_locations[0]
|
|
else:
|
|
xformers_location = os.path.split(xformers_location)[0]
|
|
cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py"
|
|
try:
|
|
if cutlass.exists():
|
|
with open(cutlass, "r+", encoding = "utf-8") as f:
|
|
text = f.read()
|
|
# See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591
|
|
if "num_splits_key=-1," in text:
|
|
text = text.replace(
|
|
"num_splits_key=-1,",
|
|
"num_splits_key=None,",
|
|
)
|
|
f.seek(0)
|
|
f.write(text)
|
|
f.truncate()
|
|
logger.info(
|
|
"Unsloth: Patching Xformers to fix some performance issues."
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Unsloth: Failed patching Xformers with error = {str(e)}")
|
|
|
|
|
|
# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
|
|
def fix_vllm_aimv2_issue():
|
|
spec = importlib.util.find_spec("vllm")
|
|
if spec is None:
|
|
return
|
|
vllm_version = importlib_version("vllm")
|
|
if Version(vllm_version) < Version("0.10.1"):
|
|
vllm_location = spec.origin
|
|
if vllm_location is None:
|
|
vllm_location = spec.submodule_search_locations[0]
|
|
else:
|
|
vllm_location = os.path.split(vllm_location)[0]
|
|
ovis_config = Path(vllm_location) / "transformers_utils" / "configs" / "ovis.py"
|
|
try:
|
|
if ovis_config.exists():
|
|
with open(ovis_config, "r+", encoding = "utf-8") as f:
|
|
text = f.read()
|
|
# See https://github.com/vllm-project/vllm-ascend/issues/2046
|
|
if 'AutoConfig.register("aimv2", AIMv2Config)' in text:
|
|
text = text.replace(
|
|
'AutoConfig.register("aimv2", AIMv2Config)',
|
|
"",
|
|
)
|
|
text = text.replace(
|
|
"""backbone_config.pop('model_type')
|
|
backbone_config = AutoConfig.for_model(model_type,
|
|
**backbone_config)""",
|
|
"""if model_type != "aimv2":
|
|
backbone_config.pop('model_type')
|
|
backbone_config = AutoConfig.for_model(model_type, **backbone_config)
|
|
else:
|
|
backbone_config = AIMv2Config(**backbone_config)""",
|
|
)
|
|
f.seek(0)
|
|
f.write(text)
|
|
f.truncate()
|
|
logger.info(
|
|
"Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`"
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Unsloth: Failed patching vLLM with error = {str(e)}")
|
|
|
|
|
|
def fix_vllm_guided_decoding_params():
|
|
if importlib.util.find_spec("vllm") is None:
|
|
return
|
|
# GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM
|
|
# https://github.com/vllm-project/vllm/pull/22772/files
|
|
# trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates
|
|
import vllm
|
|
|
|
try:
|
|
from vllm.sampling_params import GuidedDecodingParams
|
|
except ImportError:
|
|
vllm.sampling_params.GuidedDecodingParams = (
|
|
vllm.sampling_params.StructuredOutputsParams
|
|
)
|
|
|
|
|
|
def ignore_logger_messages():
|
|
# Ignore Environment variable `HF_TOKEN` is set
|
|
try:
|
|
from huggingface_hub._login import logger as huggingface_hub_logger
|
|
|
|
huggingface_hub_logger.addFilter(HideLoggingMessage("`HF_TOKEN`"))
|
|
del huggingface_hub_logger
|
|
except:
|
|
pass
|
|
|
|
|
|
def patch_ipykernel_hf_xet():
|
|
# HF-XET == 1.1.10 and ipykernel == 7.0.0 / 7.0.1 causes issues
|
|
# See https://github.com/huggingface/xet-core/issues/526
|
|
# 2025-10-13T20:37:33.028737Z ERROR Python exception updating progress:, error: PyErr { type: <class 'LookupError'>, value: LookupError(<ContextVar name='shell_parent' at 0x7535b4cebd80>), traceback: Some(<traceback object at 0x753408489f40>) }, caller: "src/progress_update.rs:313"
|
|
# at /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28
|
|
if importlib.util.find_spec("hf_xet") is None:
|
|
return
|
|
if importlib.util.find_spec("ipykernel") is None:
|
|
return
|
|
if importlib.util.find_spec("huggingface_hub") is None:
|
|
return
|
|
|
|
ipykernel_version = Version(importlib_version("ipykernel"))
|
|
if (
|
|
(Version(importlib_version("hf_xet")) == Version("1.1.10"))
|
|
and (
|
|
(ipykernel_version == Version("7.0.0"))
|
|
or (
|
|
ipykernel_version == Version("7.0.1")
|
|
) # 7.0.1 seems to also break with LookupError: <ContextVar name='shell_parent' at 0x7a9775143ec0>
|
|
)
|
|
):
|
|
print(
|
|
"#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0` or `ipykernel==7.0.1` breaks progress bars. Using ASCII progress bars.\n"
|
|
"#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>=7.1.0` or wait for a fix to\n"
|
|
"https://github.com/huggingface/xet-core/issues/526"
|
|
)
|
|
from huggingface_hub.utils import disable_progress_bars
|
|
|
|
disable_progress_bars()
|
|
|
|
|
|
def patch_trackio():
|
|
# Set some environment variables to customize the Trackio dashboard for experiment tracking
|
|
# See https://github.com/unslothai/notebooks/pull/110
|
|
os.environ["TRACKIO_LOGO_LIGHT_URL"] = (
|
|
"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png"
|
|
)
|
|
os.environ["TRACKIO_LOGO_DARK_URL"] = (
|
|
"https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png"
|
|
)
|
|
os.environ["TRACKIO_PLOT_ORDER"] = "train/reward"
|
|
|
|
|
|
def patch_datasets():
|
|
# Datasets 4.4.0 and 4.4.1 weirdly have some weird `_thread.RLock_recursion_count` issues
|
|
if importlib.util.find_spec("datasets") is None:
|
|
return
|
|
|
|
datasets_version = Version(importlib_version("datasets"))
|
|
if (datasets_version <= Version("4.5.0")) and (
|
|
datasets_version >= Version("4.4.0")
|
|
):
|
|
raise NotImplementedError(
|
|
f"#### Unsloth: Using `datasets = {str(datasets_version)}` will cause recursion errors.\n"
|
|
"Please downgrade datasets to `datasets==4.3.0"
|
|
)
|
|
|
|
|
|
def check_fbgemm_gpu_version():
|
|
if importlib.util.find_spec("fbgemm_gpu") is None:
|
|
return
|
|
try:
|
|
fbgemm_gpu_version = importlib_version("fbgemm_gpu_genai")
|
|
except:
|
|
return
|
|
# We noticed some SegFault or bad alloc errors on lower versions of fbgemm_gpu.
|
|
# Instead of raising an error, disable FBGEMM and fall back to Triton kernels.
|
|
if Version(fbgemm_gpu_version) < Version("1.4.0"):
|
|
os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
|
|
logger.info(
|
|
f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} is old and may cause issues. "
|
|
f"Disabling FBGEMM - using Triton kernels instead."
|
|
)
|
|
return
|
|
|
|
logger.info(f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected.")
|
|
|
|
|
|
def patch_enable_input_require_grads():
|
|
"""
|
|
Patch transformers PreTrainedModel.enable_input_require_grads to handle vision models
|
|
that raise NotImplementedError from get_input_embeddings().
|
|
|
|
"""
|
|
import inspect
|
|
from transformers import PreTrainedModel
|
|
|
|
# Check if the original function iterates over self.modules() instead of just returning the enable_input_require_grads
|
|
# Ref: https://github.com/huggingface/transformers/pull/41993/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL1979-R1996
|
|
try:
|
|
original_source = inspect.getsource(PreTrainedModel.enable_input_require_grads)
|
|
except:
|
|
return
|
|
|
|
# Only patch if the new pattern exists (iterating over self.modules())
|
|
if "for module in self.modules()" not in original_source:
|
|
return
|
|
|
|
def _patched_enable_input_require_grads(self):
|
|
def make_inputs_require_grads(module, input, output):
|
|
output.requires_grad_(True)
|
|
|
|
hooks = []
|
|
seen_modules = set()
|
|
|
|
for module in self.modules():
|
|
if not (
|
|
isinstance(module, PreTrainedModel)
|
|
and hasattr(module, "get_input_embeddings")
|
|
):
|
|
continue
|
|
|
|
try:
|
|
input_embeddings = module.get_input_embeddings()
|
|
except NotImplementedError:
|
|
# Vision models may not implement get_input_embeddings - skip them
|
|
# For GLM V4.6 for example, this skips only `self.visual`
|
|
continue
|
|
|
|
if input_embeddings is None:
|
|
continue
|
|
|
|
embedding_id = id(input_embeddings)
|
|
if embedding_id in seen_modules:
|
|
continue
|
|
|
|
seen_modules.add(embedding_id)
|
|
hooks.append(
|
|
input_embeddings.register_forward_hook(make_inputs_require_grads)
|
|
)
|
|
|
|
self._require_grads_hooks = hooks
|
|
if hooks:
|
|
self._require_grads_hook = hooks[0]
|
|
|
|
PreTrainedModel.enable_input_require_grads = _patched_enable_input_require_grads
|
|
|
|
logger.info(
|
|
"Unsloth: Patched enable_input_require_grads for vision model compatibility"
|
|
)
|
|
|
|
|
|
def torchvision_compatibility_check():
|
|
if importlib.util.find_spec("torch") is None:
|
|
raise ImportError("Unsloth: torch not found. Please install torch first.")
|
|
if importlib.util.find_spec("torchvision") is None:
|
|
return
|
|
torch_version = importlib_version("torch")
|
|
torchvision_version = importlib_version("torchvision")
|
|
|
|
# Torch version -> minimum required torchvision version
|
|
# See https://pytorch.org/get-started/previous-versions/
|
|
TORCH_TORCHVISION_COMPAT = [
|
|
("2.9.0", "0.24.0"),
|
|
("2.8.0", "0.23.0"),
|
|
("2.7.0", "0.22.0"),
|
|
("2.6.0", "0.21.0"),
|
|
("2.5.0", "0.20.0"),
|
|
("2.4.0", "0.19.0"),
|
|
]
|
|
|
|
required_torchvision = None
|
|
for min_torch, min_torchvision in TORCH_TORCHVISION_COMPAT:
|
|
if Version(torch_version) >= Version(min_torch):
|
|
required_torchvision = min_torchvision
|
|
break
|
|
|
|
if required_torchvision is None:
|
|
# Torch version not in compatibility table, skip check
|
|
return
|
|
|
|
if Version(torchvision_version) < Version(required_torchvision):
|
|
raise ImportError(
|
|
f"Unsloth: torch=={torch_version} requires torchvision>={required_torchvision}, "
|
|
f"but found torchvision=={torchvision_version}. "
|
|
f"Please refer to https://pytorch.org/get-started/previous-versions/ for more information."
|
|
)
|
|
|
|
logger.info(
|
|
f"Unsloth: torch=={torch_version} and torchvision=={torchvision_version} are compatible."
|
|
)
|
|
|
|
|
|
# Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined
|
|
def fix_openenv_no_vllm():
|
|
spec = importlib.util.find_spec("trl")
|
|
if spec is None:
|
|
return
|
|
trl_location = spec.origin
|
|
if trl_location is None:
|
|
trl_location = spec.submodule_search_locations[0]
|
|
else:
|
|
trl_location = os.path.split(trl_location)[0]
|
|
openenv = Path(trl_location) / "experimental" / "openenv" / "utils.py"
|
|
if not openenv.exists():
|
|
return
|
|
|
|
try:
|
|
with open(openenv, "r+", encoding = "utf-8") as f:
|
|
text = f.read()
|
|
bad = (
|
|
"if is_vllm_available():\n"
|
|
" from vllm import SamplingParams\n"
|
|
" from vllm.sampling_params import GuidedDecodingParams\n"
|
|
)
|
|
replace_with = bad + (
|
|
"else:\n"
|
|
" from typing import Any\n"
|
|
" SamplingParams = Any\n"
|
|
" GuidedDecodingParams = Any\n"
|
|
"\n"
|
|
)
|
|
if bad + "\n" + "\n" in text and replace_with not in text:
|
|
text = text.replace(bad + "\n" + "\n", replace_with)
|
|
f.seek(0)
|
|
f.write(text)
|
|
f.truncate()
|
|
logger.info(
|
|
"Unsloth: Patching TRL OpenEnv to fix SamplingParams not defined"
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Unsloth: Failed patching TRL OpenEnv with error = {str(e)}")
|
|
|
|
|
|
# Fix Exeuctorch needing get_mapped_key
|
|
def fix_executorch():
|
|
spec = importlib.util.find_spec("executorch")
|
|
if spec is None:
|
|
return
|
|
executorch_location = spec.origin
|
|
if executorch_location is None:
|
|
executorch_location = spec.submodule_search_locations[0]
|
|
else:
|
|
executorch_location = os.path.split(executorch_location)[0]
|
|
executorch = Path(executorch_location) / "examples" / "models" / "__init__.py"
|
|
if not executorch.exists():
|
|
return
|
|
|
|
try:
|
|
what = r"""
|
|
import sys
|
|
import types
|
|
import re
|
|
from typing import Any, Optional
|
|
def get_mapped_key(key: str, mapping_dict: dict[str, str]) -> str:
|
|
try:
|
|
# Checks if there is a layer # in the key
|
|
if any(k.isdigit() for k in key.split(".")):
|
|
# Replace layer number with "{}" to create key for lookup
|
|
abstract_key = re.sub(r"(\.\d+)", ".{}", key)
|
|
layer_num = re.search(r"\d+", key).group(0)
|
|
new_key = mapping_dict[abstract_key]
|
|
new_key = new_key.format(layer_num)
|
|
else:
|
|
new_key = mapping_dict[key]
|
|
except KeyError as e:
|
|
raise Exception(
|
|
f'Error converting the state dict. Found unexpected key: "{key}". '
|
|
"Please make sure you're loading a checkpoint with the right format. "
|
|
) from e
|
|
|
|
return new_key
|
|
|
|
torchtune = types.ModuleType("torchtune")
|
|
torchtune.__path__ = []
|
|
models = types.ModuleType("torchtune.models")
|
|
models.__path__ = []
|
|
convert_weights = types.ModuleType("torchtune.models.convert_weights")
|
|
convert_weights.get_mapped_key = get_mapped_key
|
|
torchtune.models = models
|
|
models.convert_weights = convert_weights
|
|
sys.modules["torchtune"] = torchtune
|
|
sys.modules["torchtune.models"] = models
|
|
sys.modules["torchtune.models.convert_weights"] = convert_weights
|
|
"""
|
|
what = textwrap.dedent(what)
|
|
|
|
with open(executorch, "r+", encoding = "utf-8") as f:
|
|
text = f.read()
|
|
bad = "from enum import Enum\n"
|
|
if bad in text and what not in text:
|
|
text = text.replace(bad + "\n", bad + "\n" + what)
|
|
f.seek(0)
|
|
f.write(text)
|
|
f.truncate()
|
|
logger.info("Unsloth: Patching Executorch to fix get_mapped_key")
|
|
except Exception as e:
|
|
logger.info(f"Unsloth: Failed Executorch with error = {str(e)}")
|
|
|
|
|
|
def fix_diffusers_warnings():
|
|
# Silence Flax classes are deprecated and will be removed in Diffusers v1.0.0.
|
|
os.environ["DIFFUSERS_VERBOSITY"] = "error"
|
|
|
|
|
|
def fix_huggingface_hub():
|
|
# huggingface_hub.is_offline_mode got removed, so add it back
|
|
import huggingface_hub
|
|
|
|
if not hasattr(huggingface_hub, "is_offline_mode"):
|
|
huggingface_hub.is_offline_mode = (
|
|
lambda: huggingface_hub.constants.HF_HUB_OFFLINE
|
|
)
|
|
|
|
|
|
def fix_vllm_pdl_blackwell():
|
|
"""
|
|
Fix vLLM PDL (Programmatic Dependent Launch) bug on Blackwell GPUs (SM100).
|
|
|
|
The issue: vLLM's LoRA Triton kernels use tl.extra.cuda.gdc_wait() for PDL
|
|
optimization on SM90+ GPUs. This fails on SM100 (B200/B100) during CUDA graph
|
|
capture because Triton's pipeliner can't handle gdc_wait in complex kernels.
|
|
|
|
See: https://github.com/vllm-project/vllm/issues/30872
|
|
"""
|
|
if importlib.util.find_spec("vllm") is None:
|
|
return
|
|
|
|
# Check if any CUDA GPU is SM100 (Blackwell)
|
|
try:
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
# Scan all GPUs for SM100 - fix applies globally via env var and monkey-patch
|
|
has_sm100 = False
|
|
sm100_gpu_name = None
|
|
for i in range(torch.cuda.device_count()):
|
|
major, minor = torch.cuda.get_device_capability(i)
|
|
if major == 10:
|
|
has_sm100 = True
|
|
sm100_gpu_name = torch.cuda.get_device_name(i)
|
|
break
|
|
|
|
if not has_sm100:
|
|
return
|
|
except Exception:
|
|
return
|
|
|
|
# Helper to check if module spec exists
|
|
def _spec_exists(name):
|
|
try:
|
|
return importlib.util.find_spec(name) is not None
|
|
except (ModuleNotFoundError, ValueError):
|
|
return False
|
|
|
|
# Check if vLLM has the PDL-related modules before doing internet check
|
|
has_utils = _spec_exists("vllm.lora.ops.triton_ops.utils")
|
|
has_expand_op = _spec_exists("vllm.lora.ops.triton_ops.lora_expand_op")
|
|
has_shrink_op = _spec_exists("vllm.lora.ops.triton_ops.lora_shrink_op")
|
|
|
|
if not has_utils and not has_expand_op and not has_shrink_op:
|
|
# Old vLLM version without PDL support - nothing to patch
|
|
return
|
|
|
|
# Check if vLLM version includes the fix
|
|
VLLM_PDL_FIX_VERSION = "0.13.2"
|
|
try:
|
|
vllm_version = Version(importlib_version("vllm"))
|
|
if vllm_version > Version(VLLM_PDL_FIX_VERSION):
|
|
logger.info(
|
|
f"Unsloth: SM100 ({sm100_gpu_name}) detected but vLLM {vllm_version} "
|
|
f"should include PDL fix - skipping workaround"
|
|
)
|
|
return
|
|
except Exception as e:
|
|
logger.debug(
|
|
f"Unsloth: vLLM version check failed ({e}), applying PDL workaround."
|
|
)
|
|
|
|
# Apply the PDL fix
|
|
os.environ["TRITON_DISABLE_PDL"] = "1"
|
|
|
|
def fake_supports_pdl(*args, **kwargs):
|
|
return False
|
|
|
|
patched = []
|
|
|
|
# First, patch the source module (utils.py) where supports_pdl is defined.
|
|
# This is critical because supports_pdl uses @lru_cache - we must clear the
|
|
# cache to prevent stale cached results from the original function.
|
|
try:
|
|
utils_module = importlib.import_module("vllm.lora.ops.triton_ops.utils")
|
|
if hasattr(utils_module, "supports_pdl"):
|
|
original_fn = utils_module.supports_pdl
|
|
if hasattr(original_fn, "cache_clear"):
|
|
original_fn.cache_clear()
|
|
utils_module.supports_pdl = fake_supports_pdl
|
|
patched.append("utils")
|
|
except (ImportError, ModuleNotFoundError, AttributeError):
|
|
pass
|
|
|
|
# Also patch the consumer modules that import supports_pdl from utils.
|
|
# This ensures the patched function is used even if the module was already
|
|
# imported before this fix runs.
|
|
consumer_modules = {
|
|
"lora_expand_op": "vllm.lora.ops.triton_ops.lora_expand_op",
|
|
"lora_shrink_op": "vllm.lora.ops.triton_ops.lora_shrink_op",
|
|
"fused_moe_lora_op": "vllm.lora.ops.triton_ops.fused_moe_lora_op",
|
|
}
|
|
for name, path in consumer_modules.items():
|
|
try:
|
|
module = importlib.import_module(path)
|
|
if hasattr(module, "supports_pdl"):
|
|
module.supports_pdl = fake_supports_pdl
|
|
patched.append(name)
|
|
except (ImportError, ModuleNotFoundError, AttributeError):
|
|
pass
|
|
|
|
if patched:
|
|
logger.info(
|
|
f"Unsloth: Applied PDL fix for SM100 ({sm100_gpu_name}) - "
|
|
f"patched: {', '.join(patched)}"
|
|
)
|
|
else:
|
|
# Just set the env var - vLLM might be an older version without supports_pdl
|
|
logger.info(f"Unsloth: Set TRITON_DISABLE_PDL=1 for SM100 ({sm100_gpu_name})")
|