mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
fixed save_pretrained_torchao and associated tests (#3264)
This commit is contained in:
parent
f42f0d2116
commit
0135d126df
3 changed files with 29 additions and 18 deletions
|
|
@ -17,7 +17,7 @@ model_to_test = [
|
|||
"unsloth/Phi-4-mini-instruct-bnb-4bit",
|
||||
"unsloth/Qwen2.5-0.5B",
|
||||
# Vision Models
|
||||
"unsloth/gemma-3-1b-it",
|
||||
"unsloth/gemma-3-4b-it",
|
||||
"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
|
||||
"unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit"
|
||||
]
|
||||
|
|
@ -182,27 +182,31 @@ def test_save_torchao(model, tokenizer, temp_save_dir: str):
|
|||
push_to_hub=False,
|
||||
)
|
||||
|
||||
# Check model files
|
||||
assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
|
||||
assert os.path.isfile(os.path.join(save_path, "config.json")), "config.json not found."
|
||||
weight_files_16bit = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
|
||||
total_16bit_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files_16bit)
|
||||
save_file_sizes["merged_16bit"][model.config._name_or_path] = total_16bit_size
|
||||
|
||||
weight_files = [f for f in os.listdir(save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
|
||||
torchao_save_path = save_path + "-torchao"
|
||||
|
||||
# Check model files
|
||||
assert os.path.isdir(torchao_save_path), f"Directory {torchao_save_path} does not exist."
|
||||
assert os.path.isfile(os.path.join(torchao_save_path, "config.json")), "config.json not found."
|
||||
|
||||
weight_files = [f for f in os.listdir(torchao_save_path) if f.endswith(".bin") or f.endswith(".safetensors")]
|
||||
assert len(weight_files) > 0, "No weight files found in the save directory."
|
||||
|
||||
# Check tokenizer files
|
||||
for file in tokenizer_files:
|
||||
assert os.path.isfile(os.path.join(save_path, file)), f"{file} not found in the save directory."
|
||||
assert os.path.isfile(os.path.join(torchao_save_path, file)), f"{file} not found in the save directory."
|
||||
|
||||
# Store the size of the model files
|
||||
total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
|
||||
total_size = sum(os.path.getsize(os.path.join(torchao_save_path, f)) for f in weight_files)
|
||||
save_file_sizes["torchao"][model.config._name_or_path] = total_size
|
||||
|
||||
# merged_16bit tests are not running yet, so we can't test this for now
|
||||
# TODO: enable this after merged_16bit is fixed
|
||||
# assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "torchao files are larger than merged 16bit files."
|
||||
assert total_size < save_file_sizes["merged_16bit"][model.config._name_or_path], "torchao files are larger than merged 16bit files."
|
||||
|
||||
# Check config to see if it is quantized with torchao
|
||||
config_path = os.path.join(save_path, "config.json")
|
||||
config_path = os.path.join(torchao_save_path, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
|
|
|||
|
|
@ -812,6 +812,11 @@ __INT_TO_FLOAT_MAPPER = \
|
|||
"microsoft/Phi-4-mini-reasoning",
|
||||
"unsloth/phi-4-mini-reasoning-bnb-4bit",
|
||||
),
|
||||
"unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit" : (
|
||||
"unsloth/Phi-4-mini-instruct",
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
"unsloth/Phi-4-mini-instruct-bnb-4bit",
|
||||
),
|
||||
"unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : (
|
||||
"unsloth/orpheus-3b-0.1-pretrained",
|
||||
"canopylabs/orpheus-3b-0.1-pretrained",
|
||||
|
|
|
|||
|
|
@ -2516,7 +2516,6 @@ def unsloth_save_pretrained_torchao(
|
|||
"""
|
||||
# first merge the lora weights
|
||||
arguments = dict(locals())
|
||||
arguments["save_directory"] = save_directory + "-local"
|
||||
arguments["model"] = self
|
||||
arguments["tokenizer"] = tokenizer
|
||||
arguments["push_to_hub"] = False # We save ourselves
|
||||
|
|
@ -2527,7 +2526,7 @@ def unsloth_save_pretrained_torchao(
|
|||
for _ in range(3):
|
||||
gc.collect()
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
from transformers import AutoModel, AutoTokenizer, TorchAoConfig
|
||||
from torchao import quantize_
|
||||
if torchao_config is None:
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
||||
|
|
@ -2539,21 +2538,23 @@ def unsloth_save_pretrained_torchao(
|
|||
kwargs = {"torch_dtype" : "auto"}
|
||||
else:
|
||||
kwargs = {"dtype" : "auto"}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model = AutoModel.from_pretrained(
|
||||
arguments["save_directory"],
|
||||
device_map = "auto",
|
||||
quantization_config = quantization_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
torchao_save_directory = save_directory + "-torchao"
|
||||
|
||||
if push_to_hub:
|
||||
if token is None and push_to_hub: token = get_token()
|
||||
# torchao does not support safe_serialization right now
|
||||
model.push_to_hub(save_directory, safe_serialization = False, token = token)
|
||||
tokenizer.push_to_hub(save_directory, token = token)
|
||||
model.push_to_hub(torchao_save_directory, safe_serialization = False, token = token)
|
||||
tokenizer.push_to_hub(torchao_save_directory, token = token)
|
||||
else:
|
||||
model.save_pretrained(save_directory, safe_serialization=False)
|
||||
tokenizer.save_pretrained(save_directory)
|
||||
model.save_pretrained(torchao_save_directory, safe_serialization=False)
|
||||
tokenizer.save_pretrained(torchao_save_directory)
|
||||
pass
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
|
|
@ -2671,6 +2672,7 @@ def patch_saving_functions(model, vision = False):
|
|||
model.save_pretrained_merged = types.MethodType(unsloth_generic_save_pretrained_merged, model)
|
||||
model.push_to_hub_gguf = types.MethodType(save_to_gguf_generic, model)
|
||||
model.save_pretrained_gguf = types.MethodType(save_to_gguf_generic, model)
|
||||
model.save_pretrained_torchao = types.MethodType(unsloth_save_pretrained_torchao, model)
|
||||
pass
|
||||
return model
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue