mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Patch trunc_normal_ for low-precision stability (#4027)
* Fix low-precision trunc_normal initialization instability * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Document TorchTitan trunc_normal low-precision failure mode * Fix trunc_normal generator positional compatibility * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix trunc_normal generator TypeError fallback --------- Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8165266a37
commit
3bddfed117
3 changed files with 199 additions and 0 deletions
114
tests/utils/test_trunc_normal_patch.py
Normal file
114
tests/utils/test_trunc_normal_patch.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
"""Tests for trunc_normal low-precision patch compatibility."""
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _load_import_fixes_module():
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
import_fixes_path = repo_root / "unsloth" / "import_fixes.py"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"unsloth_import_fixes_local", import_fixes_path
|
||||
)
|
||||
assert spec is not None and spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _getattr_or_missing(obj, name):
|
||||
return getattr(obj, name) if hasattr(obj, name) else _MISSING
|
||||
|
||||
|
||||
def _restore_attr(obj, name, value):
|
||||
if value is _MISSING:
|
||||
if hasattr(obj, name):
|
||||
delattr(obj, name)
|
||||
return
|
||||
setattr(obj, name, value)
|
||||
|
||||
|
||||
def test_trunc_normal_patch_accepts_positional_generator():
|
||||
import_fixes = _load_import_fixes_module()
|
||||
patch_fn = import_fixes.patch_trunc_normal_precision_issue
|
||||
|
||||
init_mod = torch.nn.init
|
||||
old_fn = init_mod.trunc_normal_
|
||||
old_patched = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_patched")
|
||||
old_original = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_original")
|
||||
try:
|
||||
# Normalize to an unpatched baseline before applying the patch.
|
||||
if old_original is not _MISSING:
|
||||
init_mod.trunc_normal_ = old_original
|
||||
if hasattr(init_mod, "_unsloth_trunc_normal_patched"):
|
||||
delattr(init_mod, "_unsloth_trunc_normal_patched")
|
||||
if hasattr(init_mod, "_unsloth_trunc_normal_original"):
|
||||
delattr(init_mod, "_unsloth_trunc_normal_original")
|
||||
|
||||
patch_fn()
|
||||
sig = inspect.signature(init_mod.trunc_normal_)
|
||||
assert "generator" in sig.parameters
|
||||
assert sig.parameters["generator"].kind is not inspect.Parameter.KEYWORD_ONLY
|
||||
|
||||
tensor = torch.empty(1024, dtype = torch.float32)
|
||||
gen = torch.Generator()
|
||||
gen.manual_seed(3407)
|
||||
|
||||
init_mod.trunc_normal_(tensor, 0.0, 1.0, -2.0, 2.0, gen)
|
||||
init_mod.trunc_normal_(tensor, mean = 0.0, std = 1.0, a = -2.0, b = 2.0, generator = gen)
|
||||
finally:
|
||||
init_mod.trunc_normal_ = old_fn
|
||||
_restore_attr(init_mod, "_unsloth_trunc_normal_patched", old_patched)
|
||||
_restore_attr(init_mod, "_unsloth_trunc_normal_original", old_original)
|
||||
|
||||
|
||||
def test_trunc_normal_patch_rejects_invalid_generator():
|
||||
import_fixes = _load_import_fixes_module()
|
||||
patch_fn = import_fixes.patch_trunc_normal_precision_issue
|
||||
|
||||
init_mod = torch.nn.init
|
||||
old_fn = init_mod.trunc_normal_
|
||||
old_patched = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_patched")
|
||||
old_original = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_original")
|
||||
try:
|
||||
if old_original is not _MISSING:
|
||||
init_mod.trunc_normal_ = old_original
|
||||
if hasattr(init_mod, "_unsloth_trunc_normal_patched"):
|
||||
delattr(init_mod, "_unsloth_trunc_normal_patched")
|
||||
if hasattr(init_mod, "_unsloth_trunc_normal_original"):
|
||||
delattr(init_mod, "_unsloth_trunc_normal_original")
|
||||
|
||||
patch_fn()
|
||||
sig = inspect.signature(init_mod.trunc_normal_)
|
||||
if "generator" not in sig.parameters:
|
||||
pytest.skip("torch.nn.init.trunc_normal_ lacks a generator parameter")
|
||||
|
||||
tensor = torch.empty(16, dtype = torch.float32)
|
||||
with pytest.raises(TypeError):
|
||||
init_mod.trunc_normal_(tensor, generator = 123)
|
||||
finally:
|
||||
init_mod.trunc_normal_ = old_fn
|
||||
_restore_attr(init_mod, "_unsloth_trunc_normal_patched", old_patched)
|
||||
_restore_attr(init_mod, "_unsloth_trunc_normal_original", old_original)
|
||||
|
|
@ -140,6 +140,7 @@ from .import_fixes import (
|
|||
fix_vllm_pdl_blackwell,
|
||||
fix_triton_compiled_kernel_missing_attrs,
|
||||
fix_rocm_triton_key_error,
|
||||
patch_trunc_normal_precision_issue,
|
||||
ignore_logger_messages,
|
||||
patch_ipykernel_hf_xet,
|
||||
patch_trackio,
|
||||
|
|
@ -161,6 +162,7 @@ fix_vllm_guided_decoding_params()
|
|||
fix_vllm_pdl_blackwell()
|
||||
fix_triton_compiled_kernel_missing_attrs()
|
||||
fix_rocm_triton_key_error()
|
||||
patch_trunc_normal_precision_issue()
|
||||
ignore_logger_messages()
|
||||
patch_ipykernel_hf_xet()
|
||||
patch_trackio()
|
||||
|
|
@ -180,6 +182,7 @@ del fix_vllm_guided_decoding_params
|
|||
del fix_vllm_pdl_blackwell
|
||||
del fix_triton_compiled_kernel_missing_attrs
|
||||
del fix_rocm_triton_key_error
|
||||
del patch_trunc_normal_precision_issue
|
||||
del ignore_logger_messages
|
||||
del patch_ipykernel_hf_xet
|
||||
del patch_trackio
|
||||
|
|
|
|||
|
|
@ -927,6 +927,88 @@ def fix_rocm_triton_key_error():
|
|||
)
|
||||
|
||||
|
||||
def patch_trunc_normal_precision_issue():
|
||||
"""
|
||||
Patch torch.nn.init.trunc_normal_ for low precision tensors to run init in fp32.
|
||||
|
||||
torch.nn.init.trunc_normal_ can saturate at truncation bounds in fp16/bf16 on
|
||||
some versions/backends. This was observed in TorchTitan investigations where
|
||||
low-precision truncation produced boundary-heavy initialization behavior:
|
||||
https://github.com/pytorch/torchtitan/pull/2342
|
||||
|
||||
To avoid that failure mode, initialize into a temporary fp32 tensor, then copy
|
||||
back to the original dtype.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return
|
||||
|
||||
if getattr(torch.nn.init, "_unsloth_trunc_normal_patched", False):
|
||||
return
|
||||
|
||||
original_trunc_normal = torch.nn.init.trunc_normal_
|
||||
if getattr(original_trunc_normal, "__unsloth_trunc_normal_patched__", False):
|
||||
torch.nn.init._unsloth_trunc_normal_patched = True
|
||||
return
|
||||
|
||||
low_precision_dtypes = {torch.float16, torch.bfloat16}
|
||||
|
||||
def _call_original(target, mean, std, a, b, generator):
|
||||
if generator is None:
|
||||
return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)
|
||||
try:
|
||||
return original_trunc_normal(
|
||||
target, mean = mean, std = std, a = a, b = b, generator = generator
|
||||
)
|
||||
except TypeError as exc:
|
||||
# Older torch versions may not accept a generator keyword argument.
|
||||
msg = str(exc).lower()
|
||||
if "unexpected keyword argument" in msg and "generator" in msg:
|
||||
return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)
|
||||
raise
|
||||
|
||||
try:
|
||||
from torch.distributed._tensor import DTensor
|
||||
except Exception:
|
||||
DTensor = None
|
||||
|
||||
@torch.no_grad()
|
||||
def _patched_trunc_normal_(
|
||||
tensor,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
a: float = -2.0,
|
||||
b: float = 2.0,
|
||||
generator = None,
|
||||
):
|
||||
if DTensor is not None and isinstance(tensor, DTensor):
|
||||
local_tensor = getattr(tensor, "_local_tensor", None)
|
||||
if local_tensor is None:
|
||||
return _call_original(tensor, mean, std, a, b, generator)
|
||||
if local_tensor.dtype in low_precision_dtypes:
|
||||
local_fp32 = local_tensor.float()
|
||||
_call_original(local_fp32, mean, std, a, b, generator)
|
||||
local_tensor.copy_(local_fp32.to(dtype = local_tensor.dtype))
|
||||
return tensor
|
||||
return _call_original(tensor, mean, std, a, b, generator)
|
||||
|
||||
if tensor.dtype in low_precision_dtypes:
|
||||
tensor_fp32 = tensor.float()
|
||||
_call_original(tensor_fp32, mean, std, a, b, generator)
|
||||
tensor.copy_(tensor_fp32.to(dtype = tensor.dtype))
|
||||
return tensor
|
||||
|
||||
return _call_original(tensor, mean, std, a, b, generator)
|
||||
|
||||
_patched_trunc_normal_.__unsloth_trunc_normal_patched__ = True
|
||||
_patched_trunc_normal_._unsloth_original = original_trunc_normal
|
||||
torch.nn.init._unsloth_trunc_normal_original = original_trunc_normal
|
||||
torch.nn.init.trunc_normal_ = _patched_trunc_normal_
|
||||
torch.nn.init._unsloth_trunc_normal_patched = True
|
||||
logger.info("Unsloth: Patched torch.nn.init.trunc_normal_ for fp16/bf16 stability.")
|
||||
|
||||
|
||||
def check_vllm_torch_sm100_compatibility():
|
||||
"""
|
||||
Check for incompatible vLLM + torch < 2.9.0 + SM100 (Blackwell) combination.
|
||||
|
|
|
|||
Loading…
Reference in a new issue