unsloth/tests/utils/hf_utils.py
Daniel Han c466303956 Fix Transformers 4.45 (#2151)
* Update pyproject.toml

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Batch samples

* Update loader.py

* Update loader.py

* Update loader.py

* Update loader.py

* Update _utils.py

* Update loader.py

* Update vision.py

* Update loader.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update mapper.py

* Update vision.py

* Temporary patches

* Update loader.py

* model names

* Gemma 3 chat template

* Bug fixes

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update llama.py

* Update llama.py

* Update rl.py

* Update chat_templates.py

* Update chat_templates.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* Update vision.py

* Update vision.py

* Revert

* Update _utils.py

* forced precision

* Autocast

* Update vision.py

* Update vision.py

* Update rl.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl.py

* vLLM fixes

* constexpr

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update save.py

* New models

* Triton windows update (#1976)

* Update pyproject.toml

* Update README.md

* Update RMS LayerNorm implementation, and list compr. change in chat templates (#1974)

* Update RMS LayerNorm implementation with optimizations and testing suite

* perf: optimize list comprehension in get_ollama_eos_tokens

* Update Zoo

* Update llama.py

* Update llama.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* grpo fix

* Update rl_replacements.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update mapper.py

* Update vision.py

* Update vision.py

* Update loader.py

* Update vision.py

* Update save.py

* Update save.py

* Update save.py

* Update rl.py

* Update _utils.py

* Version

* Update pyproject.toml

* Update llama.py

* Update llama.py

* bug fix #2008 (#2039)

* fix (#2051)

* Update loader.py

* Update pyproject.toml

* Update pyproject.toml

* Update vision.py

* more prints

* Update loader.py

* LoRA 16bit fix

* Update vision.py

* Update vision.py

* Update _utils.py

* Update vision.py

* move forced float32

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* move print

* Update _utils.py

* disable bfloat16

* Fix forced float32

* move float32

* Ensure trust_remote_code propegates down to unsloth_compile_transformers (#2075)

* Update _utils.py

* Show both `peft_error` and `autoconfig_error`, not just `autoconfig_error` (#2080)

When loading a PEFT model fails, only the `autoconfig_error` is shown. Instead of the `peft_error`, which is what really matters when we're trying to load a PEFT adapter, the user will see something like this:

```
RuntimeError: Unrecognized model in my_model. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, ...
```

This PR just changes it so `autoconfig_error` and `peft_error` are both displayed.

* fix error message (#2046)

* Update vision.py

* Update _utils.py

* Update pyproject.toml

* Update __init__.py

* Update __init__.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Remove double generate patch

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update mapper.py

* Update vision.py

* fix: config.torch_dtype in LlamaModel_fast_forward_inference (#2091)

* fix: config.torch_dtype in LlamaModel_fast_forward_inference

* Update llama.py

* update for consistency

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* versioning

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* model_type_arch

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* check

* Update _utils.py

* Update loader.py

* Update loader.py

* Remove prints

* Update _utils.py

* Update _utils.py

* versioning

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update vision.py

* HF Transfer

* fix(utils): add missing importlib import to fix NameError (#2134)

This commit fixes a NameError that occurs when `importlib` is referenced in _utils.py
without being imported, especially when UNSLOTH_USE_MODELSCOPE=1 is enabled.
By adding the missing import statement, the code will no longer throw a NameError.

* Add QLoRA Train and Merge16bit Test (#2130)

* add reference and unsloth lora merging tests

* add test / dataset printing to test scripts

* allow running tests from repo root

* add qlora test readme

* more readme edits

* ruff formatting

* additional readme comments

* forgot to add actual tests

* add apache license

* Update pyproject.toml

---------

Co-authored-by: Akshay Behl <126911424+Captain-T2004@users.noreply.github.com>
Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com>
Co-authored-by: Mukkesh Ganesh <mukmckenzie@gmail.com>
Co-authored-by: Kareem <81531392+KareemMusleh@users.noreply.github.com>
Co-authored-by: Xander Hawthorne <167850078+CuppaXanax@users.noreply.github.com>
Co-authored-by: Isaac Breen <isaac.breen@icloud.com>
Co-authored-by: lurf21 <93976703+lurf21@users.noreply.github.com>
Co-authored-by: naliazheli <nalia0316@gmail.com>
Co-authored-by: jeromeku <jerome.ku@gmail.com>
2025-03-21 17:55:12 -07:00

291 lines
8 KiB
Python

# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager, nullcontext
from typing import Callable, Optional
import bitsandbytes as bnb
import torch
from bitsandbytes.functional import dequantize_4bit
from peft import get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraConfig, LoraLayer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from transformers.trainer_callback import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from trl import SFTTrainer
class PeftWeightCallback(TrainerCallback):
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs,
**kwargs,
):
print(f"DEBUG::CALLBACK::on_log::{state.log_history}")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
model = kwargs.get("model")
assert model is not None
print(f"DEBUG::CALLBACK::on_train_begin::{kwargs.keys()}")
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
print(f"DEBUG::CALLBACK::on_step_end::{state.global_step}")
@torch.inference_mode()
def generate_responses(
model,
tokenizer,
prompt,
max_new_tokens: int = 100,
temperature: float = 0.8,
do_sample: bool = True,
num_generations: int = 1,
skip_special_tokens: bool = True,
dtype: torch.dtype = None,
):
inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)]
keys = inputs[0].keys()
batched_inputs = {
key: torch.cat([input[key] for input in inputs], dim=0).to(model.device)
for key in keys
}
if dtype is not None:
inference_context = torch.autocast(device_type="cuda", dtype=dtype)
else:
inference_context = nullcontext()
with inference_context:
outputs = model.generate(
**batched_inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
return responses
def sample_responses(
model,
tokenizer,
prompt,
temperature: float = 0.8,
num_generations: int = 1,
max_new_tokens: int = 100,
skip_special_tokens: bool = True,
dtype: torch.dtype = None,
):
responses = generate_responses(
model,
tokenizer,
prompt,
temperature=temperature,
num_generations=num_generations,
max_new_tokens=max_new_tokens,
skip_special_tokens=skip_special_tokens,
dtype=dtype,
)
return responses
def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):
tokenizer = AutoTokenizer.from_pretrained(model_name)
for fixup_func in fixup_funcs:
tokenizer = fixup_func(tokenizer)
return tokenizer
def setup_model(
model_name,
quantize: bool = True,
dtype=torch.bfloat16,
peft_config=None,
autocast_adapter: bool = True,
):
if quantize:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
)
else:
bnb_config = None
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cuda:0",
attn_implementation="sdpa",
quantization_config=bnb_config,
torch_dtype=dtype,
)
model = prepare_model_for_kbit_training(model) if quantize else model
if peft_config is not None:
model = get_peft_model(
model, peft_config, autocast_adapter_dtype=autocast_adapter
)
return model
def get_peft_config(
lora_rank,
lora_alpha=None,
lora_dropout=0.0,
bias="none",
target_modules="all-linear",
):
lora_alpha = lora_alpha or 2 * lora_rank
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_rank,
bias=bias,
target_modules=target_modules,
task_type="CAUSAL_LM",
)
return peft_config
def setup_trainer(
model,
tokenizer,
dataset,
train_args,
peft_config=None,
formatting_func=None,
collator=None,
):
return SFTTrainer(
model=model,
peft_config=peft_config,
train_dataset=dataset,
processing_class=tokenizer,
formatting_func=formatting_func,
data_collator=collator,
args=train_args,
)
def setup_lora(
model,
tokenizer,
dataset,
peft_config,
train_args,
formatting_func=None,
collator=None,
):
return LoraConfig(
model=model,
peft_config=peft_config,
train_dataset=dataset,
processing_class=tokenizer,
formatting_func=formatting_func,
data_collator=collator,
args=train_args,
)
def convert_weights_back_to_dtype(model, dtype):
"""
SFTTrainer calls get_peft_model and prepare_model_for_kbit_training which converts all weights to float32.
This function converts the non-loraweights back to the original dtype.
"""
for name, param in model.named_parameters():
if any(s in name for s in ["norm", "embed"]):
param.data = param.data.to(dtype)
def fix_llama3_tokenizer(tokenizer, padding_side="right"):
tokenizer.padding_side = padding_side
added_vocab = tokenizer.get_added_vocab()
pad_token = [w for w in added_vocab if "pad" in w]
assert len(pad_token) == 1
tokenizer.pad_token = pad_token[0] # Load dataset from the hub
return tokenizer
def replace_module(
module: torch.nn.Module,
target_module_type: torch.nn.Module,
conversion_func: Callable,
):
for child_name, child_module in module.named_children():
if isinstance(child_module, target_module_type):
new_module = conversion_func(child_module)
setattr(module, child_name, new_module)
else:
replace_module(child_module, target_module_type, conversion_func)
def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"):
base_layer = module.get_base_layer()
weight = base_layer.weight
assert isinstance(weight, bnb.nn.Params4bit)
quant_state = weight.quant_state
original_dtype = quant_state.dtype
w_dq = dequantize_4bit(weight.data, quant_state).float()
lora_delta = (
module.lora_B[adapter_name].weight
@ module.lora_A[adapter_name].weight
* module.scaling[adapter_name]
)
w_dq += lora_delta.float()
w_dq = w_dq.to(original_dtype)
new_module = torch.nn.Linear(
w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None
)
new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False)
if module.lora_bias[adapter_name]:
bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias
new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False)
return new_module
def convert_lora_to_linear(model: torch.nn.Module):
replace_module(model, LoraLayer, _convert_lora_to_linear)
assert not any(isinstance(module, LoraLayer) for module in model.modules())
return model