Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit cad158a56c.
This commit is contained in:
Daniel Han 2025-12-01 07:24:58 -08:00
parent cad158a56c
commit 66649d18bd
42 changed files with 2394 additions and 2394 deletions

View file

@ -54,37 +54,37 @@ if __name__ == "__main__":
seed = 42
batch_size = 5
num_generations = 5
tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer])
tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
temperature = 0.8
max_new_tokens = 20
peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear")
model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config)
peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)
prompt = tokenizer.apply_chat_template(
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
)
with header_footer_context("Test Prompt and Answer"):
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
dataset: Dataset = create_dataset(
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
)
with header_footer_context("Dataset"):
print(f"Dataset: {next(iter(dataset))}")
training_args = SFTConfig(
output_dir=output_dir,
max_steps=max_steps,
per_device_train_batch_size=batch_size,
log_level="info",
report_to="none",
num_train_epochs=1,
logging_steps=1,
seed=seed,
bf16=dtype == torch.bfloat16,
fp16=dtype == torch.float16,
save_strategy="no",
output_dir = output_dir,
max_steps = max_steps,
per_device_train_batch_size = batch_size,
log_level = "info",
report_to = "none",
num_train_epochs = 1,
logging_steps = 1,
seed = seed,
bf16 = dtype == torch.bfloat16,
fp16 = dtype == torch.float16,
save_strategy = "no",
)
with header_footer_context("Train Args"):
@ -92,7 +92,7 @@ if __name__ == "__main__":
print(peft_config)
trainer = setup_trainer(
model, tokenizer, dataset, training_args, peft_config=peft_config
model, tokenizer, dataset, training_args, peft_config = peft_config
)
with header_footer_context("Model"):
@ -108,11 +108,11 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses before training"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)
with header_footer_context("Peft Weights before training"):
for name, stats in itertools.islice(describe_peft_weights(model), 2):
@ -129,11 +129,11 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after training"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)
model_copy = deepcopy(model)
@ -142,18 +142,18 @@ if __name__ == "__main__":
responses = sample_responses(
merged_model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after custom merging to 16bit"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)
merged_model_peft = model_copy.merge_and_unload()
responses = sample_responses(
merged_model_peft,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after peft merge_and_unload"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)

View file

@ -50,13 +50,13 @@ def get_unsloth_model_and_tokenizer(
dtype: torch.dtype = torch.bfloat16,
):
return FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
fast_inference=fast_inference,
max_lora_rank=max_lora_rank,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
model_name = model_name,
max_seq_length = max_seq_length,
load_in_4bit = load_in_4bit,
fast_inference = fast_inference,
max_lora_rank = max_lora_rank,
gpu_memory_utilization = gpu_memory_utilization,
dtype = dtype,
)
@ -69,11 +69,11 @@ def get_unsloth_peft_model(
):
return FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=target_modules,
lora_alpha=lora_rank,
use_gradient_checkpointing=use_gradient_checkpointing,
random_state=random_state,
r = lora_rank,
target_modules = target_modules,
lora_alpha = lora_rank,
use_gradient_checkpointing = use_gradient_checkpointing,
random_state = random_state,
)
@ -101,48 +101,48 @@ if __name__ == "__main__":
model, tokenizer = get_unsloth_model_and_tokenizer(
model_name,
max_seq_length=512,
load_in_4bit=True,
fast_inference=False,
max_lora_rank=lora_rank,
dtype=dtype,
max_seq_length = 512,
load_in_4bit = True,
fast_inference = False,
max_lora_rank = lora_rank,
dtype = dtype,
)
temperature = 0.8
max_new_tokens = 20
model = get_unsloth_peft_model(
model,
lora_rank=lora_rank,
target_modules=target_modules,
use_gradient_checkpointing=gradient_checkpointing,
random_state=seed,
lora_rank = lora_rank,
target_modules = target_modules,
use_gradient_checkpointing = gradient_checkpointing,
random_state = seed,
)
prompt = tokenizer.apply_chat_template(
[USER_MESSAGE], tokenize=False, add_generation_prompt=True
[USER_MESSAGE], tokenize = False, add_generation_prompt = True
)
with header_footer_context("Test Prompt and Answer"):
print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
dataset: Dataset = create_dataset(
tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES
tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
)
with header_footer_context("Dataset"):
print(f"Dataset: {next(iter(dataset))}")
training_args = SFTConfig(
output_dir=output_dir,
max_steps=max_steps,
per_device_train_batch_size=batch_size,
log_level="info",
report_to="none",
num_train_epochs=1,
logging_steps=1,
seed=seed,
bf16=dtype == torch.bfloat16,
fp16=dtype == torch.float16,
save_strategy="no",
output_dir = output_dir,
max_steps = max_steps,
per_device_train_batch_size = batch_size,
log_level = "info",
report_to = "none",
num_train_epochs = 1,
logging_steps = 1,
seed = seed,
bf16 = dtype == torch.bfloat16,
fp16 = dtype == torch.float16,
save_strategy = "no",
)
with header_footer_context("Train Args"):
@ -163,11 +163,11 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses before training"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)
with header_footer_context("Peft Weights before training"):
for name, stats in itertools.islice(describe_peft_weights(model), 2):
print(f"{name}:\n{stats}")
@ -183,29 +183,29 @@ if __name__ == "__main__":
responses = sample_responses(
model,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after training"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)
model.save_pretrained_merged(
unsloth_merged_path,
tokenizer,
save_method="merged_16bit",
save_method = "merged_16bit",
)
merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(
unsloth_merged_path,
max_seq_length=512,
load_in_4bit=False,
fast_inference=False,
dtype=dtype,
max_seq_length = 512,
load_in_4bit = False,
fast_inference = False,
dtype = dtype,
)
responses = sample_responses(
merged_model_unsloth,
tokenizer,
prompt=prompt,
prompt = prompt,
**generation_args,
)
with header_footer_context("Responses after unsloth merge to 16bit"):
check_responses(responses, answer=ANSWER, prompt=prompt)
check_responses(responses, answer = ANSWER, prompt = prompt)

View file

@ -78,17 +78,17 @@ def formatting_prompts_func(examples):
}
def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=False):
def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
"""Load model and compute perplexity in subprocess"""
from unsloth import FastLanguageModel
from tests.utils.perplexity_eval import ppl_model
# Load model
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_qwen_text_model",
max_seq_length=2048,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
model_name = "./unsloth_out/merged_qwen_text_model",
max_seq_length = 2048,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
)
# Set up tokenizer
# merged_tokenizer = get_chat_template(
@ -98,7 +98,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Load dataset fresh in subprocess
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@ -146,7 +146,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
"text": texts,
}
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
# Compute perplexity using the passed dataset
ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
@ -172,7 +172,7 @@ def load_and_compute_8bit_ppl(result_queue, load_in_4bit=False, load_in_8bit=Fal
# Main execution code should be wrapped in this guard
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force = True)
if torch.cuda.is_bf16_supported():
compute_dtype = torch.bfloat16
@ -182,31 +182,31 @@ if __name__ == "__main__":
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Qwen2.5-7B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Qwen2.5-7B-Instruct",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
dataset_train = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="train"
"allenai/openassistant-guanaco-reformatted", split = "train"
)
dataset_ppl = load_dataset(
"allenai/openassistant-guanaco-reformatted", split="eval"
"allenai/openassistant-guanaco-reformatted", split = "eval"
)
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -215,40 +215,40 @@ if __name__ == "__main__":
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=200,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 200,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -260,7 +260,7 @@ if __name__ == "__main__":
# saving and merging the model to local disk
print("merge and save to local disk")
model.save_pretrained_merged(
save_directory="./unsloth_out/merged_qwen_text_model", tokenizer=tokenizer
save_directory = "./unsloth_out/merged_qwen_text_model", tokenizer = tokenizer
)
# print("cleaning")
@ -272,10 +272,10 @@ if __name__ == "__main__":
# load model from local disk and test
print("Loading merged model in 4 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_qwen_text_model",
max_seq_length=2048,
load_in_4bit=True,
load_in_8bit=False,
model_name = "./unsloth_out/merged_qwen_text_model",
max_seq_length = 2048,
load_in_4bit = True,
load_in_8bit = False,
)
add_to_comparison(
@ -284,7 +284,7 @@ if __name__ == "__main__":
print("Computing 8-bit model perplexity in subprocess...")
result_queue = mp.Queue()
p = mp.Process(target=load_and_compute_8bit_ppl, args=(result_queue, False, True))
p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
p.start()
p.join()
@ -293,10 +293,10 @@ if __name__ == "__main__":
print("Loading merged model in 16 bit for perplexity test")
merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
model_name="./unsloth_out/merged_qwen_text_model",
max_seq_length=2048,
load_in_4bit=False,
load_in_8bit=False,
model_name = "./unsloth_out/merged_qwen_text_model",
max_seq_length = 2048,
load_in_4bit = False,
load_in_8bit = False,
)
add_to_comparison(

View file

@ -37,7 +37,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@ -52,34 +52,34 @@ else:
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.2-1B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Llama-3.2-1B-Instruct",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
from unsloth.chat_templates import standardize_sharegpt
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model(
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=30,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 30,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# run training
@ -160,7 +160,7 @@ try:
print("\n" + "=" * 80)
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:

View file

@ -37,7 +37,7 @@ def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@ -52,34 +52,34 @@ else:
attn_implementation = "sdpa"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.1-8B-Instruct",
max_seq_length=2048,
dtype=compute_dtype,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
attn_implementation=attn_implementation,
model_name = "unsloth/Llama-3.1-8B-Instruct",
max_seq_length = 2048,
dtype = compute_dtype,
load_in_4bit = True,
load_in_8bit = False,
full_finetuning = False,
attn_implementation = attn_implementation,
)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
from unsloth.chat_templates import standardize_sharegpt
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split="train")
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split="eval")
dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
dataset_train = dataset_train.map(formatting_prompts_func, batched=True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched=True)
dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
r = 16,
target_modules = [
"k_proj",
"q_proj",
"v_proj",
@ -88,40 +88,40 @@ model = FastLanguageModel.get_peft_model(
"down_proj",
"up_proj",
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset_train,
dataset_text_field="text",
max_seq_length=2048,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.1,
max_steps=30,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=50,
optim="adamw_8bit",
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
model = model,
tokenizer = tokenizer,
train_dataset = dataset_train,
dataset_text_field = "text",
max_seq_length = 2048,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False,
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_ratio = 0.1,
max_steps = 30,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 50,
optim = "adamw_8bit",
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none",
),
)
@ -129,8 +129,8 @@ from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# run training
@ -161,7 +161,7 @@ try:
print("\n" + "=" * 80)
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:
@ -173,8 +173,8 @@ try:
print("\n" + "=" * 80)
print("=== VERIFYING REPO CONTENTS ===".center(80))
print("=" * 80 + "\n")
fs = HfFileSystem(token=hf_token)
file_list = fs.ls(repo_name, detail=True)
fs = HfFileSystem(token = hf_token)
file_list = fs.ls(repo_name, detail = True)
safetensors_found = any(
file["name"].endswith("model.safetensors.index.json") for file in file_list
)

View file

@ -24,7 +24,7 @@ max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
def evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False):
from unsloth import FastLanguageModel
from tests.utils.aime_eval import evaluate_model_aime
@ -32,12 +32,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="./final_merged_model",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.8, # Reduce if out of memory
model_name = "./final_merged_model",
max_seq_length = max_seq_length,
load_in_4bit = True, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.8, # Reduce if out of memory
)
print(f"\n{'='*60}")
@ -53,14 +53,14 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
print(f"{'='*60}")
evaluate_model_aime(
model=model,
tokenizer=tokenizer,
model_type=model_type,
temperature=0.3,
n_sampling=8,
max_tokens=32768,
top_p=0.95,
seed=0,
model = model,
tokenizer = tokenizer,
model_type = model_type,
temperature = 0.3,
n_sampling = 8,
max_tokens = 32768,
top_p = 0.95,
seed = 0,
)
result_queue.put(results)
@ -74,12 +74,12 @@ def evaluate_merged_model(result_queue, load_in_4bit=False, load_in_8bit=False):
# Main execution code should be wrapped in this guard
def training_run(result_queue):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/Llama-3.2-3B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=False, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.8, # Reduce if out of memory
model_name = "meta-llama/Llama-3.2-3B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = False, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.8, # Reduce if out of memory
)
"""### Helper Functions
@ -166,10 +166,10 @@ def training_run(result_queue):
lengths = dataset.map(
lambda x: {
"tokens": tokenizer.apply_chat_template(
x["prompt"], add_generation_prompt=True, tokenize=True
x["prompt"], add_generation_prompt = True, tokenize = True
)
},
batched=True,
batched = True,
).map(lambda x: {"length": len(x["tokens"])})["length"]
max_length = max(lengths)
@ -181,7 +181,7 @@ def training_run(result_queue):
)
return max_length, avg_length
def extract_unsloth_answer(text, start_tag="<SOLUTION>", end_tag="</SOLUTION>"):
def extract_unsloth_answer(text, start_tag = "<SOLUTION>", end_tag = "</SOLUTION>"):
"""Extract answer from Unsloth SOLUTION tags"""
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
matches = re.findall(pattern, text, re.DOTALL)
@ -213,10 +213,10 @@ def training_run(result_queue):
"""Count tokens in text"""
if not text:
return 0
encoding = tokenizer_instance(text, return_tensors="pt")
encoding = tokenizer_instance(text, return_tensors = "pt")
return len(encoding["input_ids"][0])
def check_format_compliance(text, format_type="unsloth"):
def check_format_compliance(text, format_type = "unsloth"):
"""Check if response follows expected format"""
if format_type == "unsloth":
reasoning_start = "<start_reasoning>"
@ -419,11 +419,11 @@ def training_run(result_queue):
# Save comparison
comparison_data = {
"summary": all_results,
"best_model": max(all_results, key=lambda x: x["exact_match_pct"]),
"best_model": max(all_results, key = lambda x: x["exact_match_pct"]),
}
with open("model_comparison_comprehensive.json", "w") as f:
json.dump(comparison_data, f, indent=4)
json.dump(comparison_data, f, indent = 4)
print(
f"\nBest performing model: {comparison_data['best_model']['model_type']} "
@ -449,10 +449,10 @@ def training_run(result_queue):
from datasets import load_dataset
# Load GSM8K
gsm8k_dataset = load_dataset("openai/gsm8k", "main", split="train")
gsm8k_dataset = load_dataset("openai/gsm8k", "main", split = "train")
# Load LIMO (adjust this based on your access method)
limo_train = load_dataset("GAIR/LIMO", split="train")
limo_train = load_dataset("GAIR/LIMO", split = "train")
# Prepare datasets
gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset)
@ -466,28 +466,28 @@ def training_run(result_queue):
# Single temperature evaluation on combined dataset
results = evaluate_model_aime(
model=model,
tokenizer=tokenizer,
model_type="base",
temperature=0.3,
n_sampling=8,
max_tokens=32768,
top_p=0.95,
seed=0,
model = model,
tokenizer = tokenizer,
model_type = "base",
temperature = 0.3,
n_sampling = 8,
max_tokens = 32768,
top_p = 0.95,
seed = 0,
)
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
chat_template = "llama-3.1",
)
def formatting_prompts_func(examples):
convos = examples["prompt"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
convo, tokenize = False, add_generation_prompt = False
)
for convo in convos
]
@ -497,7 +497,7 @@ def training_run(result_queue):
limo_train = limo_train.map(
formatting_prompts_func,
batched=True,
batched = True,
)
from trl import SFTTrainer
@ -510,8 +510,8 @@ def training_run(result_queue):
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj",
"k_proj",
"v_proj",
@ -520,37 +520,37 @@ def training_run(result_queue):
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
random_state = 3407,
)
if limo_train is not None:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=limo_train,
dataset_text_field="text",
max_seq_length=max_seq_length,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False, # Can make training 5x faster for short sequences.
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
num_train_epochs=1, # Set this for 1 full training run.
model = model,
tokenizer = tokenizer,
train_dataset = limo_train,
dataset_text_field = "text",
max_seq_length = max_seq_length,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
num_train_epochs = 1, # Set this for 1 full training run.
# max_steps = 60,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none", # Use this for WandB etc
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
)
@ -558,8 +558,8 @@ def training_run(result_queue):
trainer = train_on_responses_only(
trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
# Train
@ -588,7 +588,7 @@ def training_run(result_queue):
PRINT_EVERY_STEPS = 5
match_numbers = re.compile(
solution_start + r".*?([\d\.\,]{1,})", flags=re.MULTILINE | re.DOTALL
solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL
)
def check_numbers(prompts, completions, answer, **kwargs):
@ -642,37 +642,37 @@ def training_run(result_queue):
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
learning_rate=5e-6,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="adamw_torch_fused",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=max_prompt_length,
max_completion_length=max_seq_length - max_prompt_length,
learning_rate = 5e-6,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "adamw_torch_fused",
logging_steps = 1,
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4, # Increase to 4 for smoother training
num_generations = 8, # Decrease if out of memory
max_prompt_length = max_prompt_length,
max_completion_length = max_seq_length - max_prompt_length,
# num_train_epochs = 1, # Set to 1 for a full training run
# max_steps = 250,
max_steps=1000,
save_steps=250,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="outputs",
max_steps = 1000,
save_steps = 250,
max_grad_norm = 0.1,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
model = model,
processing_class = tokenizer,
reward_funcs = [
match_format_exactly,
match_format_approximately,
check_answer_correctness,
check_numbers,
],
args=training_args,
train_dataset=gsm8k_train,
args = training_args,
train_dataset = gsm8k_train,
)
# Train
@ -696,14 +696,14 @@ def training_run(result_queue):
print(f"{'='*60}")
grpo_results = evaluate_model_aime(
model=model,
tokenizer=tokenizer,
model_type="grpo",
temperature=0.3,
n_sampling=8,
max_tokens=32768,
top_p=0.95,
seed=0,
model = model,
tokenizer = tokenizer,
model_type = "grpo",
temperature = 0.3,
n_sampling = 8,
max_tokens = 32768,
top_p = 0.95,
seed = 0,
)
all_results.append(grpo_results)
@ -716,7 +716,7 @@ def training_run(result_queue):
# Save as merged model
try:
model.save_pretrained_merged(
"final_merged_model", tokenizer, save_method="merged_16bit"
"final_merged_model", tokenizer, save_method = "merged_16bit"
)
print("✅ Merged model saved to: final_merged_model/")
except Exception as e:
@ -774,12 +774,12 @@ def training_run(result_queue):
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force = True)
result_queue = mp.Queue()
all_results = []
# run main finetuning and grpo loop
p = mp.Process(target=training_run, args=(result_queue,))
p = mp.Process(target = training_run, args = (result_queue,))
p.start()
p.join()
@ -787,7 +787,7 @@ if __name__ == "__main__":
all_results = results
# evaluate merged model loaded 16bits
p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))
p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False))
p.start()
p.join()
@ -796,7 +796,7 @@ if __name__ == "__main__":
safe_remove_directory("./unsloth_compiled_cache")
# Merged model load 8 bits model AIME eval
p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True))
p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True))
p.start()
p.join()
@ -806,7 +806,7 @@ if __name__ == "__main__":
safe_remove_directory("./unsloth_compiled_cache")
# Merged model load 4 bits model AIME eval
p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))
p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False))
p.start()
p.join()

View file

@ -43,31 +43,31 @@ tokenizer_files = [
]
@pytest.fixture(scope="session", params=model_to_test)
@pytest.fixture(scope = "session", params = model_to_test)
def loaded_model_tokenizer(request):
model_name = request.param
print("Loading model and tokenizer...")
model, tokenizer = FastModel.from_pretrained(
model_name, # use small model
max_seq_length=128,
dtype=None,
load_in_4bit=True,
max_seq_length = 128,
dtype = None,
load_in_4bit = True,
)
# Apply LoRA
model = FastModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=16,
use_gradient_checkpointing="unsloth",
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha = 16,
use_gradient_checkpointing = "unsloth",
)
return model, tokenizer
@pytest.fixture(scope="session", params=torchao_models)
@pytest.fixture(scope = "session", params = torchao_models)
def fp16_model_tokenizer(request):
"""Load model in FP16 for TorchAO quantization"""
model_name = request.param
@ -75,29 +75,29 @@ def fp16_model_tokenizer(request):
model, tokenizer = FastModel.from_pretrained(
model_name,
max_seq_length=128,
dtype=None,
load_in_4bit=False, # No BnB quantization
max_seq_length = 128,
dtype = None,
load_in_4bit = False, # No BnB quantization
)
# Apply LoRA
model = FastModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=16,
use_gradient_checkpointing="unsloth",
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha = 16,
use_gradient_checkpointing = "unsloth",
)
return model, tokenizer
@pytest.fixture(scope="session")
@pytest.fixture(scope = "session")
def model(loaded_model_tokenizer):
return loaded_model_tokenizer[0]
@pytest.fixture(scope="session")
@pytest.fixture(scope = "session")
def tokenizer(loaded_model_tokenizer):
return loaded_model_tokenizer[1]
@ -133,7 +133,7 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
)
model.save_pretrained_merged(
save_path, tokenizer=tokenizer, save_method="merged_16bit"
save_path, tokenizer = tokenizer, save_method = "merged_16bit"
)
# Check model files
@ -172,9 +172,9 @@ def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
# Test loading the model from the saved path
loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
save_path,
max_seq_length=128,
dtype=None,
load_in_4bit=True,
max_seq_length = 128,
dtype = None,
load_in_4bit = True,
)
@ -186,7 +186,7 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
)
model.save_pretrained_merged(
save_path, tokenizer=tokenizer, save_method="merged_4bit_forced"
save_path, tokenizer = tokenizer, save_method = "merged_4bit_forced"
)
# Check model files
@ -230,15 +230,15 @@ def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
# Test loading the model from the saved path
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
save_path,
max_seq_length=128,
dtype=None,
load_in_4bit=True,
max_seq_length = 128,
dtype = None,
load_in_4bit = True,
)
@pytest.mark.skipif(
importlib.util.find_spec("torchao") is None,
reason="require torchao to be installed",
reason = "require torchao to be installed",
)
def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
model, tokenizer = fp16_model_tokenizer
@ -251,9 +251,9 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
torchao_config = Int8DynamicActivationInt8WeightConfig()
model.save_pretrained_torchao(
save_path,
tokenizer=tokenizer,
torchao_config=torchao_config,
push_to_hub=False,
tokenizer = tokenizer,
torchao_config = torchao_config,
push_to_hub = False,
)
weight_files_16bit = [
@ -316,15 +316,15 @@ def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
with torch.serialization.safe_globals([getattr]):
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
torchao_save_path,
max_seq_length=128,
dtype=None,
load_in_4bit=False,
max_seq_length = 128,
dtype = None,
load_in_4bit = False,
)
@pytest.mark.skipif(
importlib.util.find_spec("torchao") is None,
reason="require torchao to be installed",
reason = "require torchao to be installed",
)
def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
model, tokenizer = fp16_model_tokenizer
@ -343,9 +343,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
# Save with TorchAO
model.save_pretrained_torchao(
save_path,
tokenizer=tokenizer,
torchao_config=torchao_config,
push_to_hub=False,
tokenizer = tokenizer,
torchao_config = torchao_config,
push_to_hub = False,
)
torchao_save_path = save_path + "-torchao"
@ -361,9 +361,9 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
with torch.serialization.safe_globals([getattr]):
loaded_model, loaded_tokenizer = FastModel.from_pretrained(
torchao_save_path,
max_seq_length=128,
dtype=None,
load_in_4bit=False,
max_seq_length = 128,
dtype = None,
load_in_4bit = False,
)
FastModel.for_inference(loaded_model) # Enable native 2x faster inference
@ -376,24 +376,24 @@ def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
]
inputs = loaded_tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True, # Must add for generation
return_tensors="pt",
tokenize = True,
add_generation_prompt = True, # Must add for generation
return_tensors = "pt",
).to("cuda")
outputs = loaded_model.generate( # ← Use loaded_model, not model
input_ids=inputs,
max_new_tokens=64,
use_cache=False, # Avoid cache issues
temperature=1.5,
min_p=0.1,
do_sample=True,
pad_token_id=loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
input_ids = inputs,
max_new_tokens = 64,
use_cache = False, # Avoid cache issues
temperature = 1.5,
min_p = 0.1,
do_sample = True,
pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
)
# Decode with the LOADED tokenizer
generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens=True)
generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True)
input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True)
response_part = generated_text[len(input_text) :].strip()
print(f"Input: {input_text}")

View file

@ -26,11 +26,11 @@ print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
model_name="unsloth/csm-1b",
max_seq_length=2048, # Choose any for long context!
dtype=None, # Leave as None for auto-detection
auto_model=CsmForConditionalGeneration,
load_in_4bit=False, # Select True for 4bit - reduces memory usage
model_name = "unsloth/csm-1b",
max_seq_length = 2048, # Choose any for long context!
dtype = None, # Leave as None for auto-detection
auto_model = CsmForConditionalGeneration,
load_in_4bit = False, # Select True for 4bit - reduces memory usage
)
@ -39,8 +39,8 @@ base_model_class = model.__class__.__name__
model = FastModel.get_peft_model(
model,
r=32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj",
"k_proj",
"v_proj",
@ -49,14 +49,14 @@ model = FastModel.get_peft_model(
"up_proj",
"down_proj",
],
lora_alpha=32,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print("✅ Model and LoRA adapters loaded successfully!")
@ -110,11 +110,11 @@ print(f"{'='*80}")
model, processor = FastModel.from_pretrained(
model_name="./csm",
max_seq_length=2048, # Choose any for long context!
dtype=None, # Leave as None for auto-detection
auto_model=CsmForConditionalGeneration,
load_in_4bit=False, # Select True for 4bit - reduces memory usage
model_name = "./csm",
max_seq_length = 2048, # Choose any for long context!
dtype = None, # Leave as None for auto-detection
auto_model = CsmForConditionalGeneration,
load_in_4bit = False, # Select True for 4bit - reduces memory usage
)
from transformers import AutoProcessor
@ -138,19 +138,19 @@ try:
"We just finished fine tuning a text to speech model... and it's pretty good!"
)
speaker_id = 0
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens = True).to("cuda")
audio_values = model.generate(
**inputs,
max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
max_new_tokens = 125, # 125 tokens is 10 seconds of audio, for longer speech increase this
# play with these parameters to get the best results
depth_decoder_temperature=0.6,
depth_decoder_top_k=0,
depth_decoder_top_p=0.9,
temperature=0.8,
top_k=50,
top_p=1.0,
depth_decoder_temperature = 0.6,
depth_decoder_top_k = 0,
depth_decoder_top_p = 0.9,
temperature = 0.8,
top_k = 50,
top_p = 1.0,
#########################################################
output_audio=True,
output_audio = True,
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)

View file

@ -42,10 +42,10 @@ print(f"{'='*80}")
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llasa-1B",
max_seq_length=max_seq_length,
dtype=None, # Select None for auto detection
load_in_4bit=False, # Choose True for 4bit which reduces memory
model_name = "unsloth/Llasa-1B",
max_seq_length = max_seq_length,
dtype = None, # Select None for auto detection
load_in_4bit = False, # Choose True for 4bit which reduces memory
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -54,16 +54,16 @@ base_model_class = model.__class__.__name__
model = FastLanguageModel.get_peft_model(
model,
r=128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=["q_proj", "v_proj"],
lora_alpha=128,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "v_proj"],
lora_alpha = 128,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print("✅ Model and LoRA adapters loaded successfully!")
@ -117,10 +117,10 @@ print(f"{'='*80}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="./lasa",
max_seq_length=max_seq_length,
dtype=None, # Select None for auto detection
load_in_4bit=False, # Choose True for 4bit which reduces memory
model_name = "./lasa",
max_seq_length = max_seq_length,
dtype = None, # Select None for auto detection
load_in_4bit = False, # Choose True for 4bit which reduces memory
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -166,7 +166,7 @@ def extract_speech_ids(speech_tokens_str):
# TTS start!
with torch.inference_mode():
with torch.amp.autocast("cuda", dtype=model.dtype):
with torch.amp.autocast("cuda", dtype = model.dtype):
formatted_text = (
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
)
@ -178,7 +178,7 @@ with torch.inference_mode():
]
input_ids = tokenizer.apply_chat_template(
chat, tokenize=True, return_tensors="pt", continue_final_message=True
chat, tokenize = True, return_tensors = "pt", continue_final_message = True
)
input_ids = input_ids.to("cuda")
@ -187,16 +187,16 @@ with torch.inference_mode():
# Generate the speech autoregressively
outputs = model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id=speech_end_id,
do_sample=True,
top_p=1.2, # Adjusts the diversity of generated content
temperature=1.2, # Controls randomness in output
max_length = 2048, # We trained our model with a max length of 2048
eos_token_id = speech_end_id,
do_sample = True,
top_p = 1.2, # Adjusts the diversity of generated content
temperature = 1.2, # Controls randomness in output
)
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1] : -1]
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)

View file

@ -28,12 +28,12 @@ print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
model_name="unsloth/whisper-large-v3",
dtype=None, # Leave as None for auto detection
load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
auto_model=WhisperForConditionalGeneration,
whisper_language="English",
whisper_task="transcribe",
model_name = "unsloth/whisper-large-v3",
dtype = None, # Leave as None for auto detection
load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
auto_model = WhisperForConditionalGeneration,
whisper_language = "English",
whisper_task = "transcribe",
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -46,17 +46,17 @@ model.generation_config.forced_decoder_ids = None
model = FastModel.get_peft_model(
model,
r=64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=["q_proj", "v_proj"],
lora_alpha=64,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "v_proj"],
lora_alpha = 64,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
task_type=None, # ** MUST set this for Whisper **
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
task_type = None, # ** MUST set this for Whisper **
)
print("✅ Model and LoRA adapters loaded successfully!")
@ -110,12 +110,12 @@ print(f"{'='*80}")
model, tokenizer = FastModel.from_pretrained(
model_name="./whisper",
dtype=None, # Leave as None for auto detection
load_in_4bit=False, # Set to True to do 4bit quantization which reduces memory
auto_model=WhisperForConditionalGeneration,
whisper_language="English",
whisper_task="transcribe",
model_name = "./whisper",
dtype = None, # Leave as None for auto detection
load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
auto_model = WhisperForConditionalGeneration,
whisper_language = "English",
whisper_task = "transcribe",
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
@ -135,7 +135,7 @@ try:
headers = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(audio_url, headers=headers)
response = requests.get(audio_url, headers = headers)
response.raise_for_status()
with open(audio_file, "wb") as f:
f.write(response.content)
@ -156,12 +156,12 @@ model.eval()
# Create pipeline without specifying the device
whisper = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=tokenizer.tokenizer,
feature_extractor=tokenizer.feature_extractor,
processor=tokenizer,
return_language=True,
torch_dtype=torch.float16, # Remove the device parameter
model = model,
tokenizer = tokenizer.tokenizer,
feature_extractor = tokenizer.feature_extractor,
processor = tokenizer,
return_language = True,
torch_dtype = torch.float16, # Remove the device parameter
)
# Example usage
audio_file = "Speech_12dB_s16.flac"

View file

@ -21,7 +21,7 @@ from tests.utils.cleanup_utils import safe_remove_directory
## Dataset Preparation"""
print("\n📊 Loading and preparing dataset...")
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
# To select the first 2000 examples
train_dataset = dataset.select(range(2000))
@ -81,11 +81,11 @@ print("🤖 Loading base vision model...")
try:
model, tokenizer = FastVisionModel.from_pretrained(
# model_name = "unsloth/Qwen2-VL-7B-Instruct",
model_name="unsloth/Qwen2-VL-7B-Instruct",
max_seq_length=2048, # Choose any for long context!
load_in_4bit=True, # 4 bit quantization to reduce memory
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning=False, # [NEW!] We have full finetuning now!
model_name = "unsloth/Qwen2-VL-7B-Instruct",
max_seq_length = 2048, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
)
except Exception as e:
print(f"❌ Failed to load base model: {e}")
@ -96,18 +96,18 @@ print("\n🔧 Setting up LoRA configuration...")
try:
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers=True, # Turn off for just text!
finetune_language_layers=True, # Should leave on!
finetune_attention_modules=True, # Attention good for GRPO
finetune_mlp_modules=True, # SHould leave on always!
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
lora_alpha=32,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print("✅ LoRA configuration applied successfully!")
print(f" 🎯 LoRA rank (r): 16")
@ -128,40 +128,40 @@ FastVisionModel.for_training(model) # Enable for training!
try:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
data_collator=UnslothVisionDataCollator(model, tokenizer),
train_dataset=train_dataset,
args=SFTConfig(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer),
train_dataset = train_dataset,
args = SFTConfig(
# per_device_train_batch_size = 4,
# gradient_accumulation_steps = 8,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
gradient_checkpointing = True,
gradient_checkpointing_kwargs = {
"use_reentrant": False
}, # use reentrant checkpointing
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03,
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
warmup_ratio = 0.03,
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
max_steps=10,
learning_rate=2e-4,
fp16=not is_bf16_supported(),
bf16=is_bf16_supported(),
logging_steps=5,
save_strategy="epoch",
optim="adamw_torch_fused",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="checkpoints",
report_to="none", # For Weights and Biases
max_steps = 10,
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
logging_steps = 5,
save_strategy = "epoch",
optim = "adamw_torch_fused",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "checkpoints",
report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
remove_unused_columns=False,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
dataset_num_proc=4,
max_seq_length=2048,
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
print("✅ Trainer setup completed!")
@ -221,7 +221,7 @@ try:
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
print(f"🚀 Uploading to repository: {repo_name}")
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:
@ -233,8 +233,8 @@ try:
print("\n" + "=" * 80)
print("=== VERIFYING REPO CONTENTS ===".center(80))
print("=" * 80 + "\n")
fs = HfFileSystem(token=hf_token)
file_list = fs.ls(repo_name, detail=True)
fs = HfFileSystem(token = hf_token)
file_list = fs.ls(repo_name, detail = True)
safetensors_found = any(
file["name"].endswith("model.safetensors.index.json") for file in file_list
)

View file

@ -22,7 +22,7 @@ from tests.utils.cleanup_utils import safe_remove_directory
## Dataset Preparation"""
print("\n📊 Loading and preparing dataset...")
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split="train")
dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
# To select the first 2000 examples
train_dataset = dataset.select(range(2000))
@ -82,11 +82,11 @@ print("🤖 Loading base vision model...")
try:
model, tokenizer = FastVisionModel.from_pretrained(
# model_name = "unsloth/Qwen2-VL-7B-Instruct",
model_name="unsloth/Qwen2-VL-2B-Instruct",
max_seq_length=2048, # Choose any for long context!
load_in_4bit=True, # 4 bit quantization to reduce memory
load_in_8bit=False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning=False, # [NEW!] We have full finetuning now!
model_name = "unsloth/Qwen2-VL-2B-Instruct",
max_seq_length = 2048, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
full_finetuning = False, # [NEW!] We have full finetuning now!
)
except Exception as e:
print(f"❌ Failed to load base model: {e}")
@ -97,18 +97,18 @@ print("\n🔧 Setting up LoRA configuration...")
try:
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers=True, # Turn off for just text!
finetune_language_layers=True, # Should leave on!
finetune_attention_modules=True, # Attention good for GRPO
finetune_mlp_modules=True, # SHould leave on always!
r=16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
lora_alpha=32,
lora_dropout=0, # Supports any, but = 0 is optimized
bias="none", # Supports any, but = "none" is optimized
use_gradient_checkpointing="unsloth", # True or "unsloth" for very long context
random_state=3407,
use_rslora=False, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print("✅ LoRA configuration applied successfully!")
print(f" 🎯 LoRA rank (r): 16")
@ -129,40 +129,40 @@ FastVisionModel.for_training(model) # Enable for training!
try:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
data_collator=UnslothVisionDataCollator(model, tokenizer),
train_dataset=train_dataset,
args=SFTConfig(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer),
train_dataset = train_dataset,
args = SFTConfig(
# per_device_train_batch_size = 4,
# gradient_accumulation_steps = 8,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
gradient_checkpointing = True,
gradient_checkpointing_kwargs = {
"use_reentrant": False
}, # use reentrant checkpointing
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03,
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
warmup_ratio = 0.03,
# num_train_epochs = 2, # Set this instead of max_steps for full training runs
max_steps=10,
learning_rate=2e-4,
fp16=not is_bf16_supported(),
bf16=is_bf16_supported(),
logging_steps=5,
save_strategy="epoch",
optim="adamw_torch_fused",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="checkpoints",
report_to="none", # For Weights and Biases
max_steps = 10,
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
logging_steps = 5,
save_strategy = "epoch",
optim = "adamw_torch_fused",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "checkpoints",
report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
remove_unused_columns=False,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
dataset_num_proc=4,
max_seq_length=2048,
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
print("✅ Trainer setup completed!")
@ -221,7 +221,7 @@ try:
print("=== UPLOADING MODEL TO HUB ===".center(80))
print("=" * 80 + "\n")
print(f"🚀 Uploading to repository: {repo_name}")
model.push_to_hub_merged(repo_name, tokenizer=tokenizer, token=hf_token)
model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
success["upload"] = True
print("✅ Model uploaded successfully!")
except Exception as e:

View file

@ -24,7 +24,7 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
"test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl",
}
os.makedirs(data_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok = True)
combined_filepath = os.path.join(data_dir, "aime.jsonl")
# Check if combined file already exists
@ -67,9 +67,9 @@ def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
# Write combined dataset
if all_problems:
with open(combined_filepath, "w", encoding="utf-8") as f:
with open(combined_filepath, "w", encoding = "utf-8") as f:
for problem in all_problems:
f.write(json.dumps(problem, ensure_ascii=False) + "\n")
f.write(json.dumps(problem, ensure_ascii = False) + "\n")
print(f"✅ Combined {len(all_problems)} problems from {len(datasets)} datasets")
print(f" Saved to: {combined_filepath}")
@ -92,7 +92,7 @@ def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
filepath = download_and_combine_aime_datasets(data_dir)
examples = []
with open(filepath, "r", encoding="utf-8") as f:
with open(filepath, "r", encoding = "utf-8") as f:
for line_num, line in enumerate(f):
line = line.strip()
if line:
@ -188,20 +188,20 @@ def get_num_tokens(text, tokenizer_instance):
"""Count tokens in text"""
if not text:
return 0
encoding = tokenizer_instance(text, return_tensors="pt")
encoding = tokenizer_instance(text, return_tensors = "pt")
return len(encoding["input_ids"][0])
def evaluate_model_aime(
model,
tokenizer,
model_type="base",
lora_request=None,
temperature=0.3,
n_sampling=8,
max_tokens=32768,
top_p=0.95,
seed=0,
model_type = "base",
lora_request = None,
temperature = 0.3,
n_sampling = 8,
max_tokens = 32768,
top_p = 0.95,
seed = 0,
):
"""Evaluate model on combined AIME dataset with official configuration"""
@ -237,11 +237,11 @@ def evaluate_model_aime(
# Setup sampling parameters (AIME configuration)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
n=n_sampling, # Multiple samples per question
seed=seed,
temperature = temperature,
top_p = top_p,
max_tokens = max_tokens,
n = n_sampling, # Multiple samples per question
seed = seed,
)
print(f"\n🔧 Configuration:")
@ -272,13 +272,13 @@ def evaluate_model_aime(
# Main evaluation loop
with tqdm(
total=len(eval_dataset), desc="Processing AIME problems", unit="problem"
total = len(eval_dataset), desc = "Processing AIME problems", unit = "problem"
) as pbar:
for task_id, item in enumerate(eval_dataset):
try:
# Prepare prompt
prompt_text = tokenizer.apply_chat_template(
item["prompt"], add_generation_prompt=True, tokenize=False
item["prompt"], add_generation_prompt = True, tokenize = False
)
input_tokens.append(get_num_tokens(prompt_text, tokenizer))
@ -286,9 +286,9 @@ def evaluate_model_aime(
# Generate multiple responses
outputs = model.fast_generate(
[prompt_text],
sampling_params=sampling_params,
lora_request=lora_request,
use_tqdm=False,
sampling_params = sampling_params,
lora_request = lora_request,
use_tqdm = False,
)[0].outputs
# Process all generated responses
@ -413,8 +413,8 @@ def evaluate_model_aime(
# Save results
filename = f"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json"
with open(filename, "w", encoding="utf-8") as f:
json.dump({"results": results, "records": records}, f, indent=4)
with open(filename, "w", encoding = "utf-8") as f:
json.dump({"results": results, "records": records}, f, indent = 4)
# Print comprehensive summary
print(f"\n{'='*70}")
@ -517,27 +517,27 @@ def compare_aime_results(all_results):
if all_results and "source_accuracies" in all_results[0]:
datasets = list(all_results[0]["source_accuracies"].keys())
print(f"{'Model':<15}", end="")
print(f"{'Model':<15}", end = "")
for dataset in datasets:
print(f"{dataset:<15}", end="")
print(f"{dataset:<15}", end = "")
print()
print("-" * (15 + 15 * len(datasets)))
for result in all_results:
print(f"{result['model_type']:<15}", end="")
print(f"{result['model_type']:<15}", end = "")
for dataset in datasets:
accuracy = result["source_accuracies"].get(dataset, 0)
print(f"{accuracy:<15.1f}", end="")
print(f"{accuracy:<15.1f}", end = "")
print()
# Save comparison
comparison_data = {
"summary": all_results,
"best_model": max(all_results, key=lambda x: x["accuracy"]),
"best_model": max(all_results, key = lambda x: x["accuracy"]),
}
with open("aime_model_comparison.json", "w") as f:
json.dump(comparison_data, f, indent=4)
json.dump(comparison_data, f, indent = 4)
print(
f"\nBest performing model: {comparison_data['best_model']['model_type']} "

View file

@ -79,27 +79,27 @@ def generate_responses(
skip_special_tokens: bool = True,
dtype: torch.dtype = None,
):
inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)]
inputs = [tokenizer(prompt, return_tensors = "pt") for _ in range(num_generations)]
keys = inputs[0].keys()
batched_inputs = {
key: torch.cat([input[key] for input in inputs], dim=0).to(model.device)
key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device)
for key in keys
}
if dtype is not None:
inference_context = torch.autocast(device_type="cuda", dtype=dtype)
inference_context = torch.autocast(device_type = "cuda", dtype = dtype)
else:
inference_context = nullcontext()
with inference_context:
outputs = model.generate(
**batched_inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
max_new_tokens = max_new_tokens,
do_sample = do_sample,
temperature = temperature,
)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens)
return responses
@ -117,11 +117,11 @@ def sample_responses(
model,
tokenizer,
prompt,
temperature=temperature,
num_generations=num_generations,
max_new_tokens=max_new_tokens,
skip_special_tokens=skip_special_tokens,
dtype=dtype,
temperature = temperature,
num_generations = num_generations,
max_new_tokens = max_new_tokens,
skip_special_tokens = skip_special_tokens,
dtype = dtype,
)
return responses
@ -136,32 +136,32 @@ def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):
def setup_model(
model_name,
quantize: bool = True,
dtype=torch.bfloat16,
peft_config=None,
dtype = torch.bfloat16,
peft_config = None,
autocast_adapter: bool = True,
):
if quantize:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
else:
bnb_config = None
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cuda:0",
attn_implementation="sdpa",
quantization_config=bnb_config,
torch_dtype=dtype,
device_map = "cuda:0",
attn_implementation = "sdpa",
quantization_config = bnb_config,
torch_dtype = dtype,
)
model = prepare_model_for_kbit_training(model) if quantize else model
if peft_config is not None:
model = get_peft_model(
model, peft_config, autocast_adapter_dtype=autocast_adapter
model, peft_config, autocast_adapter_dtype = autocast_adapter
)
return model
@ -169,19 +169,19 @@ def setup_model(
def get_peft_config(
lora_rank,
lora_alpha=None,
lora_dropout=0.0,
bias="none",
target_modules="all-linear",
lora_alpha = None,
lora_dropout = 0.0,
bias = "none",
target_modules = "all-linear",
):
lora_alpha = lora_alpha or 2 * lora_rank
peft_config = LoraConfig(
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
r=lora_rank,
bias=bias,
target_modules=target_modules,
task_type="CAUSAL_LM",
lora_alpha = lora_alpha,
lora_dropout = lora_dropout,
r = lora_rank,
bias = bias,
target_modules = target_modules,
task_type = "CAUSAL_LM",
)
return peft_config
@ -191,18 +191,18 @@ def setup_trainer(
tokenizer,
dataset,
train_args,
peft_config=None,
formatting_func=None,
collator=None,
peft_config = None,
formatting_func = None,
collator = None,
):
return SFTTrainer(
model=model,
peft_config=peft_config,
train_dataset=dataset,
processing_class=tokenizer,
formatting_func=formatting_func,
data_collator=collator,
args=train_args,
model = model,
peft_config = peft_config,
train_dataset = dataset,
processing_class = tokenizer,
formatting_func = formatting_func,
data_collator = collator,
args = train_args,
)
@ -212,17 +212,17 @@ def setup_lora(
dataset,
peft_config,
train_args,
formatting_func=None,
collator=None,
formatting_func = None,
collator = None,
):
return LoraConfig(
model=model,
peft_config=peft_config,
train_dataset=dataset,
processing_class=tokenizer,
formatting_func=formatting_func,
data_collator=collator,
args=train_args,
model = model,
peft_config = peft_config,
train_dataset = dataset,
processing_class = tokenizer,
formatting_func = formatting_func,
data_collator = collator,
args = train_args,
)
@ -236,7 +236,7 @@ def convert_weights_back_to_dtype(model, dtype):
param.data = param.data.to(dtype)
def fix_llama3_tokenizer(tokenizer, padding_side="right"):
def fix_llama3_tokenizer(tokenizer, padding_side = "right"):
tokenizer.padding_side = padding_side
added_vocab = tokenizer.get_added_vocab()
pad_token = [w for w in added_vocab if "pad" in w]
@ -276,12 +276,12 @@ def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"):
w_dq = w_dq.to(original_dtype)
new_module = torch.nn.Linear(
w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None
w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None
)
new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False)
new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False)
if module.lora_bias[adapter_name]:
bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias
new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False)
new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False)
return new_module

View file

@ -39,7 +39,7 @@ if already_imported:
f"to ensure all optimizations are applied. Your code may run slower or encounter "
f"memory issues without these optimizations.\n\n"
f"Please restructure your imports with 'import unsloth' at the top of your file.",
stacklevel=2,
stacklevel = 2,
)
del already_imported, critical_modules
@ -141,7 +141,7 @@ if DEVICE_TYPE == "cuda":
old_is_bf16_supported = torch.cuda.is_bf16_supported
if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
def is_bf16_supported(including_emulation=False):
def is_bf16_supported(including_emulation = False):
return old_is_bf16_supported(including_emulation)
torch.cuda.is_bf16_supported = is_bf16_supported

View file

@ -86,7 +86,7 @@ class LoRA_MLP(torch.autograd.Function):
downS,
_forward_function,
_backward_function,
inplace=True,
inplace = True,
):
dtype = X.dtype
@ -169,39 +169,39 @@ class LoRA_MLP(torch.autograd.Function):
# d_downB = (downA.t() @ h.t()) @ dY
# d_downA *= downS
# d_downB *= downS
d_downA.addmm_(h.t(), dY @ downB.t(), alpha=downS, beta=0)
d_downB.addmm_(downA.t() @ h.t(), dY, alpha=downS, beta=0)
d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)
# Up projection LoRA weights
# d_upA = X.t() @ (df @ upB.t())
# d_upB = (upA.t() @ X.t()) @ df
# d_upA *= upS
# d_upB *= upS
d_upA.addmm_(X.t(), df @ upB.t(), alpha=upS, beta=0)
d_upB.addmm_(upA.t() @ X.t(), df, alpha=upS, beta=0)
d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)
# Gate projection LoRA weights
# d_gateA = X.t() @ (de @ gateB.t())
# d_gateB = (gateA.t() @ X.t()) @ de
# d_gateA *= gateS
# d_gateB *= gateS
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha=gateS, beta=0)
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha=gateS, beta=0)
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out=X if ctx.inplace else None)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
dX.addmm_(df @ upB.t(), upA.t(), alpha=upS)
dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)
gateW = fast_dequantize(gateW.t(), gateW_quant)
# dX += de @ gateW.t()
dX.addmm_(de, gateW.t())
del gateW
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
dX.addmm_(de @ gateB.t(), gateA.t(), alpha=gateS)
dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
@ -232,7 +232,7 @@ class LoRA_MLP(torch.autograd.Function):
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
def apply_lora_mlp_swiglu(self, X, inplace=True):
def apply_lora_mlp_swiglu(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
@ -264,7 +264,7 @@ def apply_lora_mlp_swiglu(self, X, inplace=True):
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
def apply_lora_mlp_geglu_exact(self, X, inplace=True):
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
@ -375,7 +375,7 @@ class LoRA_QKV(torch.autograd.Function):
VA,
VB,
VS,
inplace=True,
inplace = True,
):
dtype = X.dtype
@ -452,32 +452,32 @@ class LoRA_QKV(torch.autograd.Function):
# d_QB = (QA.t() @ X.t()) @ dQ
# d_QA *= QS
# d_QB *= QS
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha=QS, beta=0)
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha=QS, beta=0)
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)
# K Projection
# d_KA = X.t() @ (dK @ KB.t())
# d_KB = (KA.t() @ X.t()) @ dK
# d_KA *= KS
# d_KB *= KS
d_KA.addmm_(X.t(), dK @ KB.t(), alpha=KS, beta=0)
d_KB.addmm_(KA.t() @ X.t(), dK, alpha=KS, beta=0)
d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)
# V Projection
# d_VA = X.t() @ (dV @ VB.t())
# d_VB = (VA.t() @ X.t()) @ dV
# d_VA *= VS
# d_VB *= VS
d_VA.addmm_(X.t(), dV @ VB.t(), alpha=VS, beta=0)
d_VB.addmm_(VA.t() @ X.t(), dV, alpha=VS, beta=0)
d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)
# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out=X if ctx.inplace else None)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
dX.addmm_(dQ @ QB.t(), QA.t(), alpha=QS)
dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)
# dK
KW = fast_dequantize(KW.t(), KW_quant)
@ -485,7 +485,7 @@ class LoRA_QKV(torch.autograd.Function):
dX.addmm_(dK, KW.t())
del KW
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
dX.addmm_(dK @ KB.t(), KA.t(), alpha=KS)
dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)
# dV
VW = fast_dequantize(VW.t(), VW_quant)
@ -493,7 +493,7 @@ class LoRA_QKV(torch.autograd.Function):
dX.addmm_(dV, VW.t())
del VW
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
dX.addmm_(dV @ VB.t(), VA.t(), alpha=VS)
dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
@ -519,7 +519,7 @@ class LoRA_QKV(torch.autograd.Function):
)
def apply_lora_qkv(self, X, inplace=True):
def apply_lora_qkv(self, X, inplace = True):
X = _maybe_fake_quantize_activations(X, self.q_proj)
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
@ -611,15 +611,15 @@ class LoRA_W(torch.autograd.Function):
# d_B = (A.t() @ X.t()) @ dY
# d_A *= S
# d_B *= S
d_A.addmm_(X.t(), dY @ B.t(), alpha=S, beta=0)
d_B.addmm_(A.t() @ X.t(), dY, alpha=S, beta=0)
d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)
# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
dX.addmm_(dY @ B.t(), A.t(), alpha=S)
dX.addmm_(dY @ B.t(), A.t(), alpha = S)
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
@ -649,7 +649,7 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(
x, *args, adapter_names=adapter_names, **kwargs
x, *args, adapter_names = adapter_names, **kwargs
)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
@ -705,11 +705,11 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
lora_A = lora_A,
lora_B = lora_B,
scaling = scaling,
base_layer = self.get_base_layer(),
base_result = base_result,
)
if requires_conversion:
result = result.to(expected_dtype)

View file

@ -56,7 +56,7 @@ def run_benchmark_forward(
hidden_size = config.hidden_size
X = torch.randn(
bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
)
# Forward
@ -89,7 +89,7 @@ def run_benchmark_backward(
config: AutoConfig,
seqlen: int,
dtype: torch.dtype,
bs=1,
bs = 1,
):
torch.manual_seed(
SEED
@ -98,7 +98,7 @@ def run_benchmark_backward(
hidden_size = config.hidden_size
X = torch.randn(
bs, seqlen, hidden_size, dtype=dtype, device=device, requires_grad=True
bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
)
X_test = X.detach().clone().requires_grad_(True)
@ -114,14 +114,14 @@ def run_benchmark_backward(
# Bench
grad_output = torch.randn_like(output)
bench_backward_ref = lambda: output.backward(grad_output, retain_graph=True) # noqa: E731
bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph=True) # noqa: E731
bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True) # noqa: E731
bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True) # noqa: E731
ref_backward_time = do_bench(
bench_backward_ref, grad_to_none=[X, *ref_model.parameters()]
bench_backward_ref, grad_to_none = [X, *ref_model.parameters()]
)
fused_backward_time = do_bench(
bench_backward_fused, grad_to_none=[X_test, *tt_model.parameters()]
bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()]
)
print(
f"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x"
@ -138,10 +138,10 @@ def setup_model(
kernel_config_fwd,
kernel_config_bwd_dW,
kernel_config_bwd_dX,
dX_only=False,
dW_only=False,
overlap_router_shared=False,
device="cuda",
dX_only = False,
dW_only = False,
overlap_router_shared = False,
device = "cuda",
):
if isinstance(config, Qwen3MoeConfig):
ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
@ -149,29 +149,29 @@ def setup_model(
# Triton kernel grouped gemm version of MoE Block -- this is what we're testing
tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
ref_model,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
dX_only=dX_only,
dW_only=dW_only,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
dX_only = dX_only,
dW_only = dW_only,
).to(device, dtype)
elif isinstance(config, Llama4TextConfig):
ref_model = Llama4TextMoe(config).to(device, dtype)
tt_model = Llama4TritonTextMoe(
config,
overlap_router_shared=overlap_router_shared,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
dX_only=dX_only,
dW_only=dW_only,
overlap_router_shared = overlap_router_shared,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
dX_only = dX_only,
dW_only = dW_only,
).to(device, dtype)
else:
@ -205,31 +205,31 @@ def run_benchmark(
ref_model, tt_model = setup_model(
model_config,
dtype=dtype,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
dX_only=dX_only,
dW_only=dW_only,
overlap_router_shared=overlap_router_shared,
dtype = dtype,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
dX_only = dX_only,
dW_only = dW_only,
overlap_router_shared = overlap_router_shared,
)
if mode == "forward":
ref_time, fused_time = run_benchmark_forward(
ref_model,
tt_model,
config=model_config,
seqlen=seqlen,
dtype=dtype,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
config = model_config,
seqlen = seqlen,
dtype = dtype,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
)
else:
ref_time, fused_time = run_benchmark_backward(
ref_model, tt_model, config=model_config, seqlen=seqlen, dtype=dtype
ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype
)
if autotune:
@ -251,60 +251,60 @@ def run_benchmark(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--results_dir", type=str, default="benchmark_results")
parser.add_argument("--model", type=str, choices=["llama4", "qwen3"], required=True)
parser.add_argument("--seqlen", type=int, default=1024)
parser.add_argument("--results_dir", type = str, default = "benchmark_results")
parser.add_argument("--model", type = str, choices = ["llama4", "qwen3"], required = True)
parser.add_argument("--seqlen", type = int, default = 1024)
parser.add_argument(
"--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
)
parser.add_argument("--permute_x", action="store_true")
parser.add_argument("--permute_y", action="store_true")
parser.add_argument("--autotune", action="store_true")
parser.add_argument("--overlap_router_shared", action="store_true")
parser.add_argument("--permute_x", action = "store_true")
parser.add_argument("--permute_y", action = "store_true")
parser.add_argument("--autotune", action = "store_true")
parser.add_argument("--overlap_router_shared", action = "store_true")
parser.add_argument(
"--BLOCK_SIZE_M",
nargs=2,
type=int,
default=[DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
nargs = 2,
type = int,
default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
)
parser.add_argument(
"--BLOCK_SIZE_N",
nargs=2,
type=int,
default=[DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
nargs = 2,
type = int,
default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
)
parser.add_argument(
"--BLOCK_SIZE_K",
nargs=2,
type=int,
default=[DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
nargs = 2,
type = int,
default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
)
parser.add_argument(
"--num_warps",
nargs=2,
type=int,
default=[DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
nargs = 2,
type = int,
default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
)
parser.add_argument(
"--num_stages",
nargs=2,
type=int,
default=[DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
nargs = 2,
type = int,
default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
)
parser.add_argument(
"--use_tma_load_w", action="store_true"
"--use_tma_load_w", action = "store_true"
) # No need to specify, will automatically parametrize these for each kernel config
parser.add_argument(
"--use_tma_load_x", action="store_true"
"--use_tma_load_x", action = "store_true"
) # No need to specify, will automatically parametrize these for each kernel config
parser.add_argument(
"--use_tma_load_dy", action="store_true"
"--use_tma_load_dy", action = "store_true"
) # No need to specify, will automatically parametrize these for each kernel config
parser.add_argument(
"--mode",
type=str,
choices=["forward", "backward", "dW", "dX"],
default="forward",
type = str,
choices = ["forward", "backward", "dW", "dX"],
default = "forward",
)
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
@ -324,13 +324,13 @@ if __name__ == "__main__":
ref_time, fused_time = run_benchmark(
args.mode,
model_config,
seqlen=args.seqlen,
dtype=args.dtype,
permute_x=args.permute_x,
permute_y=args.permute_y,
autotune=args.autotune,
overlap_router_shared=args.overlap_router_shared,
results_dir=args.results_dir,
seqlen = args.seqlen,
dtype = args.dtype,
permute_x = args.permute_x,
permute_y = args.permute_y,
autotune = args.autotune,
overlap_router_shared = args.overlap_router_shared,
results_dir = args.results_dir,
)
end_time = time.time()
print(f"Total time: {end_time - start_time:.4f} seconds")
@ -343,13 +343,13 @@ if __name__ == "__main__":
kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y)
print(f"Running {len(kernel_configs)} kernel configs")
default_kernel_config_fwd = KernelConfigForward(
permute_x=args.permute_x, permute_y=args.permute_y
permute_x = args.permute_x, permute_y = args.permute_y
)
default_kernel_config_bwd_dW = KernelConfigBackward_dW(
permute_x=args.permute_x, permute_y=args.permute_y
permute_x = args.permute_x, permute_y = args.permute_y
)
default_kernel_config_bwd_dX = KernelConfigBackward_dX(
permute_x=args.permute_x, permute_y=args.permute_y
permute_x = args.permute_x, permute_y = args.permute_y
)
results = []
for kernel_config in kernel_configs:
@ -374,21 +374,21 @@ if __name__ == "__main__":
ref_time, fused_time = run_benchmark(
args.mode,
model_config,
seqlen=args.seqlen,
dtype=args.dtype,
permute_x=kernel_config.permute_x,
permute_y=kernel_config.permute_y,
autotune=False,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
seqlen = args.seqlen,
dtype = args.dtype,
permute_x = kernel_config.permute_x,
permute_y = kernel_config.permute_y,
autotune = False,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
)
results.append(
KernelResult(
torch_time=ref_time,
triton_time=fused_time,
speedup=ref_time / fused_time,
kernel_config=kernel_config,
torch_time = ref_time,
triton_time = fused_time,
speedup = ref_time / fused_time,
kernel_config = kernel_config,
)
)
df = post_process_results(

View file

@ -37,15 +37,15 @@ def convert_args_to_list(args):
def get_forward_configs(
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
TMA_LOAD_X=True,
TMA_LOAD_W=True,
TMA_STORE=False, # NOTE: TMA_STORE is disabled for now
num_warps=DEFAULT_NUM_WARPS,
num_stages=DEFAULT_NUM_STAGES,
num_ctas=DEFAULT_NUM_CTAS,
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
TMA_LOAD_X = True,
TMA_LOAD_W = True,
TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
num_warps = DEFAULT_NUM_WARPS,
num_stages = DEFAULT_NUM_STAGES,
num_ctas = DEFAULT_NUM_CTAS,
):
(
BLOCK_M,
@ -95,16 +95,16 @@ def get_forward_configs(
kernel_configs.append(
triton.Config(
dict(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
USE_TMA_LOAD_X=tma_load_x,
USE_TMA_LOAD_W=tma_load_w,
USE_TMA_STORE=tma_store,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
USE_TMA_LOAD_X = tma_load_x,
USE_TMA_LOAD_W = tma_load_w,
USE_TMA_STORE = tma_store,
),
num_warps=w,
num_stages=s,
num_ctas=num_ctas,
num_warps = w,
num_stages = s,
num_ctas = num_ctas,
)
)
@ -112,15 +112,15 @@ def get_forward_configs(
def get_dX_kernel_configs(
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
TMA_LOAD_dY=True,
TMA_LOAD_W=True,
TMA_STORE=False, # NOTE: TMA_STORE is disabled for now
num_warps=DEFAULT_NUM_WARPS,
num_stages=DEFAULT_NUM_STAGES,
num_ctas=DEFAULT_NUM_CTAS,
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
TMA_LOAD_dY = True,
TMA_LOAD_W = True,
TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
num_warps = DEFAULT_NUM_WARPS,
num_stages = DEFAULT_NUM_STAGES,
num_ctas = DEFAULT_NUM_CTAS,
):
(
BLOCK_M,
@ -170,16 +170,16 @@ def get_dX_kernel_configs(
kernel_configs.append(
triton.Config(
dict(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
USE_TMA_LOAD_dY=tma_load_dy,
USE_TMA_LOAD_W=tma_load_w,
USE_TMA_STORE=tma_store,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
USE_TMA_LOAD_dY = tma_load_dy,
USE_TMA_LOAD_W = tma_load_w,
USE_TMA_STORE = tma_store,
),
num_warps=w,
num_stages=s,
num_ctas=num_ctas,
num_warps = w,
num_stages = s,
num_ctas = num_ctas,
)
)
@ -187,15 +187,15 @@ def get_dX_kernel_configs(
def get_dW_kernel_configs(
BLOCK_M=DEFAULT_M_BLOCK_SIZES,
BLOCK_N=DEFAULT_N_BLOCK_SIZES,
BLOCK_K=DEFAULT_K_BLOCK_SIZES,
num_warps=DEFAULT_NUM_WARPS,
num_stages=DEFAULT_NUM_STAGES,
num_ctas=DEFAULT_NUM_CTAS,
TMA_LOAD_dY=True,
TMA_LOAD_X=True,
TMA_STORE=False,
BLOCK_M = DEFAULT_M_BLOCK_SIZES,
BLOCK_N = DEFAULT_N_BLOCK_SIZES,
BLOCK_K = DEFAULT_K_BLOCK_SIZES,
num_warps = DEFAULT_NUM_WARPS,
num_stages = DEFAULT_NUM_STAGES,
num_ctas = DEFAULT_NUM_CTAS,
TMA_LOAD_dY = True,
TMA_LOAD_X = True,
TMA_STORE = False,
):
(
BLOCK_M,
@ -245,16 +245,16 @@ def get_dW_kernel_configs(
kernel_configs.append(
triton.Config(
dict(
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
USE_TMA_LOAD_dY=tma_load_dy,
USE_TMA_LOAD_X=tma_load_x,
USE_TMA_STORE=tma_store,
BLOCK_SIZE_M = block_m,
BLOCK_SIZE_N = block_n,
BLOCK_SIZE_K = block_k,
USE_TMA_LOAD_dY = tma_load_dy,
USE_TMA_LOAD_X = tma_load_x,
USE_TMA_STORE = tma_store,
),
num_warps=w,
num_stages=s,
num_ctas=num_ctas,
num_warps = w,
num_stages = s,
num_ctas = num_ctas,
)
)

View file

@ -84,18 +84,18 @@ def _grouped_gemm_dX_kernel(
if USE_TMA_LOAD_dY:
dY_desc = tl._experimental_make_tensor_descriptor(
dY_ptr,
shape=[TOTAL_TOKENS, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
shape = [TOTAL_TOKENS, N],
strides = [N, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
if USE_TMA_LOAD_W:
expert_stride = N * K
w_desc = tl._experimental_make_tensor_descriptor(
w_ptr,
shape=[NUM_EXPERTS, N, K],
strides=[expert_stride, K, 1],
block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
shape = [NUM_EXPERTS, N, K],
strides = [expert_stride, K, 1],
block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
m_end = 0
@ -104,7 +104,7 @@ def _grouped_gemm_dX_kernel(
n_block_range = tl.arange(0, BLOCK_SIZE_N)
k_block_range = tl.arange(0, BLOCK_SIZE_K)
for expert_idx in range(NUM_EXPERTS, flatten=FLATTEN):
for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN):
m_start = m_end
m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
m_end = m_start + m_size
@ -125,9 +125,9 @@ def _grouped_gemm_dX_kernel(
)
dX_desc = tl._experimental_make_tensor_descriptor(
dX_ptr,
shape=[m_end, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
shape = [m_end, K],
strides = [K, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
# Lower bound and upper bound are defined relative to the total tiles processed so far
@ -152,7 +152,7 @@ def _grouped_gemm_dX_kernel(
)
expert_token_idx = tl.load(
gather_indices_ptr + indices_to_gather,
mask=indices_to_gather < TOTAL_TOKENS,
mask = indices_to_gather < TOTAL_TOKENS,
)
expert_token_offsets = expert_token_idx[:, None]
@ -210,13 +210,13 @@ def _grouped_gemm_dX_kernel(
# col_mask = offs_bk[None, :] < K
store_mask = row_mask # & col_mask
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32)
# GEMM main loop
for n_offset in range(0, N, BLOCK_SIZE_N):
# dY block [M, N]
if not USE_TMA_LOAD_dY:
dY = tl.load(dY_ptrs, mask=row_mask)
dY = tl.load(dY_ptrs, mask = row_mask)
else:
dY = dY_desc.load(
[m_start + tile_m_idx * BLOCK_SIZE_M, n_offset]
@ -253,7 +253,7 @@ def _grouped_gemm_dX_kernel(
tl.store(
dX_ptr + store_idx + offs_bk[None, :],
dX,
mask=store_mask,
mask = store_mask,
)
# Move to the next tile within this expert group
@ -264,9 +264,9 @@ def _grouped_gemm_dX_kernel(
_autotuned_grouped_gemm_dX_kernel = triton.autotune(
configs=get_dX_kernel_configs(),
prune_configs_by={"early_config_prune": prune_dX_configs},
key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
configs = get_dX_kernel_configs(),
prune_configs_by = {"early_config_prune": prune_dX_configs},
key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
)(_grouped_gemm_dX_kernel)
"""
@ -324,17 +324,17 @@ def _grouped_gemm_dW_kernel(
if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH:
dY_desc = tl._experimental_make_tensor_descriptor(
dY_ptr,
shape=[TOTAL_TOKENS, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
shape = [TOTAL_TOKENS, N],
strides = [N, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
if USE_TMA_LOAD_X and not TMA_LOAD_BOTH:
x_desc = tl._experimental_make_tensor_descriptor(
x_ptr,
shape=[TOTAL_TOKENS, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
shape = [TOTAL_TOKENS, K],
strides = [K, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
# Output tiles per expert, since each expert weight matrix is [N, K]
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
@ -351,9 +351,9 @@ def _grouped_gemm_dW_kernel(
tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K")
dW_desc = tl._experimental_make_tensor_descriptor(
dW_ptr,
shape=[NUM_EXPERTS, N, K],
strides=[N * K, K, 1],
block_shape=[1, BLOCK_SIZE_N, BLOCK_SIZE_K],
shape = [NUM_EXPERTS, N, K],
strides = [N * K, K, 1],
block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
for tile_idx in range(
@ -377,7 +377,7 @@ def _grouped_gemm_dW_kernel(
m_end = 0
for expert_idx in range(NUM_EXPERTS):
# We need to instantiate a fresh accumulator for each expert
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=acc_dtype)
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype)
m_start = m_end
# Need to figure out why this cast is needed, otherwise compiler complains about mismatching types
@ -392,16 +392,16 @@ def _grouped_gemm_dW_kernel(
if TMA_LOAD_BOTH:
dY_desc = tl._experimental_make_tensor_descriptor(
dY_ptr,
shape=[m_end, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
shape = [m_end, N],
strides = [N, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
)
x_desc = tl._experimental_make_tensor_descriptor(
x_ptr,
shape=[m_end, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
shape = [m_end, K],
strides = [K, 1],
block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
)
for tile_m_idx in range(0, m_size, BLOCK_SIZE_M):
@ -425,7 +425,7 @@ def _grouped_gemm_dW_kernel(
# indices_to_gather = m_start + gather_offsets
expert_token_idx = tl.load(
gather_indices_ptr + indices_to_gather,
mask=indices_to_gather < TOTAL_TOKENS,
mask = indices_to_gather < TOTAL_TOKENS,
)
expert_token_offsets = expert_token_idx[:, None]
@ -461,7 +461,7 @@ def _grouped_gemm_dW_kernel(
x_ptr
+ x_row_load_idx
+ (k_offset + block_range_k)[None, :],
mask=mk_mask,
mask = mk_mask,
)
if USE_TMA_LOAD_dY:
@ -471,7 +471,7 @@ def _grouped_gemm_dW_kernel(
dY_ptr
+ dY_row_load_idx
+ (n_offset + block_range_n)[None, :],
mask=mn_mask,
mask = mn_mask,
)
accumulator += tl.dot(
@ -491,12 +491,12 @@ def _grouped_gemm_dW_kernel(
+ store_row_offs[:, None] * K
+ (k_offset + block_range_k)[None, :],
y,
mask=nk_mask,
mask = nk_mask,
)
_autotuned_grouped_gemm_dW_kernel = triton.autotune(
configs=get_dW_kernel_configs(),
prune_configs_by={"early_config_prune": prune_kernel_configs_backward_dW},
key=["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
configs = get_dW_kernel_configs(),
prune_configs_by = {"early_config_prune": prune_kernel_configs_backward_dW},
key = ["NUM_EXPERTS", "NUM_TOKENS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
)(_grouped_gemm_dW_kernel)

View file

@ -51,9 +51,9 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
def __init__(
self,
config: Llama4TextConfig,
overlap_router_shared=False,
verbose=False,
debug=False,
overlap_router_shared = False,
verbose = False,
debug = False,
):
super().__init__(config)
self.overlap_router_shared = overlap_router_shared
@ -136,7 +136,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
routing_weights, selected_experts = torch.topk(
router_logits, self.top_k, dim=-1
router_logits, self.top_k, dim = -1
)
routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
@ -195,7 +195,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
)
if self.top_k > 1:
hidden_states = hidden_states.sum(dim=1)
hidden_states = hidden_states.sum(dim = 1)
hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim)
# 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
@ -212,7 +212,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
# Start expert computation
first_gemm = torch_grouped_gemm(
X=hidden_states, W=self.experts.gate_up_proj, m_sizes=token_counts_by_expert
X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert
)
assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim)
@ -221,7 +221,7 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
# See comment above
second_gemm = torch_grouped_gemm(
X=intermediate, W=self.experts.down_proj, m_sizes=token_counts_by_expert
X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert
)
assert second_gemm.shape == (total_tokens, hidden_dim)
@ -234,17 +234,17 @@ class Llama4GroupedGemmTextMoe(Llama4TextMoe):
result = (
Llama4MoeResult(
token_counts_by_expert=token_counts_by_expert,
gather_indices=gather_indices,
topk_weights=routing_weights,
hidden_states_after_weight_merge=hidden_states_after_weight_merge,
first_gemm=first_gemm,
intermediate=intermediate,
second_gemm=second_gemm,
hidden_states_unpermute=hidden_states_unpermute,
shared_expert_out=shared_expert_out,
final_out=final_out,
router_logits=router_logits,
token_counts_by_expert = token_counts_by_expert,
gather_indices = gather_indices,
topk_weights = routing_weights,
hidden_states_after_weight_merge = hidden_states_after_weight_merge,
first_gemm = first_gemm,
intermediate = intermediate,
second_gemm = second_gemm,
hidden_states_unpermute = hidden_states_unpermute,
shared_expert_out = shared_expert_out,
final_out = final_out,
router_logits = router_logits,
)
if self.debug
else (final_out, routing_weights)
@ -257,7 +257,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
def __init__(
self,
config: Llama4TextConfig,
overlap_router_shared=False,
overlap_router_shared = False,
permute_x: bool = False,
permute_y: bool = True,
autotune: bool = True,
@ -266,9 +266,9 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
kernel_config_bwd_dX: KernelConfigBackward_dX = None,
dW_only: bool = False,
dX_only: bool = False,
verbose=False,
verbose = False,
):
super().__init__(config, overlap_router_shared=overlap_router_shared)
super().__init__(config, overlap_router_shared = overlap_router_shared)
assert not permute_x, "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights"
self.permute_x = permute_x
self.permute_y = permute_y
@ -321,7 +321,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
routing_weights, selected_experts = torch.topk(
router_logits, self.top_k, dim=-1
router_logits, self.top_k, dim = -1
)
routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
@ -380,7 +380,7 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
)
if self.top_k > 1:
hidden_states = hidden_states.sum(dim=1)
hidden_states = hidden_states.sum(dim = 1)
hidden_states = hidden_states.view(-1, hidden_dim)
# 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
@ -395,37 +395,37 @@ class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
# Start expert computation
hidden_states = grouped_gemm(
X=hidden_states,
W=self.experts.gate_up_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=self.permute_x,
permute_y=False, # output of first grouped gemm should never be permuted
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=True,
dW_only=self.dW_only,
dX_only=self.dX_only,
X = hidden_states,
W = self.experts.gate_up_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = self.permute_x,
permute_y = False, # output of first grouped gemm should never be permuted
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = True,
dW_only = self.dW_only,
dX_only = self.dX_only,
)
hidden_states = self.act_and_mul(hidden_states)
hidden_states = grouped_gemm(
X=hidden_states,
W=self.experts.down_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=False,
permute_y=self.permute_y,
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=False,
dW_only=self.dW_only,
dX_only=self.dX_only,
X = hidden_states,
W = self.experts.down_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = False,
permute_y = self.permute_y,
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = False,
dW_only = self.dW_only,
dX_only = self.dX_only,
)
# Post-processing

View file

@ -80,14 +80,14 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
gate,
gate_up_proj,
down_proj,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
dW_only=dW_only,
dX_only=dX_only,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
dW_only = dW_only,
dX_only = dX_only,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -112,37 +112,37 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states = permute(hidden_states, gather_indices, self.top_k)
# Start expert computation
hidden_states = grouped_gemm(
X=hidden_states,
W=self.gate_up_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=self.permute_x,
permute_y=False, # output of first grouped gemm should never be permuted
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=True,
dW_only=self.dW_only,
dX_only=self.dX_only,
X = hidden_states,
W = self.gate_up_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = self.permute_x,
permute_y = False, # output of first grouped gemm should never be permuted
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = True,
dW_only = self.dW_only,
dX_only = self.dX_only,
)
hidden_states = self.act_and_mul(hidden_states)
hidden_states = grouped_gemm(
X=hidden_states,
W=self.down_proj,
m_sizes=token_counts_by_expert,
gather_indices=gather_indices,
topk=self.top_k,
permute_x=False,
permute_y=self.permute_y,
autotune=self.autotune,
kernel_config_fwd=self.kernel_config_fwd,
kernel_config_bwd_dW=self.kernel_config_bwd_dW,
kernel_config_bwd_dX=self.kernel_config_bwd_dX,
is_first_gemm=False,
dW_only=self.dW_only,
dX_only=self.dX_only,
X = hidden_states,
W = self.down_proj,
m_sizes = token_counts_by_expert,
gather_indices = gather_indices,
topk = self.top_k,
permute_x = False,
permute_y = self.permute_y,
autotune = self.autotune,
kernel_config_fwd = self.kernel_config_fwd,
kernel_config_bwd_dW = self.kernel_config_bwd_dW,
kernel_config_bwd_dX = self.kernel_config_bwd_dX,
is_first_gemm = False,
dW_only = self.dW_only,
dX_only = self.dX_only,
)
# Post-processing
@ -155,7 +155,7 @@ class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
hidden_states.view(num_tokens, self.top_k, hidden_dim)
* routing_weights[..., None]
)
hidden_states = hidden_states.sum(dim=1)
hidden_states = hidden_states.sum(dim = 1)
hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
return hidden_states, router_logits

View file

@ -56,7 +56,7 @@ def calculate_topk(
gating_output.dtype
)
else:
scores = F.softmax(gating_output.to(torch.float32), dim=1).to(
scores = F.softmax(gating_output.to(torch.float32), dim = 1).to(
gating_output.dtype
)
@ -67,13 +67,13 @@ def calculate_topk(
else:
scores = gating_output
topk_weights, topk_ids = torch.topk(scores, k=top_k, dim=1)
topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1)
if post_act:
topk_weights = _activation(topk_weights)
if renormalize:
topk_weights /= torch.sum(topk_weights, dim=-1, keepdim=True).to(
topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to(
gating_output.dtype
)
@ -94,12 +94,12 @@ def get_routing_indices(
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
token_counts_by_expert = torch.histc(
selected_experts.view(-1),
bins=num_experts,
min=0,
max=num_experts,
bins = num_experts,
min = 0,
max = num_experts,
)
# token_indices_experts_sorted shape (bs*slen*top_k,)
gather_indices = torch.argsort(selected_experts.view(-1), stable=True)
gather_indices = torch.argsort(selected_experts.view(-1), stable = True)
if return_scatter_indices:
scatter_indices = gather_indices.argsort()
return token_counts_by_expert, gather_indices, scatter_indices
@ -107,7 +107,7 @@ def get_routing_indices(
return token_counts_by_expert, gather_indices
def torch_grouped_gemm(X, W, m_sizes, transpose=True):
def torch_grouped_gemm(X, W, m_sizes, transpose = True):
"""
X: [M, K] if forward, else [M, N]
W: [E, N, K]
@ -127,7 +127,7 @@ def torch_grouped_gemm(X, W, m_sizes, transpose=True):
N = W.shape[1]
result = torch.zeros((M, N), dtype=X.dtype, device=X.device)
result = torch.zeros((M, N), dtype = X.dtype, device = X.device)
m_start = 0
for g in range(E):

View file

@ -37,7 +37,7 @@ NUM_AUTOTUNE_CONFIGS = 50
@contextmanager
def annotated_context(prelude, epilogue="Passed!", char="-", num_chars=80):
def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
print(char * num_chars)
print(prelude)
yield
@ -81,7 +81,7 @@ def prep_triton_kernel_traits(autotune):
def sparse_to_dense(t: torch.Tensor):
t = t.sum(dim=0).view(-1)
t = t.sum(dim = 0).view(-1)
return t
@ -91,9 +91,9 @@ def _check_diff(
t2: torch.Tensor,
atol,
rtol,
precision=".6f",
verbose=False,
msg="",
precision = ".6f",
verbose = False,
msg = "",
):
t2 = t2.view_as(t1)
diff = t1.sub(t2).abs().max().item()
@ -101,7 +101,7 @@ def _check_diff(
if msg == "":
msg = "diff"
print(f"{msg}: {diff:{precision}}")
assert torch.allclose(t1, t2, atol=atol, rtol=rtol)
assert torch.allclose(t1, t2, atol = atol, rtol = rtol)
def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module):
@ -115,19 +115,19 @@ def _check_grads(
m2: torch.nn.Module,
atol,
rtol,
precision=".6f",
verbose=False,
msg="",
precision = ".6f",
verbose = False,
msg = "",
):
for name, param in m1.named_parameters():
_check_diff(
param.grad,
m2.get_parameter(name).grad,
atol=atol,
rtol=rtol,
precision=precision,
verbose=verbose,
msg=f"{msg}:{name}.grad",
atol = atol,
rtol = rtol,
precision = precision,
verbose = verbose,
msg = f"{msg}:{name}.grad",
)
@ -139,19 +139,19 @@ def model_config():
@pytest.mark.parametrize(
"overlap_router_shared",
[False, True],
ids=lambda x: "overlap_router_shared" if x else "no_overlap",
ids = lambda x: "overlap_router_shared" if x else "no_overlap",
)
@pytest.mark.parametrize(
"permute_y", [False, True], ids=lambda x: "permute_y" if x else "no_permute_y"
"permute_y", [False, True], ids = lambda x: "permute_y" if x else "no_permute_y"
)
@pytest.mark.parametrize(
"permute_x", [False], ids=lambda x: "permute_x" if x else "no_permute_x"
"permute_x", [False], ids = lambda x: "permute_x" if x else "no_permute_x"
) # Llama4 does not support permute_x
@pytest.mark.parametrize(
"autotune", [True], ids=lambda x: "autotune" if x else "manual"
"autotune", [True], ids = lambda x: "autotune" if x else "manual"
)
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids=lambda x: f"seqlen={x}")
@pytest.mark.parametrize("dtype", DTYPES, ids=str)
@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
@pytest.mark.parametrize("dtype", DTYPES, ids = str)
def test_llama4_ref(
dtype: torch.dtype,
seqlen,
@ -161,9 +161,9 @@ def test_llama4_ref(
overlap_router_shared: bool,
model_config: Llama4TextConfig, # test fixture
bs: int = 1,
device="cuda",
precision=".6f",
verbose=False,
device = "cuda",
precision = ".6f",
verbose = False,
):
torch.manual_seed(
SEED
@ -172,24 +172,24 @@ def test_llama4_ref(
hidden_dim = model_config.hidden_size
atol, rtol = TOLERANCES[dtype]
check_diff = partial(
_check_diff, atol=atol, rtol=rtol, precision=precision, verbose=verbose
_check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose
)
check_grads = partial(
_check_grads, atol=atol, rtol=rtol, precision=precision, verbose=verbose
_check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose
)
# Reference op -- HF
llama4_ref = Llama4TextMoe(model_config).to(dtype=dtype, device=device)
llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device)
# Torch grouped gemm impl
llama4_gg_ref = Llama4GroupedGemmTextMoe(
model_config, overlap_router_shared=overlap_router_shared
).to(dtype=dtype, device=device)
model_config, overlap_router_shared = overlap_router_shared
).to(dtype = dtype, device = device)
llama4_gg_ref.copy_weights(llama4_ref)
llama4_gg_ref.check_weights(llama4_ref)
x_ref = torch.randn(
bs, seqlen, hidden_dim, dtype=dtype, device=device, requires_grad=True
bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True
)
x_torch_gg = x_ref.detach().clone().requires_grad_()
x_triton = x_ref.detach().clone().requires_grad_()
@ -198,9 +198,9 @@ def test_llama4_ref(
y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg)
assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}"
with annotated_context("Testing torch grouped gemm Llama4TextMoe"):
check_diff(y_ref, y_torch_gg, msg="y_torch_gg")
check_diff(y_ref, y_torch_gg, msg = "y_torch_gg")
check_diff(
sparse_to_dense(routing_ref), routing_torch_gg, msg="routing_torch_gg"
sparse_to_dense(routing_ref), routing_torch_gg, msg = "routing_torch_gg"
)
kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = (
@ -209,38 +209,38 @@ def test_llama4_ref(
llama4_triton = Llama4TritonTextMoe(
model_config,
overlap_router_shared=overlap_router_shared,
permute_x=permute_x,
permute_y=permute_y,
autotune=autotune,
kernel_config_fwd=kernel_config_fwd,
kernel_config_bwd_dW=kernel_config_bwd_dW,
kernel_config_bwd_dX=kernel_config_bwd_dX,
).to(device=device, dtype=dtype)
overlap_router_shared = overlap_router_shared,
permute_x = permute_x,
permute_y = permute_y,
autotune = autotune,
kernel_config_fwd = kernel_config_fwd,
kernel_config_bwd_dW = kernel_config_bwd_dW,
kernel_config_bwd_dX = kernel_config_bwd_dX,
).to(device = device, dtype = dtype)
llama4_triton.copy_weights(llama4_ref)
llama4_triton.check_weights(llama4_ref)
y_triton, routing_triton = llama4_triton(x_triton)
with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"):
check_diff(y_ref, y_triton, msg="y_triton")
check_diff(sparse_to_dense(routing_ref), routing_triton, msg="routing_triton")
check_diff(y_ref, y_triton, msg = "y_triton")
check_diff(sparse_to_dense(routing_ref), routing_triton, msg = "routing_triton")
ref_grad = torch.randn_like(y_ref)
run_backwards(y_ref, ref_grad, llama4_ref)
run_backwards(y_torch_gg, ref_grad, llama4_gg_ref)
with annotated_context("Testing torch group gemm Llama4TextMoe backward"):
check_grads(llama4_ref, llama4_gg_ref, msg="torch_gg")
check_grads(llama4_ref, llama4_gg_ref, msg = "torch_gg")
run_backwards(y_triton, ref_grad, llama4_triton)
with annotated_context("Testing triton group gemm Llama4TextMoe backward"):
check_grads(llama4_ref, llama4_triton, msg="triton")
check_grads(llama4_ref, llama4_triton, msg = "triton")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seqlen", type=int, default=1024)
parser.add_argument("--seqlen", type = int, default = 1024)
parser.add_argument(
"--dtype", type=str, choices=["bfloat16", "float16"], default="bfloat16"
"--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
)
args = parser.parse_args()
args.dtype = getattr(torch, args.dtype)
@ -251,12 +251,12 @@ if __name__ == "__main__":
text_config: Llama4TextConfig = get_text_config(model_id)
for overlap in [False, True]:
test_llama4_ref(
seqlen=args.seqlen,
model_config=text_config,
dtype=args.dtype,
autotune=True,
permute_x=False,
permute_y=True,
overlap_router_shared=overlap,
verbose=True,
seqlen = args.seqlen,
model_config = text_config,
dtype = args.dtype,
autotune = True,
permute_x = False,
permute_y = True,
overlap_router_shared = overlap,
verbose = True,
)

View file

@ -45,16 +45,16 @@ def _rms_layernorm_forward(
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask=mask, other=0) # .to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0) # .to(tl.float32)
row_var = tl.sum(X_row * X_row, axis=0) / n_cols
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
tl.store(Y + col_offsets, output, mask=mask)
tl.store(Y + col_offsets, output, mask = mask)
def _rms_layernorm_backward(
@ -92,9 +92,9 @@ def _rms_layernorm_backward(
else:
dX = dY
dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
# Get saved row variance
inv_var = tl.load(r).to(tl.float32)
@ -105,9 +105,9 @@ def _rms_layernorm_backward(
else:
dY_W = dY_row * W_row
rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
tl.store(dX + col_offsets, output, mask=mask)
tl.store(dX + col_offsets, output, mask = mask)
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
@ -143,16 +143,16 @@ def _gemma_rms_layernorm_forward(
X += row_idx * X_row_stride
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis=0) / n_cols
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
tl.store(Y + col_offsets, output, mask=mask)
tl.store(Y + col_offsets, output, mask = mask)
class Fast_RMS_Layernorm(torch.autograd.Function):
@ -169,8 +169,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
device = X.device
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=device)
r = torch.empty(n_rows, dtype=torch.float32, device=device)
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
r = torch.empty(n_rows, dtype = torch.float32, device = device)
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
with torch_gpu_device(device):
@ -185,8 +185,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
r.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
@ -222,9 +222,9 @@ class Fast_RMS_Layernorm(torch.autograd.Function):
# dW, dW.stride(0),
n_cols,
ctx.eps,
GEMMA=ctx.GEMMA,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
GEMMA = ctx.GEMMA,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dX = dX.view(*shape)
return dX, None, None, None
@ -248,7 +248,7 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
def forward(self, X):
return fast_rms_layernorm(self, X, gemma=False)
return fast_rms_layernorm(self, X, gemma = False)
try:
@ -256,7 +256,7 @@ try:
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
def forward(self, X):
return fast_rms_layernorm(self, X, gemma=False)
return fast_rms_layernorm(self, X, gemma = False)
except:
@ -292,25 +292,25 @@ def unpatch_rms_layernorm():
def test_rms_layernorm(
dim=1024,
eps=1e-5,
dtype=torch.float16,
bsz=21,
random_state=3407,
seqlen=3341,
dim = 1024,
eps = 1e-5,
dtype = torch.float16,
bsz = 21,
random_state = 3407,
seqlen = 3341,
):
from transformers.models.llama.modeling_llama import LlamaRMSNorm
layernorm = LlamaRMSNorm((dim,), eps=eps).to("cuda")
layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
torch.cuda.manual_seed(random_state)
torch.manual_seed(random_state)
torch.nn.init.uniform_(layernorm.weight)
X = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda")
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
XX = X.clone()
X.requires_grad_(True)
XX.requires_grad_(True)
Y = layernorm(X)
YY = torch.randn((bsz, seqlen, dim), dtype=dtype, device="cuda", requires_grad=True)
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
Y.backward(YY)
correct_grad = X.grad.clone()
# from unsloth.kernels import fast_rms_layernorm
@ -322,14 +322,14 @@ def test_rms_layernorm(
def testing_suite_layernorm():
for dim in [512, 1024, 2048]:
for dtype in [torch.float16, torch.bfloat16]:
with torch.autocast(device_type="cuda", dtype=dtype):
with torch.autocast(device_type = "cuda", dtype = dtype):
for seqlen in [3341, 2048, 349]:
for random_state in [3407, 42]:
test_rms_layernorm(
dim=dim,
eps=1e-5,
dtype=dtype,
bsz=21,
random_state=random_state,
seqlen=seqlen,
dim = dim,
eps = 1e-5,
dtype = dtype,
bsz = 21,
random_state = random_state,
seqlen = seqlen,
)

View file

@ -50,16 +50,16 @@ def _rope_embedding(
+ (row_position % seqlen) * sin_row_stride
+ half_head_dim * 0
+ col_offsets,
mask=mask,
other=0,
mask = mask,
other = 0,
)
cos1 = tl.load(
cos
+ (row_position % seqlen) * cos_row_stride
+ half_head_dim * 0
+ col_offsets,
mask=mask,
other=0,
mask = mask,
other = 0,
)
if BACKWARD_PASS:
@ -78,11 +78,11 @@ def _rope_embedding(
)
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype)
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask)
tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask)
tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask)
tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask)
_rope_embedding = triton.jit(_rope_embedding)
@ -134,9 +134,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
seq_len,
head_dim,
n_heads,
BACKWARD_PASS=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
@ -177,9 +177,9 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
seq_len,
head_dim,
n_heads,
BACKWARD_PASS=True,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
BACKWARD_PASS = True,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
dY = dY.view(batch, seq_len, n_heads, head_dim)
return (
@ -211,7 +211,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
# Q * cos + rotate_half(Q) * sin
half = Q.shape[-1] // 2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim=-1)
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
Q.addcmul_(RH_Q, sin)
# RH_Q *= sin
@ -224,7 +224,7 @@ class Slow_RoPE_Embedding(torch.autograd.Function):
cos, sin = ctx.saved_tensors
# Q * cos + rotate_half.T(Q) * sin
half = dY.shape[-1] // 2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim=-1)
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin

View file

@ -43,8 +43,8 @@ def _fg_kernel(
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
# f = e * sigmoid(e)
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
@ -53,13 +53,13 @@ def _fg_kernel(
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask=mask)
tl.store(h + offsets, h_row, mask = mask)
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=e.device)
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
with torch_gpu_device(e.device):
_fg_kernel[grid](
@ -67,8 +67,8 @@ def swiglu_fg_kernel(e, g):
g,
h,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return h
@ -101,9 +101,9 @@ def _DWf_DW_dfg_kernel(
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
DW_row = tl.load(DW + offsets, mask=mask, other=0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
g_row = tl.load(g + offsets, mask=mask, other=0) # .to(tl.float32)
DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
@ -122,9 +122,9 @@ def _DWf_DW_dfg_kernel(
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
tl.store(DW + offsets, h_row, mask=mask) # h = f * g
tl.store(e + offsets, df_row, mask=mask) # df = DW * f
tl.store(g + offsets, de_row, mask=mask) # de
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
tl.store(g + offsets, de_row, mask = mask) # de
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
@ -137,7 +137,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
e,
g,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
LONG_INDEXING=0 if n_elements <= INT32_SAFETY_BUFFER else 1,
BLOCK_SIZE = BLOCK_SIZE,
LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
)
return DW, e, g

View file

@ -47,12 +47,12 @@ if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
if DEVICE_TYPE == "xpu":
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu")
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# tl.math.tanh now is libdevice.tanh
@ -349,7 +349,7 @@ def _maybe_fake_quantize_activations(
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
if isinstance(W, Float8Tensor):
return W.dequantize()
@ -390,13 +390,13 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
size, dtype=dtype, device=device, requires_grad=False
size, dtype = dtype, device = device, requires_grad = False
)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
n_elements_absmax,
dtype=torch.float32,
device=device,
requires_grad=False,
dtype = torch.float32,
device = device,
requires_grad = False,
)
if size > WEIGHT_BUFFER.numel():
@ -409,16 +409,16 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
else:
if out is None:
out = torch_empty(
shape, dtype=dtype, device=device, requires_grad=False
shape, dtype = dtype, device = device, requires_grad = False
)
else:
assert out.shape == shape
assert out.dtype == dtype
out_absmax = torch_empty(
n_elements_absmax,
dtype=torch_float32,
device=device,
requires_grad=False,
dtype = torch_float32,
device = device,
requires_grad = False,
)
# NF4 dequantization of statistics
@ -458,7 +458,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
return W.dequantize()
if quant_state is None:
@ -500,13 +500,13 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
size, dtype=dtype, device=device, requires_grad=False
size, dtype = dtype, device = device, requires_grad = False
)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
n_elements_absmax,
dtype=torch_float32,
device=device,
requires_grad=False,
dtype = torch_float32,
device = device,
requires_grad = False,
)
if size > WEIGHT_BUFFER.numel():
@ -519,16 +519,16 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
else:
if out is None:
out = torch_empty(
shape, dtype=dtype, device=device, requires_grad=False
shape, dtype = dtype, device = device, requires_grad = False
)
else:
assert out.shape == shape
assert out.dtype == dtype
out_absmax = torch_empty(
n_elements_absmax,
dtype=torch_float32,
device=device,
requires_grad=False,
dtype = torch_float32,
device = device,
requires_grad = False,
)
pass
@ -570,7 +570,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
else:
@torch.inference_mode
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
return W.dequantize()
if quant_state is None:
@ -601,12 +601,12 @@ else:
# Create weight matrix
if out is None:
out = torch_empty(shape, dtype=dtype, device=device, requires_grad=False)
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
assert out.shape == shape
assert out.dtype == dtype
out_absmax = torch_empty(
n_elements_absmax, dtype=torch_float32, device=device, requires_grad=False
n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False
)
# Do dequantization
@ -645,9 +645,9 @@ else:
# INTEL GPU Specific Logic
if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
def fast_gemv(X, W, quant_state, out=None):
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None:
return torch_matmul(X, W, out=out)
return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@ -686,8 +686,8 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
1,
bout,
),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
# else:
# assert(out.shape == (1, 1, bout,))
@ -710,7 +710,7 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
with torch_gpu_device(device):
cdequantize_blockwise_fp32(
get_ptr(code2),
@ -751,9 +751,9 @@ if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
def fast_gemv(X, W, quant_state, out=None):
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None:
return torch_matmul(X, W, out=out)
return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@ -793,8 +793,8 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
1,
bout,
),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
# else:
# assert(out.shape == (1, 1, bout,))
@ -813,7 +813,7 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
with torch_gpu_device(device):
cdequantize_blockwise_fp32(
get_ptr(code2),
@ -856,9 +856,9 @@ elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
pass
else:
def fast_gemv(X, W, quant_state, out=None):
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None:
return torch_matmul(X, W, out=out)
return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@ -894,8 +894,8 @@ else:
1,
bout,
),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
# else:
# assert(out.shape == (1, 1, bout,))
@ -914,7 +914,7 @@ else:
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
df = torch_empty(absmax.shape, dtype=torch_float32, device=device)
df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
@ -953,21 +953,21 @@ else:
pass
def fast_linear_forward(proj, X, temp_lora=None, out=None):
def fast_linear_forward(proj, X, temp_lora = None, out = None):
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1:
return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
out = torch_matmul(X, W.t(), out=out)
out = torch_matmul(X, W.t(), out = out)
elif W.dtype == torch.float8_e4m3fn:
out = fp8_linear(X, W, W_quant, bias)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out=out)
out = fast_gemv(X, W, W_quant, out = out)
else:
W = fast_dequantize(W.t(), W_quant, use_global_buffer=True)
out = torch_matmul(X, W, out=out)
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
out = torch_matmul(X, W, out = out)
# Add in LoRA weights
if lora_A is not None:
@ -980,14 +980,14 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None):
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out=temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S)
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch_mm(
X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora
X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora
)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
out = out.view(bsz, 1, out_dim)
if bias is not None:
@ -996,7 +996,7 @@ def fast_linear_forward(proj, X, temp_lora=None, out=None):
return out
def matmul_lora(X, W, W_quant, A, B, s, out=None):
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
if X.dim() == 3:
@ -1015,12 +1015,12 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None):
W = W.dequantize()
else:
W = W.contiguous()
out = torch_matmul(X, W.t(), out=out)
out = torch_matmul(X, W.t(), out = out)
elif W.dtype == torch.float8_e4m3fn:
out = fp8_linear(X, W, W_quant)
else:
W = fast_dequantize(W, W_quant, use_global_buffer=True)
out = torch_matmul(X, W.t(), out=out)
W = fast_dequantize(W, W_quant, use_global_buffer = True)
out = torch_matmul(X, W.t(), out = out)
if W_quant is not None:
del W
@ -1028,7 +1028,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out=None):
# LoRA is enabled
A, B = A.t(), B.t()
XA = torch_matmul(X, A.to(dtype))
out.addmm_(XA, B.to(dtype), alpha=s)
out.addmm_(XA, B.to(dtype), alpha = s)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out

View file

@ -150,23 +150,23 @@ for temporary_patch in TEMPORARY_PATCHES:
# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="torch")
warnings.filterwarnings(action="ignore", category=UserWarning, module="huggingface_hub")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(
action="ignore", category=FutureWarning, module="huggingface_hub"
action = "ignore", category = FutureWarning, module = "huggingface_hub"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="trl")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="trl")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="xformers")
warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="subprocess")
warnings.filterwarnings(action="ignore", category=UserWarning, module="transformers")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="accelerate")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(
action="ignore", category=RuntimeWarning, module="multiprocessing"
action = "ignore", category = RuntimeWarning, module = "multiprocessing"
)
warnings.filterwarnings(action="ignore", category=RuntimeWarning, module="multiprocess")
warnings.filterwarnings(action="ignore", category=UserWarning, module="triton")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
@ -345,10 +345,10 @@ except:
# You passed `quantization_config` or equivalent parameters
try:
warnings.filterwarnings(
action="ignore",
message=r".*quantization_config.*",
category=UserWarning,
append=True,
action = "ignore",
message = r".*quantization_config.*",
category = UserWarning,
append = True,
)
except:
pass
@ -357,10 +357,10 @@ except:
# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463
try:
warnings.filterwarnings(
action="ignore",
message=r".*Logical operators 'and' and 'or'.*",
category=UserWarning,
append=True,
action = "ignore",
message = r".*Logical operators 'and' and 'or'.*",
category = UserWarning,
append = True,
)
except:
pass
@ -465,7 +465,7 @@ def extract_quant_model_param_count(model):
return count
def get_model_param_count(model, trainable_only=False):
def get_model_param_count(model, trainable_only = False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
@ -596,14 +596,14 @@ if DEVICE_TYPE in ("cuda", "hip"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
elif DEVICE_TYPE == "xpu":
if Version(torch_version) < Version("2.6.0"):
raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="xpu")
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# =============================================
# =============================================
@ -934,9 +934,9 @@ import torch._inductor.utils
torch._inductor.utils.is_big_gpu = is_big_gpu
patch_torch_compile(
debug=UNSLOTH_COMPILE_DEBUG,
O3=UNSLOTH_COMPILE_MAXIMUM,
ignore_errors=UNSLOTH_COMPILE_IGNORE_ERRORS,
debug = UNSLOTH_COMPILE_DEBUG,
O3 = UNSLOTH_COMPILE_MAXIMUM,
ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS,
)
torch_compile_options = {
@ -983,9 +983,9 @@ def patch_regional_compilation():
[
torch.compile(
x,
dynamic=True,
options=torch_compile_options,
fullgraph=False,
dynamic = True,
options = torch_compile_options,
fullgraph = False,
)
for x in args[0]
]
@ -1008,14 +1008,14 @@ def prepare_model_for_kbit_training(
use_reentrant: Optional[bool] = True,
) -> Any:
return prepare_model_for_training(
model=model,
use_gradient_checkpointing=use_gradient_checkpointing,
use_reentrant=use_reentrant,
full_finetuning=False,
train_layernorms=False,
train_embedding=False,
train_lm_head=False,
float32_mixed_precision=True,
model = model,
use_gradient_checkpointing = use_gradient_checkpointing,
use_reentrant = use_reentrant,
full_finetuning = False,
train_layernorms = False,
train_embedding = False,
train_lm_head = False,
float32_mixed_precision = True,
)
@ -1033,7 +1033,7 @@ if Version(peft_version) < Version("0.12.0"):
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
spaces = re.findall(r"^([ ]{1,})break", source, flags=re.MULTILINE)[0]
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
source = source.replace(source[start:end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
@ -1070,7 +1070,7 @@ import socket
@functools.lru_cache(1)
def has_internet(host="8.8.8.8", port=53, timeout=3):
def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
return False
try:
@ -1084,12 +1084,12 @@ def has_internet(host="8.8.8.8", port=53, timeout=3):
import psutil
def _get_statistics(statistics=None, force_download=True):
def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
n_cpus = psutil.cpu_count(logical=False)
n_cpus = psutil.cpu_count(logical = False)
keynames = "\n" + "\n".join(os.environ.keys())
# Check modelscope for down detection
global USE_MODELSCOPE
@ -1150,12 +1150,12 @@ def _get_statistics(statistics=None, force_download=True):
if has_internet():
def stats_check():
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as f:
with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
snapshot_download(
f"unslothai/{statistics}",
force_download=True,
cache_dir=f,
local_dir=f,
force_download = True,
cache_dir = f,
local_dir = f,
)
time_limited_stats_check = execute_with_time_limit(120)(stats_check)
@ -1178,7 +1178,7 @@ def _get_statistics(statistics=None, force_download=True):
stats_check()
def get_statistics(local_files_only=False):
def get_statistics(local_files_only = False):
# We log some basic stats about which environment is being used.
# This is also to check if HuggingFace is down or not!
# We simply download a README.md file from HF - all data is made public.
@ -1201,7 +1201,7 @@ def get_statistics(local_files_only=False):
disable_progress_bars()
disabled = True
_get_statistics(None)
_get_statistics("repeat", force_download=False)
_get_statistics("repeat", force_download = False)
total_memory = (
torch.xpu.get_device_properties(0).total_memory
if DEVICE_TYPE == "xpu"
@ -1242,7 +1242,7 @@ BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
flags=re.MULTILINE,
flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
@ -1322,12 +1322,12 @@ def offload_to_disk(
torch.save(
W,
filename,
pickle_module=pickle,
pickle_protocol=pickle.HIGHEST_PROTOCOL,
pickle_module = pickle,
pickle_protocol = pickle.HIGHEST_PROTOCOL,
)
# We must use weights_only = False due to pickling
offloaded_W = torch.load(
filename, map_location="cpu", mmap=True, weights_only=False
filename, map_location = "cpu", mmap = True, weights_only = False
)
offloaded_W._offloaded_file_location = filename
return offloaded_W
@ -1352,7 +1352,7 @@ def offload_output_embeddings(
model.get_output_embeddings(), model, "output_embeddings", temporary_location
)
new_output_embeddings = torch.nn.Linear(1, 1, bias=None)
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
new_output_embeddings.in_features = offloaded_W.shape[1]
@ -1376,10 +1376,10 @@ def is_vLLM_available():
# Patches models to add RoPE Scaling
def patch_linear_scaling(
model_name="gemma2",
rope_module=None,
scaled_rope_module=None,
attention_module=None,
model_name = "gemma2",
rope_module = None,
scaled_rope_module = None,
attention_module = None,
):
assert rope_module is not None and scaled_rope_module is not None
assert attention_module is not None
@ -1430,13 +1430,13 @@ def patch_linear_scaling(
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function=rope_module.__name__,
scaled_rope_function=scaled_rope_module.__name__,
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
r"self\.rotary\_emb \= .+?\)",
function,
flags=re.DOTALL | re.MULTILINE,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
return None, exec_code + "\n\n" + function
@ -1449,12 +1449,12 @@ def patch_linear_scaling(
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
model_name="llama",
rope_module=None,
scaled_rope_module=None,
extended_rope_module=None,
attention_module=None,
longrope_module=None,
model_name = "llama",
rope_module = None,
scaled_rope_module = None,
extended_rope_module = None,
attention_module = None,
longrope_module = None,
):
assert (
rope_module is not None
@ -1528,17 +1528,17 @@ def patch_llama_rope_scaling(
"""
fix_rope_function = fix_rope_function.format(
rope_function=rope_module.__name__,
scaled_rope_function=scaled_rope_module.__name__,
extended_rope_function=extended_rope_module.__name__,
longrope_rope_function=(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
extended_rope_function = extended_rope_module.__name__,
longrope_rope_function = (
longrope_module if longrope_module is not None else rope_module
).__name__,
)
rotary_emb = re.findall(
r"self\.rotary\_emb \= .+?\)",
function,
flags=re.DOTALL | re.MULTILINE,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
return None, function
@ -1548,15 +1548,15 @@ def patch_llama_rope_scaling(
return init_name, function
def create_boolean_mask(n=4096, sliding_window=2048):
def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype=torch.bool)
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal=1, out=mask)
torch.triu(mask, diagonal=0, out=mask)
torch.triu(mask.T, diagonal=-sliding_window, out=mask.T)
return torch.triu(mask, diagonal = 1, out = mask)
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out=mask)
torch.logical_not(mask, out = mask)
return mask
@ -1567,37 +1567,37 @@ def test_mask_creation():
for s in range(1, 23):
correct_mask = (
AttentionMaskConverter(
is_causal=True,
sliding_window=s,
is_causal = True,
sliding_window = s,
)
.to_causal_4d(
1,
n,
n,
dtype=torch.float16,
dtype = torch.float16,
)
.squeeze(0)
.squeeze(0)
)
correct_mask = correct_mask == correct_mask.min()
our_mask = create_boolean_mask(n=n, sliding_window=s)
our_mask = create_boolean_mask(n = n, sliding_window = s)
assert torch.all(correct_mask == our_mask)
correct_mask = (
AttentionMaskConverter(
is_causal=True,
sliding_window=None,
is_causal = True,
sliding_window = None,
)
.to_causal_4d(
1,
n,
n,
dtype=torch.float16,
dtype = torch.float16,
)
.squeeze(0)
.squeeze(0)
)
correct_mask = correct_mask == correct_mask.min()
our_mask = create_boolean_mask(n=n, sliding_window=0)
our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert torch.all(correct_mask == our_mask)
@ -1812,34 +1812,34 @@ def unsloth_compile_transformers(
dtype,
model_name,
model_types,
token=None,
revision=None,
trust_remote_code=False,
sdpa_dynamic_mask=True,
sdpa_bool_masks=True,
sdpa_gqa_replace=True,
sdpa_dynamic_compile=True,
compile_attention=True,
disable_causal_masks=True,
compile_torch_modules=True,
compile_custom_modules=True,
compile_function_calls=True,
fuse_lm_head=True,
gradient_checkpointing=True,
manual_replacements=True,
fast_lora_forwards=True,
fast_residual_stream=True,
accurate_accumulation=True,
epilogue_fusion=True,
max_autotune=False,
shape_padding=True,
cudagraphs=False,
debug=False,
fullgraph=True,
import_from_cache=False,
disable=False,
return_logits=False,
unsloth_force_compile=False,
token = None,
revision = None,
trust_remote_code = False,
sdpa_dynamic_mask = True,
sdpa_bool_masks = True,
sdpa_gqa_replace = True,
sdpa_dynamic_compile = True,
compile_attention = True,
disable_causal_masks = True,
compile_torch_modules = True,
compile_custom_modules = True,
compile_function_calls = True,
fuse_lm_head = True,
gradient_checkpointing = True,
manual_replacements = True,
fast_lora_forwards = True,
fast_residual_stream = True,
accurate_accumulation = True,
epilogue_fusion = True,
max_autotune = False,
shape_padding = True,
cudagraphs = False,
debug = False,
fullgraph = True,
import_from_cache = False,
disable = False,
return_logits = False,
unsloth_force_compile = False,
):
if Version(torch_version) < Version("2.4.0"):
print(
@ -1864,31 +1864,31 @@ def unsloth_compile_transformers(
for model_type in model_types:
_unsloth_compile_transformers(
model_type,
sdpa_dynamic_mask=sdpa_dynamic_mask,
sdpa_bool_masks=sdpa_bool_masks,
sdpa_gqa_replace=sdpa_gqa_replace,
sdpa_dynamic_compile=sdpa_dynamic_compile,
compile_attention=compile_attention,
disable_causal_masks=disable_causal_masks,
compile_torch_modules=compile_torch_modules,
compile_custom_modules=compile_custom_modules,
compile_function_calls=compile_function_calls,
fuse_lm_head=fuse_lm_head,
gradient_checkpointing=gradient_checkpointing,
manual_replacements=manual_replacements,
fast_lora_forwards=fast_lora_forwards,
fast_residual_stream=fast_residual_stream,
accurate_accumulation=accurate_accumulation,
epilogue_fusion=epilogue_fusion,
max_autotune=max_autotune,
shape_padding=shape_padding,
cudagraphs=cudagraphs,
debug=debug,
fullgraph=fullgraph,
import_from_cache=import_from_cache,
disable=disable,
return_logits=return_logits,
supports_sdpa=supports_sdpa,
sdpa_dynamic_mask = sdpa_dynamic_mask,
sdpa_bool_masks = sdpa_bool_masks,
sdpa_gqa_replace = sdpa_gqa_replace,
sdpa_dynamic_compile = sdpa_dynamic_compile,
compile_attention = compile_attention,
disable_causal_masks = disable_causal_masks,
compile_torch_modules = compile_torch_modules,
compile_custom_modules = compile_custom_modules,
compile_function_calls = compile_function_calls,
fuse_lm_head = fuse_lm_head,
gradient_checkpointing = gradient_checkpointing,
manual_replacements = manual_replacements,
fast_lora_forwards = fast_lora_forwards,
fast_residual_stream = fast_residual_stream,
accurate_accumulation = accurate_accumulation,
epilogue_fusion = epilogue_fusion,
max_autotune = max_autotune,
shape_padding = shape_padding,
cudagraphs = cudagraphs,
debug = debug,
fullgraph = fullgraph,
import_from_cache = import_from_cache,
disable = disable,
return_logits = return_logits,
supports_sdpa = supports_sdpa,
)
# Redo patches which override compiler
for temporary_patch in TEMPORARY_PATCHES:
@ -1993,7 +1993,7 @@ def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, m
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits=4, loftq_iter=1)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
if hasattr(model.config, "quantization_config"):
raise ValueError(
@ -2069,9 +2069,9 @@ class TorchAOConfig:
base_config_and_filter_fns: List[
Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
] = field(
default_factory=lambda: [
default_factory = lambda: [
(
Int4WeightOnlyConfig(group_size=128),
Int4WeightOnlyConfig(group_size = 128),
lambda m, _: isinstance(m, torch.nn.Linear)
and getattr(m, "in_features", 0) >= 128,
),
@ -2143,7 +2143,7 @@ def _convert_torchao_model(model):
module_to_fqn_dict = {}
for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
quantize_(model, QATConfig(base_config, step="convert"), filter_fn=filter_fn)
quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)
# Default filter function used for quantize_
if filter_fn is None:
@ -2166,7 +2166,7 @@ def _convert_torchao_model(model):
kwargs["modules_to_not_convert"] = []
quant_config = ModuleFqnToConfig(module_to_fqn_dict)
quantization_config = TorchAoConfig(quant_type=quant_config, **kwargs)
quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
model.config.quantization_config = quantization_config
@ -2199,17 +2199,17 @@ def _prepare_model_for_qat(
and m.in_features >= group_size
)
torchao_config = TorchAOConfig(
qat_scheme=qat_scheme,
base_config_and_filter_fns=[(base_config, filter_fn)],
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
elif qat_scheme == "fp8-fp8":
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
base_config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow()
granularity = PerRow()
)
torchao_config = TorchAOConfig(
qat_scheme=qat_scheme, base_config_and_filter_fns=[(base_config, None)]
qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
)
elif qat_scheme == "int8-int4":
from torchao.quantization import (
@ -2218,35 +2218,35 @@ def _prepare_model_for_qat(
)
torchao_config = TorchAOConfig(
qat_scheme=qat_scheme,
base_config_and_filter_fns=[
qat_scheme = qat_scheme,
base_config_and_filter_fns = [
(
IntxWeightOnlyConfig(
weight_dtype=torch.int8, granularity=PerAxis(0)
weight_dtype = torch.int8, granularity = PerAxis(0)
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
),
(
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4, weight_granularity=PerGroup(32)
weight_dtype = torch.int4, weight_granularity = PerGroup(32)
),
None,
),
],
prequantization_transform=_untie_input_output_embeddings,
prequantization_transform = _untie_input_output_embeddings,
)
elif qat_scheme == "int4":
from torchao.quantization import Int4WeightOnlyConfig
group_size = 128
base_config = Int4WeightOnlyConfig(group_size=group_size)
base_config = Int4WeightOnlyConfig(group_size = group_size)
filter_fn = (
lambda m, _: isinstance(m, torch.nn.Linear)
and m.in_features >= group_size
)
torchao_config = TorchAOConfig(
qat_scheme=qat_scheme,
base_config_and_filter_fns=[(base_config, filter_fn)],
qat_scheme = qat_scheme,
base_config_and_filter_fns = [(base_config, filter_fn)],
)
else:
raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
@ -2264,7 +2264,7 @@ def _prepare_model_for_qat(
if torchao_config.prequantization_transform is not None:
torchao_config.prequantization_transform(model)
for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
quantize_(model, QATConfig(base_config, step="prepare"), filter_fn=filter_fn)
quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
return model

View file

@ -113,8 +113,8 @@ def Gemma2Attention_fast_forward(
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim=2)
V = torch.cat([past_key_value[1], V], dim=2)
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Only enable if the attention_mask is True
@ -139,10 +139,10 @@ def Gemma2Attention_fast_forward(
Q,
K,
V,
causal=True,
softcap=self.config.attn_logit_softcapping,
softmax_scale=self._flash_attention_softmax_scale,
window_size=window,
causal = True,
softcap = self.config.attn_logit_softcapping,
softmax_scale = self._flash_attention_softmax_scale,
window_size = window,
)
A = A.reshape(bsz, q_len, n_heads * head_dim)
else:
@ -174,7 +174,7 @@ def Gemma2DecoderLayer_fast_forward(
self, "_flag_for_generation"
): # past_key_value is not None:
out_weight = torch.empty(
self.input_layernorm.weight.shape, dtype=torch.float32, device="cuda:0"
self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0"
)
# Self Attention
@ -183,15 +183,15 @@ def Gemma2DecoderLayer_fast_forward(
self.input_layernorm, hidden_states, out_weight
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
_flag_for_generation=self._flag_for_generation,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
_flag_for_generation = self._flag_for_generation,
)
hidden_states = fast_rms_layernorm_inference_gemma(
self.post_attention_layernorm, hidden_states, out_weight
@ -211,31 +211,31 @@ def Gemma2DecoderLayer_fast_forward(
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(
self.input_layernorm, hidden_states, gemma=True
self.input_layernorm, hidden_states, gemma = True
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
)
hidden_states = fast_rms_layernorm(
self.post_attention_layernorm, hidden_states, gemma=True
self.post_attention_layernorm, hidden_states, gemma = True
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(
self.pre_feedforward_layernorm, hidden_states, gemma=True
self.pre_feedforward_layernorm, hidden_states, gemma = True
)
hidden_states = self.mlp(hidden_states)
hidden_states = fast_rms_layernorm(
self.post_feedforward_layernorm, hidden_states, gemma=True
self.post_feedforward_layernorm, hidden_states, gemma = True
)
hidden_states = residual + hidden_states
@ -260,9 +260,9 @@ def Gemma2Attention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill=False,
attention_mask=None,
use_sliding_window=False,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
@ -286,24 +286,24 @@ def Gemma2Attention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
(2, bsz, 1, attention_size), dtype=dtype, device=device
(2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
self.attention = torch.empty(
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
@ -331,9 +331,9 @@ def Gemma2Attention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@ -348,7 +348,7 @@ def Gemma2Attention_fast_forward_inference(
RH_Q = self.RH_Q
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
@ -357,7 +357,7 @@ def Gemma2Attention_fast_forward_inference(
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
@ -400,21 +400,21 @@ def Gemma2Attention_fast_forward_inference(
self.scalar
) # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len])
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A *= self.reciprocal_t
torch_tanh(A, out=A)
torch_tanh(A, out = A)
A *= self.t # Logit softcapping
A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype)
A = torch_matmul(A, Vnn, out=Qn)
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@ -425,13 +425,13 @@ def Gemma2Model_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
attention_mask=None,
attention_mask = None,
):
out_weights = tuple(
torch.empty_like(
self.model.layers[0].input_layernorm.weight,
dtype=torch.float32,
device=torch.device(x),
dtype = torch.float32,
device = torch.device(x),
)
for x in range(DEVICE_COUNT)
)
@ -441,7 +441,7 @@ def Gemma2Model_fast_forward_inference(
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
hidden_states *= torch.tensor(
math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype
math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
)
bsz, q_len, hd = hidden_states.shape
@ -456,7 +456,7 @@ def Gemma2Model_fast_forward_inference(
(bsz, q_len),
hidden_states,
seq_len,
sliding_window=self.config.sliding_window,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
@ -484,12 +484,12 @@ def Gemma2Model_fast_forward_inference(
)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states=hidden_states,
past_key_value=past_key_values[idx],
position_ids=position_ids,
attention_mask=SWA if use_sliding_window else GA,
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window=use_sliding_window,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = SWA if use_sliding_window else GA,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window = use_sliding_window,
)
hidden_states = fast_rms_layernorm_inference_gemma(
decoder_layer.post_attention_layernorm,
@ -518,10 +518,10 @@ def Gemma2Model_fast_forward_inference(
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=[],
attentions=[],
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
@ -529,10 +529,10 @@ class FastGemma2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="gemma2",
rope_module=GemmaFixedRotaryEmbedding,
scaled_rope_module=GemmaFixedLinearScalingRotaryEmbedding,
attention_module=Gemma2Attention,
model_name = "gemma2",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = Gemma2Attention,
)
if init_name is not None:
exec(function, globals())
@ -564,7 +564,7 @@ class FastGemma2Model(FastLlamaModel):
def post_patch(model, tokenizer):
# Gemma does not downcast RoPE
model, tokenizer = patch_model_and_tokenizer(
model, tokenizer, downcast_rope=False
model, tokenizer, downcast_rope = False
)
# Add 1 to weight

View file

@ -109,8 +109,8 @@ def GraniteAttention_fast_forward(
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim=2)
V = torch.cat([past_key_value[1], V], dim=2)
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@ -135,7 +135,7 @@ def GraniteAttention_fast_forward(
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
A = xformers_attention(
Q, K, V, attn_bias=causal_mask, scale=self.scaling, p=dropout_p
Q, K, V, attn_bias = causal_mask, scale = self.scaling, p = dropout_p
)
A = A.view(bsz, q_len, n_heads, head_dim)
@ -148,10 +148,10 @@ def GraniteAttention_fast_forward(
Q,
K,
V,
causal=True,
window_size=window,
softmax_scale=self.scaling,
dropout_p=dropout_p,
causal = True,
window_size = window,
softmax_scale = self.scaling,
dropout_p = dropout_p,
)
else:
# Grouped query attention
@ -170,10 +170,10 @@ def GraniteAttention_fast_forward(
Q,
K,
V,
attn_mask=attention_mask,
scale=self.scaling,
is_causal=False,
dropout_p=dropout_p,
attn_mask = attention_mask,
scale = self.scaling,
is_causal = False,
dropout_p = dropout_p,
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@ -212,18 +212,18 @@ def GraniteDecoderLayer_fast_forward(
self.input_layernorm, hidden_states
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings=position_embeddings,
_flag_for_generation=self._flag_for_generation,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
_flag_for_generation = self._flag_for_generation,
)
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
@ -231,28 +231,28 @@ def GraniteDecoderLayer_fast_forward(
self.post_attention_layernorm, hidden_states
)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings=position_embeddings,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
)
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
outputs = (hidden_states,)
if output_attentions:
@ -275,9 +275,9 @@ def GraniteAttention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill=False,
attention_mask=None,
use_sliding_window=False,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
assert (
@ -306,24 +306,24 @@ def GraniteAttention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
(2, bsz, 1, attention_size), dtype=dtype, device=device
(2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Only for Gemma2
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
self.attention = torch.empty(
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
self.half_head_dim = head_dim // 2
@ -343,9 +343,9 @@ def GraniteAttention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@ -359,7 +359,7 @@ def GraniteAttention_fast_forward_inference(
RH_Q = self.RH_Q
RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
torch.neg(RH_Q[:, :, :, :h], out=RH_Q[:, :, :, :h])
torch.neg(RH_Q[:, :, :, :h], out = RH_Q[:, :, :, :h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
@ -368,7 +368,7 @@ def GraniteAttention_fast_forward_inference(
] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
RH_K[:, :, :, :h] = Kn[:, :, :, h:]
RH_K[:, :, :, h:] = Kn[:, :, :, :h]
torch.neg(RH_K[:, :, :, :h], out=RH_K[:, :, :, :h])
torch.neg(RH_K[:, :, :, :h], out = RH_K[:, :, :, :h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
@ -396,18 +396,18 @@ def GraniteAttention_fast_forward_inference(
# pass
Qn *= self.scaling
A = torch_matmul(Qn, Kn.transpose(2, 3), out=self.attention[:, :, :, :cached_len])
A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(A, dim=-1, dtype=torch.float32) # .to(A.dtype)
A = torch_matmul(A, Vn, out=Qn)
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
A = torch_matmul(A, Vn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False)
# pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@ -418,7 +418,7 @@ def GraniteModel_fast_forward_inference(
input_ids,
past_key_values,
position_ids,
attention_mask=None,
attention_mask = None,
):
input_ids = input_ids[:, : self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
@ -459,37 +459,37 @@ def GraniteModel_fast_forward_inference(
)
hidden_states, present_key_value = GraniteAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states=hidden_states,
past_key_value=past_key_values[idx],
position_ids=position_ids,
attention_mask=attention_mask,
do_prefill=not hasattr(decoder_layer.self_attn, "paged_attention"),
position_embeddings=position_embeddings,
hidden_states = hidden_states,
past_key_value = past_key_values[idx],
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
position_embeddings = position_embeddings,
)
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(
decoder_layer.post_attention_layernorm, hidden_states
)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha=residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
next_decoder_cache.append(present_key_value)
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=[],
attentions=[],
last_hidden_state = hidden_states,
past_key_values = next_decoder_cache,
hidden_states = [],
attentions = [],
)
class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, config):
super().__init__(config=config)
super().__init__(config = config)
def patched_init(original_init):
@ -510,10 +510,10 @@ class FastGraniteModel(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="granite",
rope_module=GraniteRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=GraniteAttention,
model_name = "granite",
rope_module = GraniteRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = GraniteAttention,
)
if init_name is not None:
exec(function, globals())
@ -548,7 +548,7 @@ class FastGraniteModel(FastLlamaModel):
model.config.update({"unsloth_version": __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias=None)
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
lm_head.in_features = lm_head.weight.shape[1]
@ -560,7 +560,7 @@ class FastGraniteModel(FastLlamaModel):
model.model.embed_tokens.weight.data_ptr()
!= model.lm_head.weight.data_ptr()
):
lm_head = torch.nn.Linear(1, 1, bias=None)
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
lm_head.in_features = lm_head.weight.shape[1]

File diff suppressed because it is too large Load diff

View file

@ -120,33 +120,33 @@ DISABLE_SDPA_MODEL_NAMES = [
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name="unsloth/Llama-3.2-1B-Instruct",
max_seq_length=2048,
dtype=None,
load_in_4bit=True, # 4bit QLoRA
load_in_8bit=False, # 8bit LoRA
load_in_16bit=False, # 16bit LoRA
full_finetuning=False,
token=None,
device_map="sequential",
rope_scaling=None,
fix_tokenizer=True,
trust_remote_code=False,
use_gradient_checkpointing="unsloth",
resize_model_vocab=None,
revision=None,
use_exact_model_name=False,
offload_embedding=False,
float32_mixed_precision=None, # Forces float32 mixed precision
fast_inference=False, # uses vLLM
gpu_memory_utilization=0.5,
float8_kv_cache=False,
random_state=3407,
max_lora_rank=64,
disable_log_stats=True,
qat_scheme=None,
load_in_fp8=False, # fp8 LoRA (True, False, 'block')
unsloth_tiled_mlp=False,
model_name = "unsloth/Llama-3.2-1B-Instruct",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True, # 4bit QLoRA
load_in_8bit = False, # 8bit LoRA
load_in_16bit = False, # 16bit LoRA
full_finetuning = False,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None,
revision = None,
use_exact_model_name = False,
offload_embedding = False,
float32_mixed_precision = None, # Forces float32 mixed precision
fast_inference = False, # uses vLLM
gpu_memory_utilization = 0.5,
float8_kv_cache = False,
random_state = 3407,
max_lora_rank = 64,
disable_log_stats = True,
qat_scheme = None,
load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
unsloth_tiled_mlp = False,
*args,
**kwargs,
):
@ -157,40 +157,40 @@ class FastLanguageModel(FastLlamaModel):
try:
from huggingface_hub import login
login(token=token)
login(token = token)
except:
pass
if load_in_8bit or full_finetuning or qat_scheme is not None:
return FastModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
load_in_16bit=load_in_16bit,
full_finetuning=full_finetuning,
token=token,
device_map=device_map,
rope_scaling=rope_scaling, # [TODO] No effect
fix_tokenizer=fix_tokenizer, # [TODO] No effect
trust_remote_code=trust_remote_code,
use_gradient_checkpointing=use_gradient_checkpointing,
resize_model_vocab=resize_model_vocab, # [TODO] No effect
revision=revision,
return_logits=False, # Return logits
fullgraph=True, # No graph breaks
use_exact_model_name=use_exact_model_name,
offload_embedding=offload_embedding,
float32_mixed_precision=float32_mixed_precision,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
load_in_16bit = load_in_16bit,
full_finetuning = full_finetuning,
token = token,
device_map = device_map,
rope_scaling = rope_scaling, # [TODO] No effect
fix_tokenizer = fix_tokenizer, # [TODO] No effect
trust_remote_code = trust_remote_code,
use_gradient_checkpointing = use_gradient_checkpointing,
resize_model_vocab = resize_model_vocab, # [TODO] No effect
revision = revision,
return_logits = False, # Return logits
fullgraph = True, # No graph breaks
use_exact_model_name = use_exact_model_name,
offload_embedding = offload_embedding,
float32_mixed_precision = float32_mixed_precision,
# Pass vLLM/inference parameters
fast_inference=fast_inference,
gpu_memory_utilization=gpu_memory_utilization,
float8_kv_cache=float8_kv_cache,
random_state=random_state,
max_lora_rank=max_lora_rank,
disable_log_stats=disable_log_stats,
qat_scheme=qat_scheme,
load_in_fp8=load_in_fp8,
fast_inference = fast_inference,
gpu_memory_utilization = gpu_memory_utilization,
float8_kv_cache = float8_kv_cache,
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
qat_scheme = qat_scheme,
load_in_fp8 = load_in_fp8,
*args,
**kwargs,
)
@ -231,7 +231,7 @@ class FastLanguageModel(FastLlamaModel):
fp8_mode = None
if not use_exact_model_name:
new_model_name = get_model_name(
model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
)
if new_model_name is None and load_in_fp8 != False:
fp8_mode = _get_fp8_mode_and_check_settings(
@ -284,9 +284,9 @@ class FastLanguageModel(FastLlamaModel):
try:
model_config = AutoConfig.from_pretrained(
model_name,
token=token,
revision=revision,
trust_remote_code=trust_remote_code,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
@ -300,9 +300,9 @@ class FastLanguageModel(FastLlamaModel):
try:
peft_config = PeftConfig.from_pretrained(
model_name,
token=token,
revision=revision,
trust_remote_code=trust_remote_code,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
@ -345,7 +345,7 @@ class FastLanguageModel(FastLlamaModel):
both_exist = exist_adapter_config and exist_config
else:
# Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.
files = HfFileSystem(token=token).glob(f"{model_name}/*.json")
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
files = list(os.path.split(x)[-1] for x in files)
if (
sum(x == "adapter_config.json" or x == "config.json" for x in files)
@ -393,8 +393,8 @@ class FastLanguageModel(FastLlamaModel):
model_config = AutoConfig.from_pretrained(
model_name,
token=token,
trust_remote_code=trust_remote_code,
token = token,
trust_remote_code = trust_remote_code,
)
if not was_disabled:
@ -484,42 +484,42 @@ class FastLanguageModel(FastLlamaModel):
# dispatch_model = FastGraniteModel
else:
return FastModel.from_pretrained(
model_name=old_model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
load_in_16bit=load_in_16bit,
full_finetuning=full_finetuning,
token=token,
device_map=device_map,
rope_scaling=rope_scaling, # [TODO] No effect
fix_tokenizer=fix_tokenizer, # [TODO] No effect
trust_remote_code=trust_remote_code,
use_gradient_checkpointing=use_gradient_checkpointing,
resize_model_vocab=resize_model_vocab, # [TODO] No effect
revision=revision,
return_logits=False, # Return logits
fullgraph=True, # No graph breaks
use_exact_model_name=use_exact_model_name,
offload_embedding=offload_embedding,
float32_mixed_precision=float32_mixed_precision,
model_name = old_model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
load_in_16bit = load_in_16bit,
full_finetuning = full_finetuning,
token = token,
device_map = device_map,
rope_scaling = rope_scaling, # [TODO] No effect
fix_tokenizer = fix_tokenizer, # [TODO] No effect
trust_remote_code = trust_remote_code,
use_gradient_checkpointing = use_gradient_checkpointing,
resize_model_vocab = resize_model_vocab, # [TODO] No effect
revision = revision,
return_logits = False, # Return logits
fullgraph = True, # No graph breaks
use_exact_model_name = use_exact_model_name,
offload_embedding = offload_embedding,
float32_mixed_precision = float32_mixed_precision,
# Pass vLLM/inference parameters
fast_inference=fast_inference,
gpu_memory_utilization=gpu_memory_utilization,
float8_kv_cache=float8_kv_cache,
random_state=random_state,
max_lora_rank=max_lora_rank,
disable_log_stats=disable_log_stats,
qat_scheme=qat_scheme,
load_in_fp8=load_in_fp8,
unsloth_tiled_mlp=unsloth_tiled_mlp,
fast_inference = fast_inference,
gpu_memory_utilization = gpu_memory_utilization,
float8_kv_cache = float8_kv_cache,
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
qat_scheme = qat_scheme,
load_in_fp8 = load_in_fp8,
unsloth_tiled_mlp = unsloth_tiled_mlp,
*args,
**kwargs,
)
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype=dtype)
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
# Check if this is local model since the tokenizer gets overwritten
if (
@ -535,24 +535,24 @@ class FastLanguageModel(FastLlamaModel):
fast_inference, model_name = fast_inference_setup(model_name, model_config)
model, tokenizer = dispatch_model.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=_get_dtype(dtype),
load_in_4bit=load_in_4bit,
token=token,
device_map=device_map,
rope_scaling=rope_scaling,
fix_tokenizer=fix_tokenizer,
model_patcher=dispatch_model,
tokenizer_name=tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision if not is_peft else None,
fast_inference=fast_inference,
gpu_memory_utilization=gpu_memory_utilization,
float8_kv_cache=float8_kv_cache,
random_state=random_state,
max_lora_rank=max_lora_rank,
disable_log_stats=disable_log_stats,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = _get_dtype(dtype),
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = dispatch_model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
fast_inference = fast_inference,
gpu_memory_utilization = gpu_memory_utilization,
float8_kv_cache = float8_kv_cache,
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
*args,
**kwargs,
)
@ -601,10 +601,10 @@ class FastLanguageModel(FastLlamaModel):
model = PeftModel.from_pretrained(
model,
old_model_name,
token=token,
revision=revision,
is_trainable=True,
trust_remote_code=trust_remote_code,
token = token,
revision = revision,
is_trainable = True,
trust_remote_code = trust_remote_code,
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
@ -615,7 +615,7 @@ class FastLanguageModel(FastLlamaModel):
"UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
)
if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice)
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
return model, tokenizer
@ -645,40 +645,40 @@ class FastModel(FastBaseModel):
@staticmethod
def from_pretrained(
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
max_seq_length=2048,
dtype=None,
load_in_4bit=True, # 4bit QLoRA
load_in_8bit=False, # 8bit LoRA
load_in_16bit=False, # 16bit LoRA
full_finetuning=False,
token=None,
device_map="sequential",
rope_scaling=None, # [TODO] No effect
fix_tokenizer=True, # [TODO] No effect
trust_remote_code=False,
use_gradient_checkpointing="unsloth",
resize_model_vocab=None, # [TODO] No effect
revision=None,
return_logits=False, # Return logits
fullgraph=True, # No graph breaks
use_exact_model_name=False,
auto_model=None,
whisper_language=None,
whisper_task=None,
unsloth_force_compile=False,
offload_embedding=False,
float32_mixed_precision=None, # Forces float32 mixed precision
model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True, # 4bit QLoRA
load_in_8bit = False, # 8bit LoRA
load_in_16bit = False, # 16bit LoRA
full_finetuning = False,
token = None,
device_map = "sequential",
rope_scaling = None, # [TODO] No effect
fix_tokenizer = True, # [TODO] No effect
trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
resize_model_vocab = None, # [TODO] No effect
revision = None,
return_logits = False, # Return logits
fullgraph = True, # No graph breaks
use_exact_model_name = False,
auto_model = None,
whisper_language = None,
whisper_task = None,
unsloth_force_compile = False,
offload_embedding = False,
float32_mixed_precision = None, # Forces float32 mixed precision
# Add the missing vLLM/inference parameters
fast_inference=False, # uses vLLM
gpu_memory_utilization=0.5,
float8_kv_cache=False,
random_state=3407,
max_lora_rank=64,
disable_log_stats=True,
qat_scheme=None,
load_in_fp8=False, # fp8 LoRA (True, False, 'block')
unsloth_tiled_mlp=False,
fast_inference = False, # uses vLLM
gpu_memory_utilization = 0.5,
float8_kv_cache = False,
random_state = 3407,
max_lora_rank = 64,
disable_log_stats = True,
qat_scheme = None,
load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
unsloth_tiled_mlp = False,
*args,
**kwargs,
):
@ -689,7 +689,7 @@ class FastModel(FastBaseModel):
try:
from huggingface_hub import login
login(token=token)
login(token = token)
except:
pass
if whisper_language is not None:
@ -765,7 +765,7 @@ class FastModel(FastBaseModel):
fp8_mode = None
if not use_exact_model_name:
new_model_name = get_model_name(
model_name, load_in_4bit=load_in_4bit, load_in_fp8=load_in_fp8
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
)
if new_model_name is None and load_in_fp8 != False:
fp8_mode = _get_fp8_mode_and_check_settings(
@ -819,9 +819,9 @@ class FastModel(FastBaseModel):
try:
model_config = AutoConfig.from_pretrained(
model_name,
token=token,
revision=revision,
trust_remote_code=trust_remote_code,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_model = True
except Exception as error:
@ -835,9 +835,9 @@ class FastModel(FastBaseModel):
try:
peft_config = PeftConfig.from_pretrained(
model_name,
token=token,
revision=revision,
trust_remote_code=trust_remote_code,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
is_peft = True
except Exception as error:
@ -1011,7 +1011,7 @@ class FastModel(FastBaseModel):
exist_config = os.path.exists(os.path.join(model_name, "config.json"))
both_exist = exist_adapter_config and exist_config
else:
files = HfFileSystem(token=token).glob(f"{model_name}/*.json")
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
files = list(os.path.split(x)[-1] for x in files)
if (
sum(x == "adapter_config.json" or x == "config.json" for x in files)
@ -1059,8 +1059,8 @@ class FastModel(FastBaseModel):
model_config = AutoConfig.from_pretrained(
model_name,
token=token,
trust_remote_code=trust_remote_code,
token = token,
trust_remote_code = trust_remote_code,
)
if not was_disabled:
@ -1092,40 +1092,40 @@ class FastModel(FastBaseModel):
break
# Patch gradient checkpointing
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype=dtype)
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
with redirector:
patch_loss_functions(torch_compile=False)
patch_loss_functions(torch_compile = False)
model_types, supports_sdpa = unsloth_compile_transformers(
dtype=dtype,
model_name=model_name,
model_types=model_types,
token=token,
sdpa_dynamic_mask=True,
sdpa_bool_masks=True,
sdpa_gqa_replace=True,
sdpa_dynamic_compile=True,
compile_attention=True,
disable_causal_masks=True,
compile_torch_modules=True,
compile_custom_modules=True,
compile_function_calls=True,
fuse_lm_head=True,
gradient_checkpointing=True,
manual_replacements=True,
fast_lora_forwards=True,
fast_residual_stream=False,
accurate_accumulation=True,
epilogue_fusion=True,
max_autotune=False,
shape_padding=True,
cudagraphs=False,
debug=False,
fullgraph=fullgraph,
import_from_cache=False,
disable=False,
return_logits=return_logits,
trust_remote_code=trust_remote_code,
unsloth_force_compile=unsloth_force_compile,
dtype = dtype,
model_name = model_name,
model_types = model_types,
token = token,
sdpa_dynamic_mask = True,
sdpa_bool_masks = True,
sdpa_gqa_replace = True,
sdpa_dynamic_compile = True,
compile_attention = True,
disable_causal_masks = True,
compile_torch_modules = True,
compile_custom_modules = True,
compile_function_calls = True,
fuse_lm_head = True,
gradient_checkpointing = True,
manual_replacements = True,
fast_lora_forwards = True,
fast_residual_stream = False,
accurate_accumulation = True,
epilogue_fusion = True,
max_autotune = False,
shape_padding = True,
cudagraphs = False,
debug = False,
fullgraph = fullgraph,
import_from_cache = False,
disable = False,
return_logits = return_logits,
trust_remote_code = trust_remote_code,
unsloth_force_compile = unsloth_force_compile,
)
# Fix SDPA issues
for model_type in DISABLE_SDPA_MODEL_NAMES:
@ -1152,34 +1152,34 @@ class FastModel(FastBaseModel):
auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
model, tokenizer = FastBaseModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=_get_dtype(dtype),
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
load_in_16bit=load_in_16bit,
full_finetuning=full_finetuning,
token=token,
device_map=device_map,
trust_remote_code=trust_remote_code,
revision=revision if not is_peft else None,
model_types=model_types,
tokenizer_name=tokenizer_name,
auto_model=auto_model,
use_gradient_checkpointing=use_gradient_checkpointing,
supports_sdpa=supports_sdpa,
whisper_language=whisper_language,
whisper_task=whisper_task,
auto_config=model_config,
offload_embedding=offload_embedding,
float32_mixed_precision=float32_mixed_precision,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = _get_dtype(dtype),
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
load_in_16bit = load_in_16bit,
full_finetuning = full_finetuning,
token = token,
device_map = device_map,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
model_types = model_types,
tokenizer_name = tokenizer_name,
auto_model = auto_model,
use_gradient_checkpointing = use_gradient_checkpointing,
supports_sdpa = supports_sdpa,
whisper_language = whisper_language,
whisper_task = whisper_task,
auto_config = model_config,
offload_embedding = offload_embedding,
float32_mixed_precision = float32_mixed_precision,
# Pass vLLM/inference parameters
fast_inference=fast_inference,
gpu_memory_utilization=gpu_memory_utilization,
float8_kv_cache=float8_kv_cache,
random_state=random_state,
max_lora_rank=max_lora_rank,
disable_log_stats=disable_log_stats,
fast_inference = fast_inference,
gpu_memory_utilization = gpu_memory_utilization,
float8_kv_cache = float8_kv_cache,
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
*args,
**kwargs,
)
@ -1228,14 +1228,14 @@ class FastModel(FastBaseModel):
model = PeftModel.from_pretrained(
model,
old_model_name,
token=token,
revision=revision,
is_trainable=True,
trust_remote_code=trust_remote_code,
token = token,
revision = revision,
is_trainable = True,
trust_remote_code = trust_remote_code,
)
# Patch it as well!
model = FastBaseModel.post_patch_model(
model, use_gradient_checkpointing, trust_remote_code=trust_remote_code
model, use_gradient_checkpointing, trust_remote_code = trust_remote_code
)
# Apply QAT if specified
@ -1249,7 +1249,7 @@ class FastModel(FastBaseModel):
"UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
)
if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
patch_tiled_mlp(model, patch_options_str=patch_tiled_mlp_choice)
patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
return model, tokenizer

View file

@ -39,10 +39,10 @@ class FastQwen2Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="qwen2",
rope_module=LlamaRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=Qwen2Attention,
model_name = "qwen2",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = Qwen2Attention,
)
if init_name is not None:
exec(function, globals())
@ -72,30 +72,30 @@ class FastQwen2Model(FastLlamaModel):
@staticmethod
def from_pretrained(
model_name="Qwen/Qwen2-7B",
max_seq_length=4096,
dtype=None,
load_in_4bit=True,
token=None,
device_map="sequential",
rope_scaling=None, # Qwen2 does not support RoPE scaling
fix_tokenizer=True,
model_patcher=None,
tokenizer_name=None,
trust_remote_code=False,
model_name = "Qwen/Qwen2-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Qwen2 does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
token=token,
device_map=device_map,
rope_scaling=rope_scaling,
fix_tokenizer=fix_tokenizer,
model_patcher=FastQwen2Model,
tokenizer_name=tokenizer_name,
trust_remote_code=trust_remote_code,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen2Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)

View file

@ -115,7 +115,7 @@ def Qwen3Attention_fast_forward(
else:
# Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
device_index = Q.device.index
if position_ids is None:
@ -126,8 +126,8 @@ def Qwen3Attention_fast_forward(
Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim=2)
V = torch.cat([past_key_value[1], V], dim=2)
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
past_key_value = (K, V) if use_cache else None
# Attention module
@ -163,7 +163,7 @@ def Qwen3Attention_fast_forward(
K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
A = xformers_attention(Q, K, V, attn_bias=causal_mask)
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
A = A.view(bsz, q_len, n_heads, head_dim)
elif HAS_FLASH_ATTENTION and attention_mask is None:
@ -172,7 +172,7 @@ def Qwen3Attention_fast_forward(
V = V.transpose(1, 2)
sw = kv_seq_len
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
A = flash_attn_func(Q, K, V, causal=True, window_size=window)
A = flash_attn_func(Q, K, V, causal = True, window_size = window)
else:
# Grouped query attention
# if n_groups != 1:
@ -195,7 +195,7 @@ def Qwen3Attention_fast_forward(
is_causal = False
A = scaled_dot_product_attention(
Q, K, V, attn_mask=attention_mask, is_causal=is_causal
Q, K, V, attn_mask = attention_mask, is_causal = is_causal
)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
@ -214,8 +214,8 @@ def Qwen3Attention_fast_forward_inference(
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill=False,
attention_mask=None,
do_prefill = False,
attention_mask = None,
):
"""
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
@ -267,29 +267,29 @@ def Qwen3Attention_fast_forward_inference(
if do_prefill:
self.paged_attention = torch.empty(
(KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
dtype=dtype,
device=device,
dtype = dtype,
device = device,
)
self.paged_attention_K = self.paged_attention[:, 0]
self.paged_attention_V = self.paged_attention[:, 1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
self.temp_QA = torch.empty(
(2, bsz, 1, attention_size), dtype=dtype, device=device
(2, bsz, 1, attention_size), dtype = dtype, device = device
)
self.temp_KV = torch.empty(
(2, bsz, 1, n_kv_heads * head_dim), dtype=dtype, device=device
(2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype=dtype, device=device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
self.temp_O = torch.empty((1, bsz, hidden_size), dtype=dtype, device=device)
self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
else:
self.temp_O = self.temp_QA[1][:, :, :hidden_size]
self.attention = torch.empty(
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype=dtype, device=device
(bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
@ -309,9 +309,9 @@ def Qwen3Attention_fast_forward_inference(
(bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
)
Qn = fast_linear_forward(self.q_proj, Xn, out=self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out=self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out=self.temp_KV[1])
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
Qn = Qn.view(
bsz, 1, n_heads, head_dim
) # .transpose(1, 2) # we will transpose after normalisation
@ -399,30 +399,30 @@ def Qwen3Attention_fast_forward_inference(
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch_matmul(
Qn, Knn.transpose(2, 3), out=self.attention[:, :, :, :cached_len]
Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
)
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
A[:] = torch_nn_functional_softmax(
A, dim=-1, dtype=torch.float32
A, dim = -1, dtype = torch.float32
) # .to(A.dtype)
A = torch_matmul(A, Vnn, out=Qn)
A = torch_matmul(A, Vnn, out = Qn)
else:
if SDPA_HAS_GQA:
A = scaled_dot_product_attention(
Qn,
Knn,
Vnn,
attn_mask=attention_mask,
is_causal=is_causal,
enable_gqa=True,
attn_mask = attention_mask,
is_causal = is_causal,
enable_gqa = True,
)
else:
A = scaled_dot_product_attention(
Qn, Knn, Vnn, attn_mask=attention_mask, is_causal=is_causal
Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
)
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out=self.temp_O)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
@ -430,10 +430,10 @@ class FastQwen3Model(FastLlamaModel):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="Qwen3",
rope_module=LlamaRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=Qwen3Attention,
model_name = "Qwen3",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = Qwen3Attention,
)
if init_name is not None:
exec(function, globals())
@ -463,30 +463,30 @@ class FastQwen3Model(FastLlamaModel):
@staticmethod
def from_pretrained( # TODO: Change after release
model_name="Qwen/Qwen3-7B",
max_seq_length=4096,
dtype=None,
load_in_4bit=True,
token=None,
device_map="sequential",
rope_scaling=None,
fix_tokenizer=True,
model_patcher=None,
tokenizer_name=None,
trust_remote_code=False,
model_name = "Qwen/Qwen3-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
token=token,
device_map=device_map,
rope_scaling=rope_scaling,
fix_tokenizer=fix_tokenizer,
model_patcher=FastQwen3Model,
tokenizer_name=tokenizer_name,
trust_remote_code=trust_remote_code,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen3Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)

View file

@ -49,29 +49,29 @@ from unsloth_zoo.utils import Version, _get_dtype
torch_nn_functional_softmax = torch.nn.functional.softmax
def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate=None, temp_up=None):
def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None):
# adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370
bsz, seq_len, hd = X.shape
X = X.view(-1, hd)
router_logits = fast_linear_forward(
self.gate_proj, X, out=temp_gate
self.gate_proj, X, out = temp_gate
) # pretty much the only change from transformers implementation.
routing_weights = torch_nn_functional_softmax(
router_logits, dim=-1, dtype=torch.float32
router_logits, dim = -1, dtype = torch.float32
)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
# we cast back to the input dtype
routing_weights = routing_weights.to(X.dtype)
final_X = torch.zeros((bsz * seq_len, hd), dtype=torch.float32, device=X.device)
final_X = torch.zeros((bsz * seq_len, hd), dtype = torch.float32, device = X.device)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
selected_experts, num_classes = self.num_experts
).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
@ -119,16 +119,16 @@ def Qwen3MoeDecoderLayer_fast_forward(
self.input_layernorm, hidden_states
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings=position_embeddings,
_flag_for_generation=self._flag_for_generation,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
_flag_for_generation = self._flag_for_generation,
)
hidden_states += residual
@ -145,15 +145,15 @@ def Qwen3MoeDecoderLayer_fast_forward(
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings=position_embeddings,
hidden_states = hidden_states,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_value = past_key_value,
output_attentions = output_attentions,
use_cache = use_cache,
padding_mask = padding_mask,
position_embeddings = position_embeddings,
)
hidden_states = residual + hidden_states
@ -177,10 +177,10 @@ class FastQwen3MoeModel(FastQwen3Model):
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name="Qwen3Moe",
rope_module=LlamaRotaryEmbedding,
scaled_rope_module=LlamaLinearScalingRotaryEmbedding,
attention_module=Qwen3MoeAttention,
model_name = "Qwen3Moe",
rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = Qwen3MoeAttention,
)
if init_name is not None:
exec(function, globals())
@ -214,30 +214,30 @@ class FastQwen3MoeModel(FastQwen3Model):
@staticmethod
def from_pretrained( # TODO: Change after release
model_name="Qwen/Qwen3-7B",
max_seq_length=4096,
dtype=None,
load_in_4bit=True,
token=None,
device_map="sequential",
rope_scaling=None,
fix_tokenizer=True,
model_patcher=None,
tokenizer_name=None,
trust_remote_code=False,
model_name = "Qwen/Qwen3-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None,
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
token=token,
device_map=device_map,
rope_scaling=rope_scaling,
fix_tokenizer=fix_tokenizer,
model_patcher=FastQwen3Model,
tokenizer_name=tokenizer_name,
trust_remote_code=trust_remote_code,
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen3Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)

View file

@ -30,70 +30,70 @@ class DeepseekR1ModelInfo(ModelInfo):
# Deepseek V3 Model Meta
DeepseekV3Meta = ModelMeta(
org="deepseek-ai",
base_name="DeepSeek",
instruct_tags=[None],
model_version="3",
model_sizes=[""],
model_info_cls=DeepseekV3ModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BF16],
org = "deepseek-ai",
base_name = "DeepSeek",
instruct_tags = [None],
model_version = "3",
model_sizes = [""],
model_info_cls = DeepseekV3ModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BF16],
)
DeepseekV3_0324Meta = ModelMeta(
org="deepseek-ai",
base_name="DeepSeek",
instruct_tags=[None],
model_version="3-0324",
model_sizes=[""],
model_info_cls=DeepseekV3ModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.GGUF],
org = "deepseek-ai",
base_name = "DeepSeek",
instruct_tags = [None],
model_version = "3-0324",
model_sizes = [""],
model_info_cls = DeepseekV3ModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.GGUF],
)
DeepseekR1Meta = ModelMeta(
org="deepseek-ai",
base_name="DeepSeek-R1",
instruct_tags=[None],
model_version="",
model_sizes=[""],
model_info_cls=DeepseekR1ModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF],
org = "deepseek-ai",
base_name = "DeepSeek-R1",
instruct_tags = [None],
model_version = "",
model_sizes = [""],
model_info_cls = DeepseekR1ModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF],
)
DeepseekR1ZeroMeta = ModelMeta(
org="deepseek-ai",
base_name="DeepSeek-R1",
instruct_tags=[None],
model_version="Zero",
model_sizes=[""],
model_info_cls=DeepseekR1ModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.GGUF],
org = "deepseek-ai",
base_name = "DeepSeek-R1",
instruct_tags = [None],
model_version = "Zero",
model_sizes = [""],
model_info_cls = DeepseekR1ModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.GGUF],
)
DeepseekR1DistillLlamaMeta = ModelMeta(
org="deepseek-ai",
base_name="DeepSeek-R1-Distill",
instruct_tags=[None],
model_version="Llama",
model_sizes=["8", "70"],
model_info_cls=DeepseekR1ModelInfo,
is_multimodal=False,
quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
org = "deepseek-ai",
base_name = "DeepSeek-R1-Distill",
instruct_tags = [None],
model_version = "Llama",
model_sizes = ["8", "70"],
model_info_cls = DeepseekR1ModelInfo,
is_multimodal = False,
quant_types = {"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
)
# Deepseek R1 Distill Qwen Model Meta
DeepseekR1DistillQwenMeta = ModelMeta(
org="deepseek-ai",
base_name="DeepSeek-R1-Distill",
instruct_tags=[None],
model_version="Qwen",
model_sizes=["1.5", "7", "14", "32"],
model_info_cls=DeepseekR1ModelInfo,
is_multimodal=False,
quant_types={
org = "deepseek-ai",
base_name = "DeepSeek-R1-Distill",
instruct_tags = [None],
model_version = "Qwen",
model_sizes = ["1.5", "7", "14", "32"],
model_info_cls = DeepseekR1ModelInfo,
is_multimodal = False,
quant_types = {
"1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
"7": [QuantType.UNSLOTH, QuantType.BNB],
"14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
@ -106,7 +106,7 @@ def register_deepseek_v3_models(include_original_model: bool = False):
global _IS_DEEPSEEK_V3_REGISTERED
if _IS_DEEPSEEK_V3_REGISTERED:
return
_register_models(DeepseekV3Meta, include_original_model=include_original_model)
_register_models(DeepseekV3Meta, include_original_model = include_original_model)
_IS_DEEPSEEK_V3_REGISTERED = True
@ -114,7 +114,7 @@ def register_deepseek_v3_0324_models(include_original_model: bool = False):
global _IS_DEEPSEEK_V3_0324_REGISTERED
if _IS_DEEPSEEK_V3_0324_REGISTERED:
return
_register_models(DeepseekV3_0324Meta, include_original_model=include_original_model)
_register_models(DeepseekV3_0324Meta, include_original_model = include_original_model)
_IS_DEEPSEEK_V3_0324_REGISTERED = True
@ -122,7 +122,7 @@ def register_deepseek_r1_models(include_original_model: bool = False):
global _IS_DEEPSEEK_R1_REGISTERED
if _IS_DEEPSEEK_R1_REGISTERED:
return
_register_models(DeepseekR1Meta, include_original_model=include_original_model)
_register_models(DeepseekR1Meta, include_original_model = include_original_model)
_IS_DEEPSEEK_R1_REGISTERED = True
@ -130,7 +130,7 @@ def register_deepseek_r1_zero_models(include_original_model: bool = False):
global _IS_DEEPSEEK_R1_ZERO_REGISTERED
if _IS_DEEPSEEK_R1_ZERO_REGISTERED:
return
_register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model)
_register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model)
_IS_DEEPSEEK_R1_ZERO_REGISTERED = True
@ -139,7 +139,7 @@ def register_deepseek_r1_distill_llama_models(include_original_model: bool = Fal
if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED:
return
_register_models(
DeepseekR1DistillLlamaMeta, include_original_model=include_original_model
DeepseekR1DistillLlamaMeta, include_original_model = include_original_model
)
_IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True
@ -149,21 +149,21 @@ def register_deepseek_r1_distill_qwen_models(include_original_model: bool = Fals
if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED:
return
_register_models(
DeepseekR1DistillQwenMeta, include_original_model=include_original_model
DeepseekR1DistillQwenMeta, include_original_model = include_original_model
)
_IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True
def register_deepseek_models(include_original_model: bool = False):
register_deepseek_v3_models(include_original_model=include_original_model)
register_deepseek_v3_0324_models(include_original_model=include_original_model)
register_deepseek_r1_models(include_original_model=include_original_model)
register_deepseek_r1_zero_models(include_original_model=include_original_model)
register_deepseek_v3_models(include_original_model = include_original_model)
register_deepseek_v3_0324_models(include_original_model = include_original_model)
register_deepseek_r1_models(include_original_model = include_original_model)
register_deepseek_r1_zero_models(include_original_model = include_original_model)
register_deepseek_r1_distill_llama_models(
include_original_model=include_original_model
include_original_model = include_original_model
)
register_deepseek_r1_distill_qwen_models(
include_original_model=include_original_model
include_original_model = include_original_model
)
@ -172,7 +172,7 @@ def _list_deepseek_r1_distill_models():
from unsloth.utils.hf_hub import list_models
models: list[HfModelInfo] = list_models(
author="unsloth", search="Distill", limit=1000
author = "unsloth", search = "Distill", limit = 1000
)
distill_models = []
for model in models:
@ -185,14 +185,14 @@ def _list_deepseek_r1_distill_models():
return distill_models
register_deepseek_models(include_original_model=True)
register_deepseek_models(include_original_model = True)
if __name__ == "__main__":
from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
MODEL_REGISTRY.clear()
register_deepseek_models(include_original_model=True)
register_deepseek_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)

View file

@ -15,26 +15,26 @@ class GemmaModelInfo(ModelInfo):
# Gemma3 Base Model Meta
GemmaMeta3Base = ModelMeta(
org="google",
base_name="gemma",
instruct_tags=["pt"], # pt = base
model_version="3",
model_sizes=["1", "4", "12", "27"],
model_info_cls=GemmaModelInfo,
is_multimodal=True,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
org = "google",
base_name = "gemma",
instruct_tags = ["pt"], # pt = base
model_version = "3",
model_sizes = ["1", "4", "12", "27"],
model_info_cls = GemmaModelInfo,
is_multimodal = True,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Gemma3 Instruct Model Meta
GemmaMeta3Instruct = ModelMeta(
org="google",
base_name="gemma",
instruct_tags=["it"], # it = instruction tuned
model_version="3",
model_sizes=["1", "4", "12", "27"],
model_info_cls=GemmaModelInfo,
is_multimodal=True,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
org = "google",
base_name = "gemma",
instruct_tags = ["it"], # it = instruction tuned
model_version = "3",
model_sizes = ["1", "4", "12", "27"],
model_info_cls = GemmaModelInfo,
is_multimodal = True,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
)
@ -42,7 +42,7 @@ def register_gemma_3_base_models(include_original_model: bool = False):
global _IS_GEMMA_3_BASE_REGISTERED
if _IS_GEMMA_3_BASE_REGISTERED:
return
_register_models(GemmaMeta3Base, include_original_model=include_original_model)
_register_models(GemmaMeta3Base, include_original_model = include_original_model)
_IS_GEMMA_3_BASE_REGISTERED = True
@ -50,13 +50,13 @@ def register_gemma_3_instruct_models(include_original_model: bool = False):
global _IS_GEMMA_3_INSTRUCT_REGISTERED
if _IS_GEMMA_3_INSTRUCT_REGISTERED:
return
_register_models(GemmaMeta3Instruct, include_original_model=include_original_model)
_register_models(GemmaMeta3Instruct, include_original_model = include_original_model)
_IS_GEMMA_3_INSTRUCT_REGISTERED = True
def register_gemma_models(include_original_model: bool = False):
register_gemma_3_base_models(include_original_model=include_original_model)
register_gemma_3_instruct_models(include_original_model=include_original_model)
register_gemma_3_base_models(include_original_model = include_original_model)
register_gemma_3_instruct_models(include_original_model = include_original_model)
if __name__ == "__main__":
@ -64,7 +64,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
register_gemma_models(include_original_model=True)
register_gemma_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)

View file

@ -23,14 +23,14 @@ class MistralSmallModelInfo(ModelInfo):
MistralSmall_2503_Base_Meta = ModelMeta(
org="mistralai",
base_name="Mistral-Small",
instruct_tags=["Base"],
model_version=_MISTRAL_SMALL_03_25_VERSION,
model_sizes=["24"],
model_info_cls=MistralSmallModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
org = "mistralai",
base_name = "Mistral-Small",
instruct_tags = ["Base"],
model_version = _MISTRAL_SMALL_03_25_VERSION,
model_sizes = ["24"],
model_info_cls = MistralSmallModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
)
MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
@ -54,23 +54,23 @@ def register_mistral_small_models(include_original_model: bool = False):
if _IS_MISTRAL_SMALL_REGISTERED:
return
_register_models(
MistralSmall_2503_Base_Meta, include_original_model=include_original_model
MistralSmall_2503_Base_Meta, include_original_model = include_original_model
)
_register_models(
MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model
MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model
)
_register_models(
MistralSmall_2501_Base_Meta, include_original_model=include_original_model
MistralSmall_2501_Base_Meta, include_original_model = include_original_model
)
_register_models(
MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model
MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model
)
_IS_MISTRAL_SMALL_REGISTERED = True
def register_mistral_models(include_original_model: bool = False):
register_mistral_small_models(include_original_model=include_original_model)
register_mistral_small_models(include_original_model = include_original_model)
if __name__ == "__main__":
@ -78,7 +78,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
register_mistral_models(include_original_model=True)
register_mistral_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)

View file

@ -43,50 +43,50 @@ class QwenQVQPreviewModelInfo(ModelInfo):
# Qwen2.5 Model Meta
Qwen_2_5_Meta = ModelMeta(
org="Qwen",
base_name="Qwen",
instruct_tags=[None, "Instruct"],
model_version="2.5",
model_sizes=["3", "7"],
model_info_cls=QwenModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
org = "Qwen",
base_name = "Qwen",
instruct_tags = [None, "Instruct"],
model_version = "2.5",
model_sizes = ["3", "7"],
model_info_cls = QwenModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Qwen2.5 VL Model Meta
Qwen_2_5_VLMeta = ModelMeta(
org="Qwen",
base_name="Qwen",
instruct_tags=["Instruct"], # No base, only instruction tuned
model_version="2.5",
model_sizes=["3", "7", "32", "72"],
model_info_cls=QwenVLModelInfo,
is_multimodal=True,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
org = "Qwen",
base_name = "Qwen",
instruct_tags = ["Instruct"], # No base, only instruction tuned
model_version = "2.5",
model_sizes = ["3", "7", "32", "72"],
model_info_cls = QwenVLModelInfo,
is_multimodal = True,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
)
# Qwen QwQ Model Meta
QwenQwQMeta = ModelMeta(
org="Qwen",
base_name="QwQ",
instruct_tags=[None],
model_version="",
model_sizes=["32"],
model_info_cls=QwenQwQModelInfo,
is_multimodal=False,
quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
org = "Qwen",
base_name = "QwQ",
instruct_tags = [None],
model_version = "",
model_sizes = ["32"],
model_info_cls = QwenQwQModelInfo,
is_multimodal = False,
quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
)
# Qwen QVQ Preview Model Meta
QwenQVQPreviewMeta = ModelMeta(
org="Qwen",
base_name="QVQ",
instruct_tags=[None],
model_version="",
model_sizes=["72"],
model_info_cls=QwenQVQPreviewModelInfo,
is_multimodal=True,
quant_types=[QuantType.NONE, QuantType.BNB],
org = "Qwen",
base_name = "QVQ",
instruct_tags = [None],
model_version = "",
model_sizes = ["72"],
model_info_cls = QwenQVQPreviewModelInfo,
is_multimodal = True,
quant_types = [QuantType.NONE, QuantType.BNB],
)
@ -94,7 +94,7 @@ def register_qwen_2_5_models(include_original_model: bool = False):
global _IS_QWEN_2_5_REGISTERED
if _IS_QWEN_2_5_REGISTERED:
return
_register_models(Qwen_2_5_Meta, include_original_model=include_original_model)
_register_models(Qwen_2_5_Meta, include_original_model = include_original_model)
_IS_QWEN_2_5_REGISTERED = True
@ -102,7 +102,7 @@ def register_qwen_2_5_vl_models(include_original_model: bool = False):
global _IS_QWEN_2_5_VL_REGISTERED
if _IS_QWEN_2_5_VL_REGISTERED:
return
_register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model)
_register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model)
_IS_QWEN_2_5_VL_REGISTERED = True
@ -110,15 +110,15 @@ def register_qwen_qwq_models(include_original_model: bool = False):
global _IS_QWEN_QWQ_REGISTERED
if _IS_QWEN_QWQ_REGISTERED:
return
_register_models(QwenQwQMeta, include_original_model=include_original_model)
_register_models(QwenQVQPreviewMeta, include_original_model=include_original_model)
_register_models(QwenQwQMeta, include_original_model = include_original_model)
_register_models(QwenQVQPreviewMeta, include_original_model = include_original_model)
_IS_QWEN_QWQ_REGISTERED = True
def register_qwen_models(include_original_model: bool = False):
register_qwen_2_5_models(include_original_model=include_original_model)
register_qwen_2_5_vl_models(include_original_model=include_original_model)
register_qwen_qwq_models(include_original_model=include_original_model)
register_qwen_2_5_models(include_original_model = include_original_model)
register_qwen_2_5_vl_models(include_original_model = include_original_model)
register_qwen_qwq_models(include_original_model = include_original_model)
if __name__ == "__main__":
@ -126,7 +126,7 @@ if __name__ == "__main__":
MODEL_REGISTRY.clear()
register_qwen_models(include_original_model=True)
register_qwen_models(include_original_model = True)
for model_id, model_info in MODEL_REGISTRY.items():
model_info = _check_model_info(model_id)

View file

@ -62,7 +62,7 @@ class ModelInfo:
@classmethod
def construct_model_name(
cls, base_name, version, size, quant_type, instruct_tag, key=""
cls, base_name, version, size, quant_type, instruct_tag, key = ""
):
key = cls.append_instruct_tag(key, instruct_tag)
key = cls.append_quant_type(key, quant_type)
@ -81,10 +81,10 @@ class ModelMeta:
base_name: str
model_version: str
model_info_cls: type[ModelInfo]
model_sizes: list[str] = field(default_factory=list)
instruct_tags: list[str] = field(default_factory=list)
model_sizes: list[str] = field(default_factory = list)
instruct_tags: list[str] = field(default_factory = list)
quant_types: list[QuantType] | dict[str, list[QuantType]] = field(
default_factory=list
default_factory = list
)
is_multimodal: bool = False
@ -104,11 +104,11 @@ def register_model(
name: str = None,
):
name = name or model_info_cls.construct_model_name(
base_name=base_name,
version=version,
size=size,
quant_type=quant_type,
instruct_tag=instruct_tag,
base_name = base_name,
version = version,
size = size,
quant_type = quant_type,
instruct_tag = instruct_tag,
)
key = f"{org}/{name}"
@ -118,14 +118,14 @@ def register_model(
)
MODEL_REGISTRY[key] = model_info_cls(
org=org,
base_name=base_name,
version=version,
size=size,
is_multimodal=is_multimodal,
instruct_tag=instruct_tag,
quant_type=quant_type,
name=name,
org = org,
base_name = base_name,
version = version,
size = size,
is_multimodal = is_multimodal,
instruct_tag = instruct_tag,
quant_type = quant_type,
name = name,
)
@ -137,7 +137,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]):
api = HfApi()
try:
model_info: HfModelInfo = api.model_info(model_id, expand=properties)
model_info: HfModelInfo = api.model_info(model_id, expand = properties)
except Exception as e:
if isinstance(e, RepositoryNotFoundError):
warnings.warn(f"{model_id} not found on Hugging Face")
@ -168,24 +168,24 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False
# NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
_org = "unsloth" # unsloth models -- these are all quantized versions of the original model
register_model(
model_info_cls=model_info_cls,
org=_org,
base_name=base_name,
version=model_version,
size=size,
instruct_tag=instruct_tag,
quant_type=quant_type,
is_multimodal=is_multimodal,
model_info_cls = model_info_cls,
org = _org,
base_name = base_name,
version = model_version,
size = size,
instruct_tag = instruct_tag,
quant_type = quant_type,
is_multimodal = is_multimodal,
)
# include original model from releasing organization
if include_original_model:
register_model(
model_info_cls=model_info_cls,
org=org,
base_name=base_name,
version=model_version,
size=size,
instruct_tag=instruct_tag,
quant_type=QuantType.NONE,
is_multimodal=is_multimodal,
model_info_cls = model_info_cls,
org = org,
base_name = base_name,
version = model_version,
size = size,
instruct_tag = instruct_tag,
quant_type = QuantType.NONE,
is_multimodal = is_multimodal,
)

View file

@ -79,7 +79,7 @@ def _create_unsloth_optimizer(
model,
optimizer_cls,
optimizer_kwargs,
embedding_lr=5e-5,
embedding_lr = 5e-5,
):
lr = optimizer_kwargs["lr"]
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)

View file

@ -36,7 +36,7 @@ def get_model_info(
if _HFAPI is None:
_HFAPI = HfApi()
try:
model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties)
model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties)
except Exception as e:
print(f"Error getting model info for {model_id}: {e}")
model_info = None
@ -68,11 +68,11 @@ def list_models(
properties = None
models: list[ModelInfo] = _HFAPI.list_models(
author=author,
search=search,
sort=sort,
limit=limit,
expand=properties,
full=full,
author = author,
search = search,
sort = sort,
limit = limit,
expand = properties,
full = full,
)
return models