mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
Support model.save_pretrained_torchao (#3111)
Summary:
Allow users merge the LoRA weights and then do a post training quantization with torchao
Usage:
```
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
torchao_config = Int8DynamicActivationInt8WeightConfig()
model.save_pretrained_torchao(
save_path,
tokenizer=tokenizer,
torchao_config=torchao_config,
)
```
Test Plan:
python tests/saving/test_unsloth_save.py
Reviewers:
Subscribers:
Tasks:
Tags:
This commit is contained in:
parent
ac78311261
commit
f3ab8c21af
2 changed files with 96 additions and 0 deletions
|
|
@ -3,6 +3,7 @@ import os
|
|||
import shutil
|
||||
import tempfile
|
||||
import pytest
|
||||
import importlib
|
||||
|
||||
from unsloth import FastLanguageModel, FastModel
|
||||
|
||||
|
|
@ -167,3 +168,48 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
|
|||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(importlib.util.find_spec("torchao") is None)
|
||||
def test_save_torchao(model, tokenizer, temp_save_dir: str):
|
||||
save_path = os.path.join(temp_save_dir, "unsloth_torchao", model.config._name_or_path.replace("/", "_"))
|
||||
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
||||
torchao_config = Int8DynamicActivationInt8WeightConfig()
|
||||
model.save_pretrained_torchao(
|
||||
save_path,
|
||||
tokenizer=tokenizer,
|
||||
torchao_config=torchao_config,
|
||||
)
|
||||
|
||||
# Check model files
|
||||
assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
|
||||
assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
|
||||
|
||||
weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
|
||||
assert len(weight_files) > 0, "No weight files found in the save directory."
|
||||
|
||||
# Check tokenizer files
|
||||
for file in tokenizer_files:
|
||||
assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
|
||||
|
||||
# Store the size of the model files
|
||||
total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
|
||||
save_file_sizes["merged_4bit"][model.config._name_or_path] = total_size
|
||||
|
||||
print(f"Total size of merged_4bit files: {total_size} bytes")
|
||||
|
||||
assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "Merged 4bit files are larger than merged 16bit files."
|
||||
|
||||
# Check config to see if it is 4bit
|
||||
config_path = os.path.join(save_path, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "quantization_config" in config, "Quantization config not found in the model config."
|
||||
|
||||
# Test loading the model from the saved path
|
||||
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
|
||||
save_path,
|
||||
max_seq_length=128,
|
||||
dtype=None,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2122,6 +2122,55 @@ def unsloth_push_to_hub_gguf(
|
|||
pass
|
||||
pass
|
||||
|
||||
def unsloth_save_pretrained_torchao(
|
||||
self,
|
||||
save_directory : Union[str, os.PathLike],
|
||||
tokenizer = None,
|
||||
torchao_config = None,
|
||||
push_to_hub : bool = False,
|
||||
):
|
||||
"""Quantizes the model with torchao and saves a torchao quantized checkpoint
|
||||
|
||||
Args
|
||||
`save_directory`: local folder path or huggingface hub ID when `push_to_hub` is set to True, e.g. `my_model`
|
||||
`torchao_config` (TorchAOBaseConfig): configuration for torchao quantization, full list: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize
|
||||
`push_to_hub` (bool): whether to push the checkpoint to huggingface hub or not
|
||||
"""
|
||||
# first merge the lora weights
|
||||
arguments = dict(locals())
|
||||
arguments["save_directory"] = save_directory + "-local"
|
||||
arguments["model"] = self
|
||||
arguments["tokenizer"] = tokenizer
|
||||
arguments["push_to_hub"] = False # We save ourselves
|
||||
arguments["save_method"] = "merged_16bit" # Must be 16bit
|
||||
del arguments["self"]
|
||||
del arguments["torchao_config"]
|
||||
new_save_directory, old_username = unsloth_save_model(**arguments)
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
from torchao import quantize_
|
||||
if torchao_config is None:
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
||||
torchao_config = Int8DynamicActivationInt8WeightConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=torchao_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(new_save_directory)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
new_save_directory,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
save_to = save_directory
|
||||
# torchao does not support safe_serialization right now
|
||||
model.push_to_hub(save_to, safe_serialization=False)
|
||||
tokenizer.push_to_hub(save_to)
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
# Corrected function to save LoRA to a custom directory
|
||||
def save_lora_to_custom_dir(model, tokenizer, save_directory):
|
||||
# Create the custom directory if it doesn't exist
|
||||
|
|
@ -2597,6 +2646,7 @@ def patch_saving_functions(model, vision = False):
|
|||
model.save_pretrained_merged = types.MethodType(unsloth_generic_save_pretrained_merged, model)
|
||||
model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
|
||||
model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model)
|
||||
model.save_pretrained_torchao = types.MethodType(unsloth_save_pretrained_torchao, model)
|
||||
model.push_to_hub_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_push_to_hub, model)
|
||||
model.save_pretrained_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_save_locally, model)
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue