Support model.save_pretrained_torchao (#3111)

Summary:
Allow users merge the LoRA weights and then do a post training quantization with torchao

Usage:

```
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
torchao_config = Int8DynamicActivationInt8WeightConfig()
model.save_pretrained_torchao(
    save_path,
    tokenizer=tokenizer,
    torchao_config=torchao_config,
)
```

Test Plan:
python tests/saving/test_unsloth_save.py

Reviewers:

Subscribers:

Tasks:

Tags:
This commit is contained in:
Jerry Zhang 2025-08-26 04:53:39 -07:00 committed by GitHub
parent ac78311261
commit f3ab8c21af
2 changed files with 96 additions and 0 deletions

View file

@ -3,6 +3,7 @@ import os
import shutil
import tempfile
import pytest
import importlib
from unsloth import FastLanguageModel, FastModel
@ -167,3 +168,48 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
load_in_4bit=True,
)
@pytest.mark.skipif(importlib.util.find_spec("torchao") is None)
def test_save_torchao(model, tokenizer, temp_save_dir: str):
save_path = os.path.join(temp_save_dir, "unsloth_torchao", model.config._name_or_path.replace("/", "_"))
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
torchao_config = Int8DynamicActivationInt8WeightConfig()
model.save_pretrained_torchao(
save_path,
tokenizer=tokenizer,
torchao_config=torchao_config,
)
# Check model files
assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
assert len(weight_files) > 0, "No weight files found in the save directory."
# Check tokenizer files
for file in tokenizer_files:
assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
# Store the size of the model files
total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
save_file_sizes["merged_4bit"][model.config._name_or_path] = total_size
print(f"Total size of merged_4bit files: {total_size} bytes")
assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "Merged 4bit files are larger than merged 16bit files."
# Check config to see if it is 4bit
config_path = os.path.join(save_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
assert "quantization_config" in config, "Quantization config not found in the model config."
# Test loading the model from the saved path
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
save_path,
max_seq_length=128,
dtype=None,
load_in_4bit=True,
)

View file

@ -2122,6 +2122,55 @@ def unsloth_push_to_hub_gguf(
pass
pass
def unsloth_save_pretrained_torchao(
self,
save_directory : Union[str, os.PathLike],
tokenizer = None,
torchao_config = None,
push_to_hub : bool = False,
):
"""Quantizes the model with torchao and saves a torchao quantized checkpoint
Args
`save_directory`: local folder path or huggingface hub ID when `push_to_hub` is set to True, e.g. `my_model`
`torchao_config` (TorchAOBaseConfig): configuration for torchao quantization, full list: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize
`push_to_hub` (bool): whether to push the checkpoint to huggingface hub or not
"""
# first merge the lora weights
arguments = dict(locals())
arguments["save_directory"] = save_directory + "-local"
arguments["model"] = self
arguments["tokenizer"] = tokenizer
arguments["push_to_hub"] = False # We save ourselves
arguments["save_method"] = "merged_16bit" # Must be 16bit
del arguments["self"]
del arguments["torchao_config"]
new_save_directory, old_username = unsloth_save_model(**arguments)
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao import quantize_
if torchao_config is None:
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
torchao_config = Int8DynamicActivationInt8WeightConfig()
quantization_config = TorchAoConfig(quant_type=torchao_config)
tokenizer = AutoTokenizer.from_pretrained(new_save_directory)
model = AutoModelForCausalLM.from_pretrained(
new_save_directory,
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config,
)
if push_to_hub:
save_to = save_directory
# torchao does not support safe_serialization right now
model.push_to_hub(save_to, safe_serialization=False)
tokenizer.push_to_hub(save_to)
pass
pass
pass
# Corrected function to save LoRA to a custom directory
def save_lora_to_custom_dir(model, tokenizer, save_directory):
# Create the custom directory if it doesn't exist
@ -2597,6 +2646,7 @@ def patch_saving_functions(model, vision = False):
model.save_pretrained_merged = types.MethodType(unsloth_generic_save_pretrained_merged, model)
model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model)
model.save_pretrained_torchao = types.MethodType(unsloth_save_pretrained_torchao, model)
model.push_to_hub_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_push_to_hub, model)
model.save_pretrained_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_save_locally, model)
pass