mirror of
https://github.com/unslothai/unsloth
synced 2026-04-21 13:37:39 +00:00
tests for mxfp4 and quantized models merge fix unsloth zoo pr 254 (#3223)
This commit is contained in:
parent
25b21f4899
commit
711ec4a3ac
4 changed files with 367 additions and 0 deletions
18
tests/saving/gpt-oss-merge/run_test.sh
Executable file
18
tests/saving/gpt-oss-merge/run_test.sh
Executable file
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "================================================================"
|
||||
echo "🚀 STEP 1: Running the training and merging script..."
|
||||
echo "================================================================"
|
||||
python train_and_merge.py
|
||||
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo "✅ STEP 2: Training complete. Running the inference script..."
|
||||
echo "================================================================"
|
||||
python test_merged_model.py
|
||||
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo "🎉 All steps completed successfully!"
|
||||
echo "================================================================"
|
||||
55
tests/saving/gpt-oss-merge/test_merged_model.py
Normal file
55
tests/saving/gpt-oss-merge/test_merged_model.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
# inference_on_merged.py
|
||||
from unsloth import FastLanguageModel
|
||||
from transformers import TextStreamer
|
||||
import torch
|
||||
import gc
|
||||
import os
|
||||
import shutil
|
||||
|
||||
def safe_remove_directory(path):
|
||||
try:
|
||||
if os.path.exists(path) and os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
return True
|
||||
else:
|
||||
print(f"Path {path} is not a valid directory")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Failed to remove directory {path}: {e}")
|
||||
return False
|
||||
pass
|
||||
|
||||
print("🔥 Loading the 16-bit merged model from disk...")
|
||||
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="./gpt-oss-finetuned-merged",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
)
|
||||
print("✅ Merged model loaded successfully.")
|
||||
|
||||
# --- Run Inference ---
|
||||
print("\n🚀 Running inference...")
|
||||
messages = [
|
||||
{"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
|
||||
]
|
||||
inputs = merged_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = True,
|
||||
return_tensors = "pt",
|
||||
return_dict = True,
|
||||
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
|
||||
).to(merged_model.device)
|
||||
|
||||
_ = merged_model.generate(**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer))
|
||||
print("\n✅ Inference complete.")
|
||||
|
||||
# --- Final Cleanup ---
|
||||
print("\n🧹 Cleaning up merged model directory and cache...")
|
||||
del merged_model, merged_tokenizer
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
safe_remove_directory("./gpt-oss-finetuned-merged")
|
||||
safe_remove_directory("./unsloth_compiled_cache") # Clean up cache created by this process
|
||||
print("✅ Final cleanup complete. Exiting inference script.")
|
||||
71
tests/saving/gpt-oss-merge/train_and_merge.py
Normal file
71
tests/saving/gpt-oss-merge/train_and_merge.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
# train_and_merge.py
|
||||
from unsloth import FastLanguageModel
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import gc
|
||||
import os
|
||||
import shutil
|
||||
|
||||
def safe_remove_directory(path):
|
||||
try:
|
||||
if os.path.exists(path) and os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
return True
|
||||
else:
|
||||
print(f"Path {path} is not a valid directory")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Failed to remove directory {path}: {e}")
|
||||
return False
|
||||
pass
|
||||
|
||||
# This tokenizer will be used by the mapping function
|
||||
tokenizer = None
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["messages"]
|
||||
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
|
||||
return {"text": texts}
|
||||
|
||||
# --- Load 4-bit Model and Train ---
|
||||
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
|
||||
max_seq_length = 1024
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
"unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
|
||||
)
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
|
||||
formatting_prompts_func, batched=True
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model, r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=16, use_gradient_checkpointing="unsloth", random_state=3407,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model, tokenizer=tokenizer, train_dataset=dataset,
|
||||
args=SFTConfig(
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=10,
|
||||
learning_rate=2e-4, output_dir="outputs", report_to="none",
|
||||
),
|
||||
)
|
||||
|
||||
print("Starting fine-tuning...")
|
||||
trainer.train()
|
||||
print("Fine-tuning complete.")
|
||||
|
||||
# --- Merge and Save ---
|
||||
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
|
||||
model.save_pretrained_merged(save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer)
|
||||
print("✅ Model merged and saved.")
|
||||
|
||||
# --- Cleanup ---
|
||||
print("\n🧹 Cleaning up training artifacts...")
|
||||
del model, trainer, tokenizer, dataset
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
safe_remove_directory("./outputs")
|
||||
safe_remove_directory("./unsloth_compiled_cache") # Clean up the cache created by this process
|
||||
print("✅ Cleanup complete. Exiting training script.")
|
||||
223
tests/saving/language_models/test_merge_4bit_validation.py
Normal file
223
tests/saving/language_models/test_merge_4bit_validation.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import DataCollatorForSeq2Seq, TrainingArguments
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).parents[3]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from tests.utils.cleanup_utils import safe_remove_directory
|
||||
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["messages"]
|
||||
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
|
||||
return {"text": texts}
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 PHASE 1: Loading Base Model and Initial Training")
|
||||
print(f"{'='*80}")
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
compute_dtype = torch.bfloat16
|
||||
attn_implementation = 'flash_attention_2'
|
||||
else:
|
||||
compute_dtype = torch.float16
|
||||
attn_implementation = 'sdpa'
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/Llama-3.1-8B-Instruct",
|
||||
max_seq_length=2048,
|
||||
dtype=compute_dtype,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
full_finetuning=False,
|
||||
attn_implementation=attn_implementation
|
||||
)
|
||||
|
||||
tokenizer = get_chat_template(
|
||||
tokenizer,
|
||||
chat_template="llama-3.1",
|
||||
)
|
||||
|
||||
# Load small dataset for quick training
|
||||
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train[:100]")
|
||||
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
|
||||
|
||||
print("✅ Base model loaded successfully!")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 PHASE 2: First Fine-tuning")
|
||||
print(f"{'='*80}")
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
)
|
||||
|
||||
from unsloth import is_bfloat16_supported
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=10, # Very short training for test
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=5,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs",
|
||||
report_to="none",
|
||||
),
|
||||
)
|
||||
|
||||
trainer_stats = trainer.train()
|
||||
print("✅ First fine-tuning completed!")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 PHASE 3: Save with Forced 4bit Merge")
|
||||
print(f"{'='*80}")
|
||||
|
||||
model.save_pretrained_merged(
|
||||
save_directory='./test_4bit_model',
|
||||
tokenizer=tokenizer,
|
||||
save_method="forced_merged_4bit"
|
||||
)
|
||||
|
||||
print("✅ Model saved with forced 4bit merge!")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning")
|
||||
print(f"{'='*80}")
|
||||
|
||||
# Clean up first model
|
||||
del model
|
||||
del tokenizer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load the 4bit merged model
|
||||
model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
|
||||
model_name="./test_4bit_model",
|
||||
max_seq_length=2048,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
)
|
||||
|
||||
tokenizer_4bit = get_chat_template(
|
||||
tokenizer_4bit,
|
||||
chat_template="llama-3.1",
|
||||
)
|
||||
|
||||
print("✅ 4bit model loaded successfully!")
|
||||
|
||||
# Add LoRA adapters to the 4bit model
|
||||
model_4bit = FastLanguageModel.get_peft_model(
|
||||
model_4bit,
|
||||
r=16,
|
||||
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
loftq_config=None,
|
||||
)
|
||||
|
||||
# Second fine-tuning
|
||||
trainer_4bit = SFTTrainer(
|
||||
model=model_4bit,
|
||||
tokenizer=tokenizer_4bit,
|
||||
train_dataset=dataset_train,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_ratio=0.1,
|
||||
max_steps=10, # Very short training for test
|
||||
learning_rate=2e-4,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=5,
|
||||
optim="adamw_8bit",
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
output_dir="outputs_4bit",
|
||||
report_to="none",
|
||||
),
|
||||
)
|
||||
|
||||
trainer_4bit.train()
|
||||
print("✅ Second fine-tuning on 4bit model completed!")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
|
||||
print(f"{'='*80}")
|
||||
|
||||
try:
|
||||
model_4bit.save_pretrained_merged(
|
||||
save_directory='./test_should_fail',
|
||||
tokenizer=tokenizer_4bit
|
||||
# No save_method specified, should default to regular merge
|
||||
)
|
||||
assert False, "Expected TypeError but merge succeeded!"
|
||||
except TypeError as e:
|
||||
expected_error = "Base model should be a 16bits or mxfp4 base model for a 16bit model merge. Use `save_method=forced_merged_4bit` instead"
|
||||
assert expected_error in str(e), f"Unexpected error message: {str(e)}"
|
||||
print("✅ Correct TypeError raised for 4bit base model regular merge attempt!")
|
||||
print(f"Error message: {str(e)}")
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 PHASE 6: Successful Save with Forced 4bit Method")
|
||||
print(f"{'='*80}")
|
||||
|
||||
try:
|
||||
model_4bit.save_pretrained_merged(
|
||||
save_directory='./test_4bit_second',
|
||||
tokenizer=tokenizer_4bit,
|
||||
save_method="forced_merged_4bit"
|
||||
)
|
||||
print("✅ Successfully saved 4bit model with forced 4bit method!")
|
||||
except Exception as e:
|
||||
assert False, f"Phase 6 failed unexpectedly: {e}"
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔍 CLEANUP")
|
||||
print(f"{'='*80}")
|
||||
|
||||
# Cleanup
|
||||
safe_remove_directory("./outputs")
|
||||
safe_remove_directory("./outputs_4bit")
|
||||
safe_remove_directory("./unsloth_compiled_cache")
|
||||
safe_remove_directory("./test_4bit_model")
|
||||
safe_remove_directory("./test_4bit_second")
|
||||
safe_remove_directory("./test_should_fail")
|
||||
|
||||
print("✅ All tests passed successfully!")
|
||||
Loading…
Reference in a new issue