Patch trunc_normal_ for low-precision stability (#4027)

* Fix low-precision trunc_normal initialization instability

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Document TorchTitan trunc_normal low-precision failure mode

* Fix trunc_normal generator positional compatibility

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix trunc_normal generator TypeError fallback

---------

Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Han 2026-02-19 04:40:14 -08:00 committed by GitHub
parent 8165266a37
commit 3bddfed117
3 changed files with 199 additions and 0 deletions

View file

@ -0,0 +1,114 @@
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Tests for trunc_normal low-precision patch compatibility."""
import importlib.util
import inspect
from pathlib import Path
import pytest
import torch
_MISSING = object()
def _load_import_fixes_module():
repo_root = Path(__file__).resolve().parents[2]
import_fixes_path = repo_root / "unsloth" / "import_fixes.py"
spec = importlib.util.spec_from_file_location(
"unsloth_import_fixes_local", import_fixes_path
)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _getattr_or_missing(obj, name):
return getattr(obj, name) if hasattr(obj, name) else _MISSING
def _restore_attr(obj, name, value):
if value is _MISSING:
if hasattr(obj, name):
delattr(obj, name)
return
setattr(obj, name, value)
def test_trunc_normal_patch_accepts_positional_generator():
import_fixes = _load_import_fixes_module()
patch_fn = import_fixes.patch_trunc_normal_precision_issue
init_mod = torch.nn.init
old_fn = init_mod.trunc_normal_
old_patched = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_patched")
old_original = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_original")
try:
# Normalize to an unpatched baseline before applying the patch.
if old_original is not _MISSING:
init_mod.trunc_normal_ = old_original
if hasattr(init_mod, "_unsloth_trunc_normal_patched"):
delattr(init_mod, "_unsloth_trunc_normal_patched")
if hasattr(init_mod, "_unsloth_trunc_normal_original"):
delattr(init_mod, "_unsloth_trunc_normal_original")
patch_fn()
sig = inspect.signature(init_mod.trunc_normal_)
assert "generator" in sig.parameters
assert sig.parameters["generator"].kind is not inspect.Parameter.KEYWORD_ONLY
tensor = torch.empty(1024, dtype = torch.float32)
gen = torch.Generator()
gen.manual_seed(3407)
init_mod.trunc_normal_(tensor, 0.0, 1.0, -2.0, 2.0, gen)
init_mod.trunc_normal_(tensor, mean = 0.0, std = 1.0, a = -2.0, b = 2.0, generator = gen)
finally:
init_mod.trunc_normal_ = old_fn
_restore_attr(init_mod, "_unsloth_trunc_normal_patched", old_patched)
_restore_attr(init_mod, "_unsloth_trunc_normal_original", old_original)
def test_trunc_normal_patch_rejects_invalid_generator():
import_fixes = _load_import_fixes_module()
patch_fn = import_fixes.patch_trunc_normal_precision_issue
init_mod = torch.nn.init
old_fn = init_mod.trunc_normal_
old_patched = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_patched")
old_original = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_original")
try:
if old_original is not _MISSING:
init_mod.trunc_normal_ = old_original
if hasattr(init_mod, "_unsloth_trunc_normal_patched"):
delattr(init_mod, "_unsloth_trunc_normal_patched")
if hasattr(init_mod, "_unsloth_trunc_normal_original"):
delattr(init_mod, "_unsloth_trunc_normal_original")
patch_fn()
sig = inspect.signature(init_mod.trunc_normal_)
if "generator" not in sig.parameters:
pytest.skip("torch.nn.init.trunc_normal_ lacks a generator parameter")
tensor = torch.empty(16, dtype = torch.float32)
with pytest.raises(TypeError):
init_mod.trunc_normal_(tensor, generator = 123)
finally:
init_mod.trunc_normal_ = old_fn
_restore_attr(init_mod, "_unsloth_trunc_normal_patched", old_patched)
_restore_attr(init_mod, "_unsloth_trunc_normal_original", old_original)

View file

@ -140,6 +140,7 @@ from .import_fixes import (
fix_vllm_pdl_blackwell,
fix_triton_compiled_kernel_missing_attrs,
fix_rocm_triton_key_error,
patch_trunc_normal_precision_issue,
ignore_logger_messages,
patch_ipykernel_hf_xet,
patch_trackio,
@ -161,6 +162,7 @@ fix_vllm_guided_decoding_params()
fix_vllm_pdl_blackwell()
fix_triton_compiled_kernel_missing_attrs()
fix_rocm_triton_key_error()
patch_trunc_normal_precision_issue()
ignore_logger_messages()
patch_ipykernel_hf_xet()
patch_trackio()
@ -180,6 +182,7 @@ del fix_vllm_guided_decoding_params
del fix_vllm_pdl_blackwell
del fix_triton_compiled_kernel_missing_attrs
del fix_rocm_triton_key_error
del patch_trunc_normal_precision_issue
del ignore_logger_messages
del patch_ipykernel_hf_xet
del patch_trackio

View file

@ -927,6 +927,88 @@ def fix_rocm_triton_key_error():
)
def patch_trunc_normal_precision_issue():
"""
Patch torch.nn.init.trunc_normal_ for low precision tensors to run init in fp32.
torch.nn.init.trunc_normal_ can saturate at truncation bounds in fp16/bf16 on
some versions/backends. This was observed in TorchTitan investigations where
low-precision truncation produced boundary-heavy initialization behavior:
https://github.com/pytorch/torchtitan/pull/2342
To avoid that failure mode, initialize into a temporary fp32 tensor, then copy
back to the original dtype.
"""
try:
import torch
except (ImportError, ModuleNotFoundError):
return
if getattr(torch.nn.init, "_unsloth_trunc_normal_patched", False):
return
original_trunc_normal = torch.nn.init.trunc_normal_
if getattr(original_trunc_normal, "__unsloth_trunc_normal_patched__", False):
torch.nn.init._unsloth_trunc_normal_patched = True
return
low_precision_dtypes = {torch.float16, torch.bfloat16}
def _call_original(target, mean, std, a, b, generator):
if generator is None:
return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)
try:
return original_trunc_normal(
target, mean = mean, std = std, a = a, b = b, generator = generator
)
except TypeError as exc:
# Older torch versions may not accept a generator keyword argument.
msg = str(exc).lower()
if "unexpected keyword argument" in msg and "generator" in msg:
return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)
raise
try:
from torch.distributed._tensor import DTensor
except Exception:
DTensor = None
@torch.no_grad()
def _patched_trunc_normal_(
tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
generator = None,
):
if DTensor is not None and isinstance(tensor, DTensor):
local_tensor = getattr(tensor, "_local_tensor", None)
if local_tensor is None:
return _call_original(tensor, mean, std, a, b, generator)
if local_tensor.dtype in low_precision_dtypes:
local_fp32 = local_tensor.float()
_call_original(local_fp32, mean, std, a, b, generator)
local_tensor.copy_(local_fp32.to(dtype = local_tensor.dtype))
return tensor
return _call_original(tensor, mean, std, a, b, generator)
if tensor.dtype in low_precision_dtypes:
tensor_fp32 = tensor.float()
_call_original(tensor_fp32, mean, std, a, b, generator)
tensor.copy_(tensor_fp32.to(dtype = tensor.dtype))
return tensor
return _call_original(tensor, mean, std, a, b, generator)
_patched_trunc_normal_.__unsloth_trunc_normal_patched__ = True
_patched_trunc_normal_._unsloth_original = original_trunc_normal
torch.nn.init._unsloth_trunc_normal_original = original_trunc_normal
torch.nn.init.trunc_normal_ = _patched_trunc_normal_
torch.nn.init._unsloth_trunc_normal_patched = True
logger.info("Unsloth: Patched torch.nn.init.trunc_normal_ for fp16/bf16 stability.")
def check_vllm_torch_sm100_compatibility():
"""
Check for incompatible vLLM + torch < 2.9.0 + SM100 (Blackwell) combination.