tests for mxfp4 and quantized models merge fix unsloth zoo pr 254 (#3223)

This commit is contained in:
Roland Tannous 2025-08-29 11:30:48 +03:00 committed by GitHub
parent 25b21f4899
commit 711ec4a3ac
4 changed files with 367 additions and 0 deletions

View file

@ -0,0 +1,18 @@
#!/bin/bash
set -e
echo "================================================================"
echo "🚀 STEP 1: Running the training and merging script..."
echo "================================================================"
python train_and_merge.py
echo ""
echo "================================================================"
echo "✅ STEP 2: Training complete. Running the inference script..."
echo "================================================================"
python test_merged_model.py
echo ""
echo "================================================================"
echo "🎉 All steps completed successfully!"
echo "================================================================"

View file

@ -0,0 +1,55 @@
# inference_on_merged.py
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch
import gc
import os
import shutil
def safe_remove_directory(path):
try:
if os.path.exists(path) and os.path.isdir(path):
shutil.rmtree(path)
return True
else:
print(f"Path {path} is not a valid directory")
return False
except Exception as e:
print(f"Failed to remove directory {path}: {e}")
return False
pass
print("🔥 Loading the 16-bit merged model from disk...")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./gpt-oss-finetuned-merged",
max_seq_length=1024,
load_in_4bit=True,
load_in_8bit=False,
)
print("✅ Merged model loaded successfully.")
# --- Run Inference ---
print("\n🚀 Running inference...")
messages = [
{"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
]
inputs = merged_tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
return_tensors = "pt",
return_dict = True,
reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
).to(merged_model.device)
_ = merged_model.generate(**inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer))
print("\n✅ Inference complete.")
# --- Final Cleanup ---
print("\n🧹 Cleaning up merged model directory and cache...")
del merged_model, merged_tokenizer
torch.cuda.empty_cache()
gc.collect()
safe_remove_directory("./gpt-oss-finetuned-merged")
safe_remove_directory("./unsloth_compiled_cache") # Clean up cache created by this process
print("✅ Final cleanup complete. Exiting inference script.")

View file

@ -0,0 +1,71 @@
# train_and_merge.py
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import torch
import gc
import os
import shutil
def safe_remove_directory(path):
try:
if os.path.exists(path) and os.path.isdir(path):
shutil.rmtree(path)
return True
else:
print(f"Path {path} is not a valid directory")
return False
except Exception as e:
print(f"Failed to remove directory {path}: {e}")
return False
pass
# This tokenizer will be used by the mapping function
tokenizer = None
def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return {"text": texts}
# --- Load 4-bit Model and Train ---
print("Loading 4-bit Mxfp4 gpt-oss model for training...")
max_seq_length = 1024
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/gpt-oss-20b", max_seq_length=max_seq_length, load_in_4bit=True
)
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train[:50]").map(
formatting_prompts_func, batched=True
)
model = FastLanguageModel.get_peft_model(
model, r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16, use_gradient_checkpointing="unsloth", random_state=3407,
)
trainer = SFTTrainer(
model=model, tokenizer=tokenizer, train_dataset=dataset,
args=SFTConfig(
per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=10,
learning_rate=2e-4, output_dir="outputs", report_to="none",
),
)
print("Starting fine-tuning...")
trainer.train()
print("Fine-tuning complete.")
# --- Merge and Save ---
print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
model.save_pretrained_merged(save_directory="./gpt-oss-finetuned-merged", tokenizer=tokenizer)
print("✅ Model merged and saved.")
# --- Cleanup ---
print("\n🧹 Cleaning up training artifacts...")
del model, trainer, tokenizer, dataset
torch.cuda.empty_cache()
gc.collect()
safe_remove_directory("./outputs")
safe_remove_directory("./unsloth_compiled_cache") # Clean up the cache created by this process
print("✅ Cleanup complete. Exiting training script.")

View file

@ -0,0 +1,223 @@
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from trl import SFTTrainer, SFTConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from datasets import load_dataset
import torch
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).parents[3]
sys.path.insert(0, str(REPO_ROOT))
from tests.utils.cleanup_utils import safe_remove_directory
def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return {"text": texts}
print(f"\n{'='*80}")
print("🔍 PHASE 1: Loading Base Model and Initial Training")
print(f"{'='*80}")
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
attn_implementation = 'flash_attention_2'
else:
compute_dtype = torch.float16
attn_implementation = 'sdpa'
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation
)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
)
# Load small dataset for quick training
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train[:100]")
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
print("✅ Base model loaded successfully!")
print(f"\n{'='*80}")
print("🔍 PHASE 2: First Fine-tuning")
print(f"{'='*80}")
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10, # Very short training for test
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=5,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
),
)
trainer_stats = trainer.train()
print("✅ First fine-tuning completed!")
print(f"\n{'='*80}")
print("🔍 PHASE 3: Save with Forced 4bit Merge")
print(f"{'='*80}")
model.save_pretrained_merged(
save_directory='./test_4bit_model',
tokenizer=tokenizer,
save_method="forced_merged_4bit"
)
print("✅ Model saved with forced 4bit merge!")
print(f"\n{'='*80}")
print("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning")
print(f"{'='*80}")
# Clean up first model
del model
del tokenizer
torch.cuda.empty_cache()
# Load the 4bit merged model
model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
model_name="./test_4bit_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
)
tokenizer_4bit = get_chat_template(
tokenizer_4bit,
chat_template="llama-3.1",
)
print("✅ 4bit model loaded successfully!")
# Add LoRA adapters to the 4bit model
model_4bit = FastLanguageModel.get_peft_model(
model_4bit,
r=16,
target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
# Second fine-tuning
trainer_4bit = SFTTrainer(
model=model_4bit,
tokenizer=tokenizer_4bit,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_4bit),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=10, # Very short training for test
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=5,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs_4bit",
report_to="none",
),
)
trainer_4bit.train()
print("✅ Second fine-tuning on 4bit model completed!")
print(f"\n{'='*80}")
print("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
print(f"{'='*80}")
try:
model_4bit.save_pretrained_merged(
save_directory='./test_should_fail',
tokenizer=tokenizer_4bit
# No save_method specified, should default to regular merge
)
assert False, "Expected TypeError but merge succeeded!"
except TypeError as e:
expected_error = "Base model should be a 16bits or mxfp4 base model for a 16bit model merge. Use `save_method=forced_merged_4bit` instead"
assert expected_error in str(e), f"Unexpected error message: {str(e)}"
print("✅ Correct TypeError raised for 4bit base model regular merge attempt!")
print(f"Error message: {str(e)}")
print(f"\n{'='*80}")
print("🔍 PHASE 6: Successful Save with Forced 4bit Method")
print(f"{'='*80}")
try:
model_4bit.save_pretrained_merged(
save_directory='./test_4bit_second',
tokenizer=tokenizer_4bit,
save_method="forced_merged_4bit"
)
print("✅ Successfully saved 4bit model with forced 4bit method!")
except Exception as e:
assert False, f"Phase 6 failed unexpectedly: {e}"
print(f"\n{'='*80}")
print("🔍 CLEANUP")
print(f"{'='*80}")
# Cleanup
safe_remove_directory("./outputs")
safe_remove_directory("./outputs_4bit")
safe_remove_directory("./unsloth_compiled_cache")
safe_remove_directory("./test_4bit_model")
safe_remove_directory("./test_4bit_second")
safe_remove_directory("./test_should_fail")
print("✅ All tests passed successfully!")