mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Torch 2.8 (#3186)
* Fix mamba * Update loader.py * Update vision.py * Update loader.py * Filter vLLM standby logs (#3131) * filter vLLM standby logs * safeguard standby logger patch * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py * Update unsloth/models/_utils.py --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update loader.py * Add scaler * Update llama.py * Update _utils.py * Versioning * GPT OSS fix * GPT OSS fix * Update loader.py * Update vision.py * Update vision.py * Update loader.py * Update vision.py * Update vision.py * Update llama.py * Update llama.py * Update llama.py * Versioning * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Upcast norms * Update loader.py * Update vision.py * Upcast layernorms * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update save.py * Update rl.py * Update pyproject.toml * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update _utils.py * Update __init__.py * Torch 2.8 * Update rl_replacements.py --------- Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
This commit is contained in:
parent
10f68527d8
commit
089a0056e2
5 changed files with 195 additions and 4 deletions
112
pyproject.toml
112
pyproject.toml
|
|
@ -207,6 +207,16 @@ cu126onlytorch260 = [
|
|||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
||||
]
|
||||
cu118onlytorch270 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
||||
]
|
||||
cu126onlytorch270 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
||||
|
|
@ -227,6 +237,30 @@ cu128onlytorch270 = [
|
|||
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
||||
]
|
||||
cu118onlytorch271 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
|
||||
]
|
||||
cu126onlytorch271 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
|
||||
]
|
||||
cu128onlytorch271 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
|
||||
]
|
||||
cu118onlytorch280 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
|
||||
]
|
||||
cu126onlytorch280 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
|
||||
]
|
||||
cu128onlytorch280 = [
|
||||
"xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
|
||||
"xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
|
||||
]
|
||||
cu118 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
|
|
@ -337,6 +371,11 @@ cu126-torch260 = [
|
|||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu126onlytorch260]",
|
||||
]
|
||||
cu118-torch270 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu118onlytorch270]",
|
||||
]
|
||||
cu126-torch270 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
|
|
@ -347,6 +386,36 @@ cu128-torch270 = [
|
|||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu128onlytorch270]",
|
||||
]
|
||||
cu118-torch271 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu118onlytorch271]",
|
||||
]
|
||||
cu126-torch271 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu126onlytorch271]",
|
||||
]
|
||||
cu128-torch271 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu128onlytorch271]",
|
||||
]
|
||||
cu118-torch280 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu118onlytorch280]",
|
||||
]
|
||||
cu126-torch280 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu126onlytorch280]",
|
||||
]
|
||||
cu128-torch280 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu128onlytorch280]",
|
||||
]
|
||||
kaggle = [
|
||||
"unsloth[huggingface]",
|
||||
]
|
||||
|
|
@ -540,6 +609,12 @@ cu126-ampere-torch260 = [
|
|||
"unsloth[cu126onlytorch260]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu118-ampere-torch270 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu118onlytorch270]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu126-ampere-torch270 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
|
|
@ -552,7 +627,42 @@ cu128-ampere-torch270 = [
|
|||
"unsloth[cu128onlytorch270]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
|
||||
cu118-ampere-torch271 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu118onlytorch271]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu126-ampere-torch271 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu126onlytorch271]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu128-ampere-torch271 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu128onlytorch271]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu118-ampere-torch280 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu118onlytorch280]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu126-ampere-torch280 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu126onlytorch280]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
cu128-ampere-torch280 = [
|
||||
"unsloth[huggingface]",
|
||||
"bitsandbytes>=0.45.5",
|
||||
"unsloth[cu128onlytorch280]",
|
||||
"unsloth[flashattention]",
|
||||
]
|
||||
flashattentiontorch260abiFALSEcu12x = [
|
||||
"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.9'",
|
||||
"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10'",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
try:
|
||||
# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
|
||||
# MUST do this at the start primarily due to tensorflow causing issues
|
||||
import google.protobuf.message_factory
|
||||
class MessageFactory:
|
||||
def CreatePrototype(self, *args, **kwargs): return
|
||||
def GetMessages(self, *args, **kwargs): return
|
||||
def GetPrototype(self, *args, **kwargs): return
|
||||
if not hasattr(google.protobuf.message_factory, "MessageFactory"):
|
||||
google.protobuf.message_factory.MessageFactory = MessageFactory
|
||||
elif hasattr(google.protobuf.message_factory, "MessageFactory") and \
|
||||
not hasattr(google.protobuf.message_factory.MessageFactory, "GetPrototype") and \
|
||||
not hasattr(google.protobuf.message_factory, "GetMessageClass"):
|
||||
google.protobuf.message_factory.MessageFactory = MessageFactory
|
||||
elif hasattr(google.protobuf.message_factory, "MessageFactory") and \
|
||||
not hasattr(google.protobuf.message_factory.MessageFactory, "GetPrototype") and \
|
||||
hasattr(google.protobuf.message_factory, "GetMessageClass"):
|
||||
GetMessageClass = google.protobuf.message_factory.GetMessageClass
|
||||
def GetPrototype(self, descriptor):
|
||||
return GetMessageClass(descriptor)
|
||||
google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
import warnings, importlib, sys
|
||||
from packaging.version import Version
|
||||
import os, re, subprocess, inspect
|
||||
|
|
|
|||
|
|
@ -30,7 +30,11 @@ elif v < V('2.5.0'): x = 'cu{}{}-torch240'
|
|||
elif v < V('2.5.1'): x = 'cu{}{}-torch250'
|
||||
elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
|
||||
elif v < V('2.7.0'): x = 'cu{}{}-torch260'
|
||||
elif v < V('2.8.0'): x = 'cu{}{}-torch270'
|
||||
elif v < V('2.7.9'): x = 'cu{}{}-torch270'
|
||||
elif v < V('2.8.0'): x = 'cu{}{}-torch271'
|
||||
elif v < V('2.8.9'): x = 'cu{}{}-torch280'
|
||||
else: raise RuntimeError(f"Torch = {v} too new!")
|
||||
if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8"):
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
|
||||
|
|
@ -273,6 +273,38 @@ try:
|
|||
except:
|
||||
pass
|
||||
|
||||
# Using a slow image processor as `use_fast`
|
||||
try:
|
||||
from transformers.processing_utils import logger as processing_utils_logger
|
||||
processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
|
||||
del processing_utils_logger
|
||||
except:
|
||||
pass
|
||||
|
||||
# Using a slow image processor as `use_fast`
|
||||
try:
|
||||
from transformers.models.auto.image_processing_auto import logger as processing_utils_logger
|
||||
processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
|
||||
del processing_utils_logger
|
||||
except:
|
||||
pass
|
||||
|
||||
# `use_cache=True` is incompatible with gradient checkpointing
|
||||
try:
|
||||
from transformers.trainer import logger as trainer_logger
|
||||
trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
|
||||
del trainer_logger
|
||||
except:
|
||||
pass
|
||||
|
||||
# `use_cache=True` is incompatible with gradient checkpointing
|
||||
try:
|
||||
from transformers.utils.generic import logger as trainer_logger
|
||||
trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
|
||||
del trainer_logger
|
||||
except:
|
||||
pass
|
||||
|
||||
# Errors out on
|
||||
# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
|
||||
from transformers.modeling_utils import logger as transformers_logger
|
||||
|
|
|
|||
|
|
@ -133,15 +133,18 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
|
|||
default = -1,
|
||||
metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},
|
||||
)
|
||||
{max_seq_length_pre}
|
||||
def __init__({RLConfig_arguments},
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
{max_seq_length_call}
|
||||
**kwargs,
|
||||
):
|
||||
{RLConfig_extra_args}
|
||||
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
{max_seq_length_post}
|
||||
pass
|
||||
|
||||
{RLTrainer_extras}
|
||||
|
|
@ -353,9 +356,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
" max_length = args.max_length\n"\
|
||||
" else:\n"\
|
||||
" model_max_length = getattr(model, 'max_seq_length', None)\n"\
|
||||
" # print(model_max_length, 'mml1')\n"\
|
||||
" if model_max_length is None: model_max_length = getattr(model, 'max_length', None)\n"\
|
||||
" # print(model_max_length, 'mml2')\n"\
|
||||
" if model_max_length is not None:\n"\
|
||||
" args.max_length = model_max_length\n"\
|
||||
" max_length = args.max_length\n"\
|
||||
|
|
@ -535,6 +536,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
extra_args += learning_rate_check
|
||||
pass
|
||||
|
||||
# Check if max_seq_length is NOT defined (max_length is now default)
|
||||
if "max_seq_length" not in call_args and "max_length" in call_args:
|
||||
max_seq_length_pre = \
|
||||
"""max_seq_length : Optional[int] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
||||
)"""
|
||||
max_seq_length_call = "max_seq_length = max_seq_length,"
|
||||
max_seq_length_post = "self.max_seq_length = max_seq_length"
|
||||
else:
|
||||
max_seq_length_pre = ""
|
||||
max_seq_length_call = ""
|
||||
max_seq_length_post = ""
|
||||
pass
|
||||
|
||||
# Add output_dir saving
|
||||
if "output_dir" in call_args:
|
||||
# Default checks
|
||||
|
|
@ -666,6 +682,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
|
|||
RLTrainer_post = RLTrainer_post,
|
||||
RL_pre = RL_pre,
|
||||
|
||||
max_seq_length_pre = max_seq_length_pre,
|
||||
max_seq_length_call = max_seq_length_call,
|
||||
max_seq_length_post = max_seq_length_post,
|
||||
|
||||
selective_log_softmax_code = selective_log_softmax_code,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue