Merge pull request #3863 from unslothai/fix/fbgemm-cutlass-errors-sm100

Fix FBGEMM/CUTLASS errors on SM100 (Blackwell) GPUs
This commit is contained in:
Daniel Han 2026-01-08 03:19:53 -08:00 committed by GitHub
commit e6536a5884
3 changed files with 52 additions and 9 deletions

View file

@ -94,16 +94,34 @@ class HidePrintMessage:
if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") != "1":
import sys
# Apply to stderr for FBGEMM
# Apply to stderr for FBGEMM and CUTLASS errors
sys.stderr = HidePrintMessage(sys.stderr)
# https://github.com/pytorch/FBGEMM/blob/d99cd96490ec4aabac2ee95b1e76ea4dcfcfa628/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py#L43-L52
sys.stderr.add_filter("TMA benchmarks will be running")
# CUTLASS/FBGEMM MMA instruction error on SM90 vs SM100 (Blackwell) GPUs
# https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp
sys.stderr.add_filter("Arch conditional MMA instruction used without targeting")
# CUTLASS arch conditional errors for various architectures
sys.stderr.add_filter("CUTE_INVALID_CONTROL_PATH")
# CUTLASS TMA-related errors when not targeting correct architecture
sys.stderr.add_filter("Trying to use tma without CUTE_ARCH_TMA")
# Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0
logging.getLogger("torchao").setLevel(logging.ERROR)
# Also filter torchao print to stderr about cpp extensions
sys.stderr.add_filter("Skipping import of cpp extensions")
# SyntaxWarning: invalid escape sequence '\.'
warnings.filterwarnings(
"ignore", message = "invalid escape sequence", category = SyntaxWarning
)
# PYTORCH_CUDA_ALLOC_CONF is deprecated warning from torch
warnings.filterwarnings("ignore", message = "PYTORCH_CUDA_ALLOC_CONF is deprecated")
# TF32 precision deprecation warning from torch
warnings.filterwarnings(
"ignore", message = "Please use the new API settings to control TF32"
)
# Deprecation warnings from torchao
warnings.filterwarnings("ignore", message = "`int4_weight_only` is deprecated")
warnings.filterwarnings("ignore", message = "`int8_weight_only` is deprecated")
# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
@ -323,10 +341,14 @@ def check_fbgemm_gpu_version():
except:
return
# We noticed some SegFault or bad alloc errors on lower versions of fbgemm_gpu.
# Instead of raising an error, disable FBGEMM and fall back to Triton kernels.
if Version(fbgemm_gpu_version) < Version("1.4.0"):
raise ImportError(
f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected. It might cause unexpected issues like segmentation faults. Please uninstall the current one by doing `pip uninstall fbgemm-gpu` && `pip install fbgemm-gpu` to install fbgemm-gpu 1.4.0 or newer!"
os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
logger.info(
f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} is old and may cause issues. "
f"Disabling FBGEMM - using Triton kernels instead."
)
return
logger.info(f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected.")

View file

@ -523,6 +523,7 @@ def fp8_fbgemm_block_linear(X, weight, weight_scale, bias = None):
def test_has_fbgemm():
# We must manually check if the faster FBGEMM works on the specific GPU
# For example RTX 5090 and RTX 4090 does not work
# Also SM100 (Blackwell B200/B100) GPUs fail with CUTLASS SM90 kernels
# [TODO] Investigate with TorchAO why FBGEMM fails on consumer GPUs
M, N, K = 128, 128, 128
xq = torch.ones(M, K, dtype = torch.float8_e4m3fn, device = "cuda")
@ -537,10 +538,25 @@ def test_has_fbgemm():
has_fbgemm = True
del out
except Exception as e:
e = str(e)
if "cutlass cannot initialize" in e.lower():
error_str = str(e).lower()
# Catch any CUTLASS/CUDA errors and disable FBGEMM
# This includes MMA instruction errors, architecture mismatches, kernel launch failures, etc.
cutlass_cuda_errors = (
"cutlass",
"cuda error",
"cuda runtime error",
"no kernel image",
"arch conditional",
"mma instruction",
"compute capability",
"cute_invalid_control_path",
"tma",
)
is_cutlass_cuda_error = any(err in error_str for err in cutlass_cuda_errors)
if is_cutlass_cuda_error:
print(
f"Unsloth: FBGEMM on the current GPU cannot load - will switch to Triton kernels"
"Unsloth: FBGEMM on the current GPU cannot load - will switch to Triton kernels"
)
else:
print(

View file

@ -408,7 +408,7 @@ def _get_fp8_mode_and_check_settings(
if Version(torchao.__version__) < Version("0.15.0"):
raise ValueError(error_message)
# If fbgemm_gpu_genai is installed, check if it's >= 1.4.1
# If fbgemm_gpu_genai is installed and old, disable FBGEMM and use Triton instead
if (
importlib.util.find_spec("fbgemm_gpu") is not None
and importlib.util.find_spec("fbgemm_gpu.experimental") is not None
@ -416,7 +416,12 @@ def _get_fp8_mode_and_check_settings(
import fbgemm_gpu.experimental.gen_ai
if Version(fbgemm_gpu.__version__) < Version("1.4.1"):
raise ValueError(
"Unsloth: On the fly `load_in_fp8` is only compatible with fbgemm_gpu_genai 1.4.1+. Try `unsloth/Qwen3-8B` instead."
# Old FBGEMM version - disable and use Triton kernels instead
os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
from unsloth_zoo.log import logger
logger.info(
f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu.__version__} is old for FP8 loading. "
f"Using Triton kernels instead."
)
return fp8_mode